packages/django-channels-postgres: provide sync API for group_send (cherry-pick #20740 to version-2026.2) (#20741)

packages/django-channels-postgres: provide sync API for group_send (#20740)

Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
authentik-automation[bot]
2026-03-05 19:02:49 +01:00
committed by GitHub
parent 093e60c753
commit 0d8f366af8
3 changed files with 61 additions and 11 deletions
+2 -3
View File
@@ -7,7 +7,6 @@ from socket import gethostname
from typing import Any
from urllib.parse import urlparse
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.core.cache import cache
from django.utils.translation import gettext_lazy as _
@@ -159,7 +158,7 @@ def outpost_send_update(pk: Any):
layer = get_channel_layer()
group = build_outpost_group(outpost.pk)
LOGGER.debug("sending update", channel=group, outpost=outpost)
async_to_sync(layer.group_send)(group, {"type": "event.update"})
layer.group_send_blocking(group, {"type": "event.update"})
@actor(description=_("Checks the local environment and create Service connections."))
@@ -210,7 +209,7 @@ def outpost_session_end(session_id: str):
for outpost in Outpost.objects.all():
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
group = build_outpost_group(outpost.pk)
async_to_sync(layer.group_send)(
layer.group_send_blocking(
group,
{
"type": "event.session.end",
+1 -2
View File
@@ -1,6 +1,5 @@
"""proxy provider tasks"""
from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer
from django.utils.translation import gettext_lazy as _
from dramatiq.actor import actor
@@ -16,7 +15,7 @@ def proxy_on_logout(session_id: str):
hashed_session_id = hash_session_key(session_id)
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = build_outpost_group(outpost.pk)
async_to_sync(layer.group_send)(
layer.group_send_blocking(
group,
{
"type": "event.provider.specific",
@@ -36,7 +36,7 @@ async def _async_proxy(
) -> Any:
# Must be defined as a function and not a method due to
# https://bugs.python.org/issue38364
layer = obj._get_layer()
layer = obj._get_layer(allow_sync=False)
return await getattr(layer, name)(*args, **kwargs)
@@ -63,7 +63,7 @@ class PostgresChannelLayerLoopProxy:
self._args = args
self._kwargs = kwargs
self._kwargs["channel_layer"] = self
self._layers: dict[asyncio.AbstractEventLoop, PostgresChannelLoopLayer] = {}
self._layers: dict[asyncio.AbstractEventLoop | None, PostgresChannelLoopLayer] = {}
def __getattr__(self, name: str) -> Any:
if name in (
@@ -77,7 +77,7 @@ class PostgresChannelLayerLoopProxy:
):
return functools.partial(_async_proxy, self, name)
else:
return getattr(self._get_layer(), name)
return getattr(self._get_layer(allow_sync=True), name)
def serialize(self, message: dict[str, Any]) -> bytes:
"""Serializes message to a byte string."""
@@ -90,15 +90,23 @@ class PostgresChannelLayerLoopProxy:
m = zlib.decompress(message)
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
def _get_layer(self) -> PostgresChannelLoopLayer:
loop = asyncio.get_running_loop()
def _get_layer(self, allow_sync: bool) -> PostgresChannelLoopLayer:
try:
loop = asyncio.get_running_loop()
except RuntimeError as exc:
if allow_sync:
# No loop configured, we will only allow sync APIs
loop = None
else:
raise exc
try:
layer = self._layers[loop]
except KeyError:
layer = PostgresChannelLoopLayer(*self._args, **self._kwargs)
self._layers[loop] = layer
_wrap_close(self, loop)
if loop is not None:
_wrap_close(self, loop)
return layer
@@ -396,6 +404,50 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
messages,
)
def group_send_blocking(self, group: str, message: dict[str, Any]) -> None:
"""
Sends a message to the entire group, blocking version.
"""
assert self.require_valid_group_name(group), "Group name not valid" # nosec
group_key = self._group_key(group)
serialized_message = self.channel_layer.serialize(message)
with connections[self.using].cursor() as cursor:
cursor.execute(
sql.SQL("""
SELECT DISTINCT {table}.{channel}
FROM {table}
WHERE {table}.{group_key} = %s
""").format(
table=sql.Identifier(GROUP_CHANNEL_TABLE),
channel=sql.Identifier("channel"),
group_key=sql.Identifier("group_key"),
),
(group_key,),
)
channels = [row[0] for row in cursor.fetchall()]
messages = [
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
for channel in channels
]
with connections[self.using].cursor() as cursor:
cursor.executemany(
sql.SQL("""
INSERT INTO {table}
({id}, {channel}, {message}, {expires})
VALUES (%s, %s, %s, %s)
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
message=sql.Identifier("message"),
expires=sql.Identifier("expires"),
),
messages,
)
def _group_key(self, group: str) -> str:
"""
Common function to make the storage key for the group.