Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions shiny/session/_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .._fileupload import FileInfo, FileUploadManager
from .._namespaces import Id, ResolvedId, Root
from .._typing_extensions import TypedDict
from .._utils import wrap_async
from ..http_staticfiles import FileResponse
from ..input_handler import input_handlers
from ..reactive import Effect, Effect_, Value, flush, isolate
Expand Down Expand Up @@ -196,15 +197,15 @@ def __init__(
str, Callable[..., Awaitable[object]]
] = self._create_message_handlers()
self._file_upload_manager: FileUploadManager = FileUploadManager()
self._on_ended_callbacks = _utils.Callbacks()
self._on_ended_callbacks = _utils.AsyncCallbacks()
self._has_run_session_end_tasks: bool = False
self._downloads: dict[str, DownloadInfo] = {}
self._dynamic_routes: dict[str, DynamicRouteHandler] = {}

self._register_session_end_callbacks()

self._flush_callbacks = _utils.Callbacks()
self._flushed_callbacks = _utils.Callbacks()
self._flush_callbacks = _utils.AsyncCallbacks()
self._flushed_callbacks = _utils.AsyncCallbacks()

def _register_session_end_callbacks(self) -> None:
# This is to be called from the initialization. It registers functions
Expand All @@ -213,13 +214,13 @@ def _register_session_end_callbacks(self) -> None:
# Clear file upload directories, if present
self.on_ended(self._file_upload_manager.rm_upload_dir)

def _run_session_end_tasks(self) -> None:
async def _run_session_end_tasks(self) -> None:
if self._has_run_session_end_tasks:
return
self._has_run_session_end_tasks = True

try:
self._on_ended_callbacks.invoke()
await self._on_ended_callbacks.invoke()
finally:
self.app._remove_session(self)

Expand All @@ -228,7 +229,7 @@ async def close(self, code: int = 1001) -> None:
Close the session.
"""
await self._conn.close(code, None)
self._run_session_end_tasks()
await self._run_session_end_tasks()

async def _run(self) -> None:
conn_state: ConnectionState = ConnectionState.Start
Expand Down Expand Up @@ -318,7 +319,7 @@ def verify_state(expected_state: ConnectionState) -> None:
finally:
await self.close()
finally:
self._run_session_end_tasks()
await self._run_session_end_tasks()

def _manage_inputs(self, data: dict[str, object]) -> None:
for key, val in data.items():
Expand Down Expand Up @@ -632,7 +633,11 @@ def _send_error_response(self, message_str: str) -> None:
# Flush
# ==========================================================================
@add_example()
def on_flush(self, fn: Callable[[], None], once: bool = True) -> Callable[[], None]:
def on_flush(
self,
fn: Callable[[], None] | Callable[[], Awaitable[None]],
once: bool = True,
) -> Callable[[], None]:
"""
Register a function to call before the next reactive flush.

Expand All @@ -648,11 +653,13 @@ def on_flush(self, fn: Callable[[], None], once: bool = True) -> Callable[[], No
:
A function that can be used to cancel the registration.
"""
return self._flush_callbacks.register(fn, once)
return self._flush_callbacks.register(wrap_async(fn), once)

@add_example()
def on_flushed(
self, fn: Callable[[], None], once: bool = True
self,
fn: Callable[[], None] | Callable[[], Awaitable[None]],
once: bool = True,
) -> Callable[[], None]:
"""
Register a function to call after the next reactive flush.
Expand All @@ -669,14 +676,14 @@ def on_flushed(
:
A function that can be used to cancel the registration.
"""
return self._flushed_callbacks.register(fn, once)
return self._flushed_callbacks.register(wrap_async(fn), once)

def _request_flush(self) -> None:
self.app._request_flush(self)

async def _flush(self) -> None:
with session_context(self):
self._flush_callbacks.invoke()
await self._flush_callbacks.invoke()

try:
omq = self._outbound_message_queues
Expand All @@ -701,13 +708,16 @@ async def _flush(self) -> None:
self._outbound_message_queues = empty_outbound_message_queues()
finally:
with session_context(self):
self._flushed_callbacks.invoke()
await self._flushed_callbacks.invoke()

# ==========================================================================
# On session ended
# ==========================================================================
@add_example()
def on_ended(self, fn: Callable[[], None]) -> Callable[[], None]:
def on_ended(
self,
fn: Callable[[], None] | Callable[[], Awaitable[None]],
) -> Callable[[], None]:
"""
Registers a function to be called after the client has disconnected.

Expand All @@ -721,7 +731,7 @@ def on_ended(self, fn: Callable[[], None]) -> Callable[[], None]:
:
A function that can be used to cancel the registration.
"""
return self._on_ended_callbacks.register(fn)
return self._on_ended_callbacks.register(wrap_async(fn))

# ==========================================================================
# Misc
Expand Down