mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
packages/django-channels-postgres: provide sync API for group_send (cherry-pick #20740 to version-2026.2) (#20741)
packages/django-channels-postgres: provide sync API for group_send (#20740) Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
committed by
GitHub
parent
093e60c753
commit
0d8f366af8
@@ -7,7 +7,6 @@ from socket import gethostname
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.core.cache import cache
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
@@ -159,7 +158,7 @@ def outpost_send_update(pk: Any):
|
||||
layer = get_channel_layer()
|
||||
group = build_outpost_group(outpost.pk)
|
||||
LOGGER.debug("sending update", channel=group, outpost=outpost)
|
||||
async_to_sync(layer.group_send)(group, {"type": "event.update"})
|
||||
layer.group_send_blocking(group, {"type": "event.update"})
|
||||
|
||||
|
||||
@actor(description=_("Checks the local environment and create Service connections."))
|
||||
@@ -210,7 +209,7 @@ def outpost_session_end(session_id: str):
|
||||
for outpost in Outpost.objects.all():
|
||||
LOGGER.info("Sending session end signal to outpost", outpost=outpost)
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
group,
|
||||
{
|
||||
"type": "event.session.end",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""proxy provider tasks"""
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from dramatiq.actor import actor
|
||||
@@ -16,7 +15,7 @@ def proxy_on_logout(session_id: str):
|
||||
hashed_session_id = hash_session_key(session_id)
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = build_outpost_group(outpost.pk)
|
||||
async_to_sync(layer.group_send)(
|
||||
layer.group_send_blocking(
|
||||
group,
|
||||
{
|
||||
"type": "event.provider.specific",
|
||||
|
||||
@@ -36,7 +36,7 @@ async def _async_proxy(
|
||||
) -> Any:
|
||||
# Must be defined as a function and not a method due to
|
||||
# https://bugs.python.org/issue38364
|
||||
layer = obj._get_layer()
|
||||
layer = obj._get_layer(allow_sync=False)
|
||||
return await getattr(layer, name)(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class PostgresChannelLayerLoopProxy:
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self._kwargs["channel_layer"] = self
|
||||
self._layers: dict[asyncio.AbstractEventLoop, PostgresChannelLoopLayer] = {}
|
||||
self._layers: dict[asyncio.AbstractEventLoop | None, PostgresChannelLoopLayer] = {}
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in (
|
||||
@@ -77,7 +77,7 @@ class PostgresChannelLayerLoopProxy:
|
||||
):
|
||||
return functools.partial(_async_proxy, self, name)
|
||||
else:
|
||||
return getattr(self._get_layer(), name)
|
||||
return getattr(self._get_layer(allow_sync=True), name)
|
||||
|
||||
def serialize(self, message: dict[str, Any]) -> bytes:
|
||||
"""Serializes message to a byte string."""
|
||||
@@ -90,15 +90,23 @@ class PostgresChannelLayerLoopProxy:
|
||||
m = zlib.decompress(message)
|
||||
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
|
||||
|
||||
def _get_layer(self) -> PostgresChannelLoopLayer:
|
||||
loop = asyncio.get_running_loop()
|
||||
def _get_layer(self, allow_sync: bool) -> PostgresChannelLoopLayer:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError as exc:
|
||||
if allow_sync:
|
||||
# No loop configured, we will only allow sync APIs
|
||||
loop = None
|
||||
else:
|
||||
raise exc
|
||||
|
||||
try:
|
||||
layer = self._layers[loop]
|
||||
except KeyError:
|
||||
layer = PostgresChannelLoopLayer(*self._args, **self._kwargs)
|
||||
self._layers[loop] = layer
|
||||
_wrap_close(self, loop)
|
||||
if loop is not None:
|
||||
_wrap_close(self, loop)
|
||||
|
||||
return layer
|
||||
|
||||
@@ -396,6 +404,50 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
messages,
|
||||
)
|
||||
|
||||
def group_send_blocking(self, group: str, message: dict[str, Any]) -> None:
|
||||
"""
|
||||
Sends a message to the entire group, blocking version.
|
||||
"""
|
||||
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
||||
|
||||
group_key = self._group_key(group)
|
||||
|
||||
serialized_message = self.channel_layer.serialize(message)
|
||||
|
||||
with connections[self.using].cursor() as cursor:
|
||||
cursor.execute(
|
||||
sql.SQL("""
|
||||
SELECT DISTINCT {table}.{channel}
|
||||
FROM {table}
|
||||
WHERE {table}.{group_key} = %s
|
||||
""").format(
|
||||
table=sql.Identifier(GROUP_CHANNEL_TABLE),
|
||||
channel=sql.Identifier("channel"),
|
||||
group_key=sql.Identifier("group_key"),
|
||||
),
|
||||
(group_key,),
|
||||
)
|
||||
channels = [row[0] for row in cursor.fetchall()]
|
||||
messages = [
|
||||
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
|
||||
for channel in channels
|
||||
]
|
||||
with connections[self.using].cursor() as cursor:
|
||||
cursor.executemany(
|
||||
sql.SQL("""
|
||||
INSERT INTO {table}
|
||||
({id}, {channel}, {message}, {expires})
|
||||
VALUES (%s, %s, %s, %s)
|
||||
""").format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
channel=sql.Identifier("channel"),
|
||||
message=sql.Identifier("message"),
|
||||
expires=sql.Identifier("expires"),
|
||||
),
|
||||
messages,
|
||||
)
|
||||
|
||||
def _group_key(self, group: str) -> str:
|
||||
"""
|
||||
Common function to make the storage key for the group.
|
||||
|
||||
Reference in New Issue
Block a user