From 0d8f366af86d8074a477ec7c132acb6f680972e8 Mon Sep 17 00:00:00 2001 From: "authentik-automation[bot]" <135050075+authentik-automation[bot]@users.noreply.github.com> Date: Thu, 5 Mar 2026 19:02:49 +0100 Subject: [PATCH] packages/django-channels-postgres: provide sync API for group_send (cherry-pick #20740 to version-2026.2) (#20741) packages/django-channels-postgres: provide sync API for group_send (#20740) Co-authored-by: Marc 'risson' Schmitt --- authentik/outposts/tasks.py | 5 +- authentik/providers/proxy/tasks.py | 3 +- .../django_channels_postgres/layer.py | 64 +++++++++++++++++-- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index 65dac91a77..5a5f3b2497 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -7,7 +7,6 @@ from socket import gethostname from typing import Any from urllib.parse import urlparse -from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.core.cache import cache from django.utils.translation import gettext_lazy as _ @@ -159,7 +158,7 @@ def outpost_send_update(pk: Any): layer = get_channel_layer() group = build_outpost_group(outpost.pk) LOGGER.debug("sending update", channel=group, outpost=outpost) - async_to_sync(layer.group_send)(group, {"type": "event.update"}) + layer.group_send_blocking(group, {"type": "event.update"}) @actor(description=_("Checks the local environment and create Service connections.")) @@ -210,7 +209,7 @@ def outpost_session_end(session_id: str): for outpost in Outpost.objects.all(): LOGGER.info("Sending session end signal to outpost", outpost=outpost) group = build_outpost_group(outpost.pk) - async_to_sync(layer.group_send)( + layer.group_send_blocking( group, { "type": "event.session.end", diff --git a/authentik/providers/proxy/tasks.py b/authentik/providers/proxy/tasks.py index 1891afa00a..a5bde34f40 100644 --- a/authentik/providers/proxy/tasks.py +++ b/authentik/providers/proxy/tasks.py @@ -1,6 +1,5 @@ """proxy provider tasks""" -from asgiref.sync import async_to_sync from channels.layers import get_channel_layer from django.utils.translation import gettext_lazy as _ from dramatiq.actor import actor @@ -16,7 +15,7 @@ def proxy_on_logout(session_id: str): hashed_session_id = hash_session_key(session_id) for outpost in Outpost.objects.filter(type=OutpostType.PROXY): group = build_outpost_group(outpost.pk) - async_to_sync(layer.group_send)( + layer.group_send_blocking( group, { "type": "event.provider.specific", diff --git a/packages/django-channels-postgres/django_channels_postgres/layer.py b/packages/django-channels-postgres/django_channels_postgres/layer.py index 8afc03866b..9adc65b798 100644 --- a/packages/django-channels-postgres/django_channels_postgres/layer.py +++ b/packages/django-channels-postgres/django_channels_postgres/layer.py @@ -36,7 +36,7 @@ async def _async_proxy( ) -> Any: # Must be defined as a function and not a method due to # https://bugs.python.org/issue38364 - layer = obj._get_layer() + layer = obj._get_layer(allow_sync=False) return await getattr(layer, name)(*args, **kwargs) @@ -63,7 +63,7 @@ class PostgresChannelLayerLoopProxy: self._args = args self._kwargs = kwargs self._kwargs["channel_layer"] = self - self._layers: dict[asyncio.AbstractEventLoop, PostgresChannelLoopLayer] = {} + self._layers: dict[asyncio.AbstractEventLoop | None, PostgresChannelLoopLayer] = {} def __getattr__(self, name: str) -> Any: if name in ( @@ -77,7 +77,7 @@ class PostgresChannelLayerLoopProxy: ): return functools.partial(_async_proxy, self, name) else: - return getattr(self._get_layer(), name) + return getattr(self._get_layer(allow_sync=True), name) def serialize(self, message: dict[str, Any]) -> bytes: """Serializes message to a byte string.""" @@ -90,15 +90,23 @@ class PostgresChannelLayerLoopProxy: m = zlib.decompress(message) return cast(dict[str, Any], msgpack.unpackb(m, raw=False)) - def _get_layer(self) -> PostgresChannelLoopLayer: - loop = asyncio.get_running_loop() + def _get_layer(self, allow_sync: bool) -> PostgresChannelLoopLayer: + try: + loop = asyncio.get_running_loop() + except RuntimeError as exc: + if allow_sync: + # No loop configured, we will only allow sync APIs + loop = None + else: + raise exc try: layer = self._layers[loop] except KeyError: layer = PostgresChannelLoopLayer(*self._args, **self._kwargs) self._layers[loop] = layer - _wrap_close(self, loop) + if loop is not None: + _wrap_close(self, loop) return layer @@ -396,6 +404,50 @@ class PostgresChannelLoopLayer(BaseChannelLayer): messages, ) + def group_send_blocking(self, group: str, message: dict[str, Any]) -> None: + """ + Sends a message to the entire group, blocking version. + """ + assert self.require_valid_group_name(group), "Group name not valid" # nosec + + group_key = self._group_key(group) + + serialized_message = self.channel_layer.serialize(message) + + with connections[self.using].cursor() as cursor: + cursor.execute( + sql.SQL(""" + SELECT DISTINCT {table}.{channel} + FROM {table} + WHERE {table}.{group_key} = %s + """).format( + table=sql.Identifier(GROUP_CHANNEL_TABLE), + channel=sql.Identifier("channel"), + group_key=sql.Identifier("group_key"), + ), + (group_key,), + ) + channels = [row[0] for row in cursor.fetchall()] + messages = [ + (uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry)) + for channel in channels + ] + with connections[self.using].cursor() as cursor: + cursor.executemany( + sql.SQL(""" + INSERT INTO {table} + ({id}, {channel}, {message}, {expires}) + VALUES (%s, %s, %s, %s) + """).format( + table=sql.Identifier(MESSAGE_TABLE), + id=sql.Identifier("id"), + channel=sql.Identifier("channel"), + message=sql.Identifier("message"), + expires=sql.Identifier("expires"), + ), + messages, + ) + def _group_key(self, group: str) -> str: """ Common function to make the storage key for the group.