Skip to content

Commit

Permalink
Add selected_rows method to data frame renderer(#1121)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gordon Shotwell authored Feb 14, 2024
1 parent e43e859 commit 6b844e8
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 30 deletions.
9 changes: 3 additions & 6 deletions examples/dataframe/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,10 @@ def handle_edit():

@render.text
def detail():
if (
input.grid_selected_rows() is not None
and len(input.grid_selected_rows()) > 0
):
selected_rows = grid.input_selected_rows() or ()
if len(selected_rows) > 0:
# "split", "records", "index", "columns", "values", "table"

return df().iloc[list(input.grid_selected_rows())]
return df().iloc[list(grid.input_selected_rows())]


app = App(app_ui, server)
6 changes: 4 additions & 2 deletions shiny/api-examples/data_frame/app-core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ def summary_data():

@reactive.calc
def filtered_df():
# input.summary_data_selected_rows() is a tuple, so we must convert it to list,
req(summary_data.input_selected_rows())

# summary_data.selected_rows() is a tuple, so we must convert it to list,
# as that's what Pandas requires for indexing.
selected_idx = list(req(input.summary_data_selected_rows()))
selected_idx = list(summary_data.input_selected_rows())
countries = summary_df.iloc[selected_idx]["country"]
# Filter data for selected countries
return df[df["country"].isin(countries)]
Expand Down
9 changes: 6 additions & 3 deletions shiny/api-examples/data_frame/app-express.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from shinywidgets import render_widget

from shiny import reactive, req
from shiny.express import input, render, ui
from shiny.express import render, ui

# Load the Gapminder dataset
df = px.data.gapminder()
Expand Down Expand Up @@ -66,9 +66,12 @@ def country_detail_percap():

@reactive.calc
def filtered_df():
# input.summary_data_selected_rows() is a tuple, so we must convert it to list,
req(summary_data.input_selected_rows())

# summary_data.input_selected_rows() is a tuple, so we must convert it to list,
# as that's what Pandas requires for indexing.
selected_idx = list(req(input.summary_data_selected_rows()))

selected_idx = list(summary_data.input_selected_rows())
countries = summary_df.iloc[selected_idx]["country"]
# Filter data for selected countries
return df[df["country"].isin(countries)]
24 changes: 21 additions & 3 deletions shiny/render/_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,22 @@

import abc
import json
from typing import TYPE_CHECKING, Any, Literal, Protocol, Union, cast, runtime_checkable
from typing import (
TYPE_CHECKING,
Any,
Literal,
Optional,
Protocol,
Union,
cast,
runtime_checkable,
)

from htmltools import Tag

from .. import ui
from .._docstring import add_example, no_example
from ..session._utils import require_active_session
from ._dataframe_unsafe import serialize_numpy_dtypes
from .renderer import Jsonifiable, Renderer

Expand Down Expand Up @@ -237,8 +247,7 @@ class data_frame(Renderer[DataFrameResult]):
Row selection
-------------
When using the row selection feature, you can access the selected rows by using the
`input.<id>_selected_rows()` function, where `<id>` is the `id` of the
:func:`~shiny.ui.output_data_frame`. The value returned will be `None` if no rows
`<data_frame_renderer>.input_selected_rows()` method, where `<data_frame_renderer>` is the render function name that corresponds with the `id=` used in :func:`~shiny.ui.outout_data_frame`. Internally, this method retrieves the selected row value from session's `input.<id>_selected_rows()` value. The value returned will be `None` if no rows
are selected, or a tuple of integers representing the indices of the selected rows.
To filter a pandas data frame down to the selected rows, use
`df.iloc[list(input.<id>_selected_rows())]`.
Expand Down Expand Up @@ -270,6 +279,15 @@ async def transform(self, value: DataFrameResult) -> Jsonifiable:
)
return value.to_payload()

def input_selected_rows(self) -> Optional[tuple[int]]:
"""
When `row_selection_mode` is set to "single" or "multiple" this will return
a tuple of integers representing the rows selected by a user.
"""

active_session = require_active_session(None)
return active_session.input[self.output_id + "_selected_rows"]()


@runtime_checkable
class PandasCompatible(Protocol):
Expand Down
25 changes: 16 additions & 9 deletions shiny/session/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,10 @@ def __init__(

self._outbound_message_queues = OutBoundMessageQueues()

self._message_handlers: dict[str, Callable[..., Awaitable[object]]] = (
self._create_message_handlers()
)
self._message_handlers: dict[
str,
Callable[..., Awaitable[object]],
] = self._create_message_handlers()
self._file_upload_manager: FileUploadManager = FileUploadManager()
self._on_ended_callbacks = _utils.AsyncCallbacks()
self._has_run_session_end_tasks: bool = False
Expand Down Expand Up @@ -608,22 +609,26 @@ def _send_remove_ui(self, selector: str, multiple: bool) -> None:
@overload
def _send_progress(
self, type: Literal["binding"], message: BindingProgressMessage
) -> None: ...
) -> None:
pass

@overload
def _send_progress(
self, type: Literal["open"], message: OpenProgressMessage
) -> None: ...
) -> None:
pass

@overload
def _send_progress(
self, type: Literal["close"], message: CloseProgressMessage
) -> None: ...
) -> None:
pass

@overload
def _send_progress(
self, type: Literal["update"], message: UpdateProgressMessage
) -> None: ...
) -> None:
pass

def _send_progress(self, type: str, message: object) -> None:
msg: dict[str, object] = {"progress": {"type": type, "message": message}}
Expand Down Expand Up @@ -1033,7 +1038,8 @@ def __init__(
self._suspend_when_hidden = suspend_when_hidden

@overload
def __call__(self, renderer: RendererT) -> RendererT: ...
def __call__(self, renderer: RendererT) -> RendererT:
pass

@overload
def __call__(
Expand All @@ -1042,7 +1048,8 @@ def __call__(
id: Optional[str] = None,
suspend_when_hidden: bool = True,
priority: int = 0,
) -> Callable[[RendererT], RendererT]: ...
) -> Callable[[RendererT], RendererT]:
pass

def __call__(
self,
Expand Down
4 changes: 3 additions & 1 deletion tests/playwright/deploys/plotly/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,11 @@ def summary_data():

@reactive.Calc
def filtered_df():
req(summary_data.input_selected_rows())

# input.summary_data_selected_rows() is a tuple, so we must convert it to list,
# as that's what Pandas requires for indexing.
selected_idx = list(req(input.summary_data_selected_rows()))
selected_idx = list(summary_data.input_selected_rows())
countries = summary_df.iloc[selected_idx]["country"]
# Filter data for selected countries
return df[df["country"].isin(countries)]
Expand Down
9 changes: 3 additions & 6 deletions tests/playwright/shiny/bugs/0676-row-selection/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,9 @@ def grid():

@render.table
def detail():
if (
input.grid_selected_rows() is not None
and len(input.grid_selected_rows()) > 0
):
# "split", "records", "index", "columns", "values", "table"
return df.iloc[list(input.grid_selected_rows())]
selected_rows = grid.input_selected_rows() or ()
if len(selected_rows) > 0:
return df.iloc[list(selected_rows)]

@render.text
def debug():
Expand Down

0 comments on commit 6b844e8

Please sign in to comment.