mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
packages/django-channels-postgres/layer: fix connection deadlock (#17270)
This commit is contained in:
committed by
GitHub
parent
2108575b73
commit
ae6c1906e4
@@ -11,7 +11,6 @@ import msgpack
|
||||
from channels.layers import BaseChannelLayer
|
||||
from django.db import DEFAULT_DB_ALIAS, connections
|
||||
from django.utils.timezone import now
|
||||
from pglock.core import _cast_lock_id
|
||||
from psycopg import AsyncConnection, Notify, sql
|
||||
from psycopg.conninfo import make_conninfo
|
||||
from psycopg.errors import Error as PsycopgError
|
||||
@@ -136,9 +135,13 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
|
||||
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
|
||||
# This dict maps `channel_name` to a queue of messages for that channel.
|
||||
self.channels: dict[str, asyncio.Queue[bytes]] = {}
|
||||
self.channels: dict[str, asyncio.Queue[tuple[str, bytes | None]]] = {}
|
||||
|
||||
self.connection = PostgresChannelLayerConnection(self.using, self)
|
||||
self.receiver = PostgresChannelLayerReceiver(self.using, self)
|
||||
|
||||
async def _subscribe_to_channel(self, channel: str) -> None:
|
||||
self.channels[channel] = asyncio.Queue()
|
||||
await self.receiver.subscribe(channel)
|
||||
|
||||
extensions = ["groups", "flush"]
|
||||
|
||||
@@ -165,33 +168,29 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
Returns a new channel name that can be used by something in our
|
||||
process as a specific channel.
|
||||
"""
|
||||
return f"{self.prefix}.{prefix}.{uuid4().hex}"
|
||||
channel = f"{self.prefix}.{prefix}.{uuid4().hex}"
|
||||
await self._subscribe_to_channel(channel)
|
||||
return channel
|
||||
|
||||
async def receive(self, channel: str) -> dict[str, Any]:
|
||||
"""
|
||||
Receive the first message that arrives on the channel.
|
||||
If more than one coroutine waits on the same channel, the first waiter
|
||||
will be given the message when it arrives.
|
||||
|
||||
This is done by acquiring an `advistory_lock` from the database
|
||||
based on the channel name.
|
||||
|
||||
If the lock is acquired successfully, subsequent calls to this method
|
||||
will not try to acquire the lock again.
|
||||
_The lock is session based and should be released by postgres when
|
||||
the session is closed_
|
||||
|
||||
If the lock is already acquired by another coroutine,
|
||||
subsequent calls to this method will repeatedly try to acquire the lock
|
||||
before proceeding to wait for a message.
|
||||
"""
|
||||
if channel not in self.channels:
|
||||
self.channels[channel] = asyncio.Queue()
|
||||
await self.connection.subscribe(channel)
|
||||
await self._subscribe_to_channel(channel)
|
||||
|
||||
q = self.channels[channel]
|
||||
try:
|
||||
message = await q.get()
|
||||
while True:
|
||||
(message_id, message) = await q.get()
|
||||
if message is None:
|
||||
m = await Message.objects.filter(pk=message_id).afirst()
|
||||
if m is None:
|
||||
continue
|
||||
message = cast(bytes, m.message)
|
||||
break
|
||||
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
|
||||
# We assume here that the reason we are cancelled is because the consumer
|
||||
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
|
||||
@@ -204,7 +203,7 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
if channel in self.channels:
|
||||
del self.channels[channel]
|
||||
try:
|
||||
await self.connection.unsubscribe(channel)
|
||||
await self.receiver.unsubscribe(channel)
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc)
|
||||
# We don't re-raise here because we want the CancelledError to be the one
|
||||
@@ -286,15 +285,14 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
Deletes all messages and groups.
|
||||
"""
|
||||
self.channels = {}
|
||||
await self.connection.flush()
|
||||
await self.receiver.flush()
|
||||
|
||||
|
||||
class PostgresChannelLayerConnection:
|
||||
class PostgresChannelLayerReceiver:
|
||||
def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None:
|
||||
self.using = using
|
||||
self.channel_layer = channel_layer
|
||||
self._subscribed_to: set[str] = set()
|
||||
self._locked_channels: set[str] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._receive_task: asyncio.Task[None] | None = None
|
||||
|
||||
@@ -325,33 +323,23 @@ class PostgresChannelLayerConnection:
|
||||
while True:
|
||||
try:
|
||||
async with await self._create_connection() as conn:
|
||||
await self._update_locks(conn)
|
||||
await self._process_backlog(conn)
|
||||
await conn.execute(
|
||||
sql.SQL("LISTEN {channel}").format(channel=sql.Identifier(NOTIFY_CHANNEL))
|
||||
)
|
||||
while True:
|
||||
await self._process_backlog(conn)
|
||||
await conn.execute(
|
||||
sql.SQL("LISTEN {channel}").format(
|
||||
channel=sql.Identifier(NOTIFY_CHANNEL)
|
||||
)
|
||||
)
|
||||
first_loop = True
|
||||
async for notify in conn.notifies(stop_after=1, timeout=5):
|
||||
if first_loop:
|
||||
await self._update_locks(conn)
|
||||
first_loop = False
|
||||
await self._receive_notify(conn, notify)
|
||||
if first_loop:
|
||||
await self._update_locks(conn)
|
||||
async for notify in conn.notifies(timeout=30):
|
||||
await self._receive_notify(notify)
|
||||
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
|
||||
raise
|
||||
except PsycopgError as exc:
|
||||
LOGGER.warning("Postgres connection is not healthy", exc=exc)
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True)
|
||||
self._locked_channels = set()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _process_backlog(self, conn: AsyncConnection) -> None:
|
||||
if not self._locked_channels:
|
||||
if not self._subscribed_to:
|
||||
return
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
@@ -362,102 +350,49 @@ class PostgresChannelLayerConnection:
|
||||
WHERE
|
||||
{table}.{channel} IN (%s)
|
||||
AND {table}.{expires} >= %s
|
||||
RETURNING {table}.{channel}, {table}.{message}
|
||||
RETURNING {table}.{id}, {table}.{channel}, {table}.{message}
|
||||
"""
|
||||
).format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
channel=sql.Identifier("channel"),
|
||||
expires=sql.Identifier("expires"),
|
||||
message=sql.Identifier("message"),
|
||||
),
|
||||
(tuple(self._locked_channels), now()),
|
||||
(tuple(self._subscribed_to), now()),
|
||||
)
|
||||
async for row in cursor:
|
||||
channel, message = row
|
||||
self._receive_message(channel, message)
|
||||
message_id, channel, message = row
|
||||
self._receive_message(channel, message_id, message)
|
||||
|
||||
def _get_lock_id(self, channel: str) -> int:
|
||||
lock_id = _cast_lock_id(f"channels.{channel}") # type: ignore[no-untyped-call]
|
||||
return cast(int, lock_id)
|
||||
|
||||
async def _update_locks(self, conn: AsyncConnection) -> None:
|
||||
async with self._lock:
|
||||
locks_to_release = self._locked_channels - self._subscribed_to
|
||||
locks_to_acquire = self._subscribed_to - self._locked_channels
|
||||
|
||||
for channel in locks_to_acquire:
|
||||
lock_id = self._get_lock_id(channel)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is not None:
|
||||
if row[0]:
|
||||
self._locked_channels.add(channel)
|
||||
|
||||
for channel in locks_to_release:
|
||||
lock_id = self._get_lock_id(channel)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is not None:
|
||||
if row[0]:
|
||||
self._locked_channels.remove(channel)
|
||||
|
||||
async def _receive_notify(self, conn: AsyncConnection, notify: Notify) -> None:
|
||||
async def _receive_notify(self, notify: Notify) -> None:
|
||||
payload = notify.payload
|
||||
split_payload = payload.split(":")
|
||||
message: bytes | None = None
|
||||
match len(split_payload):
|
||||
case 4:
|
||||
message_id, channel, timestamp, base64_message = split_payload
|
||||
if channel not in self._locked_channels:
|
||||
if channel not in self._subscribed_to:
|
||||
return
|
||||
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
||||
if expires < now():
|
||||
return
|
||||
message = b64decode(base64_message)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
sql.SQL("DELETE FROM {table} WHERE {table}.{id} = %s").format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
)
|
||||
)
|
||||
case 3:
|
||||
message_id, channel, timestamp = split_payload
|
||||
if channel not in self._locked_channels:
|
||||
if channel not in self._subscribed_to:
|
||||
return
|
||||
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
||||
if expires < now():
|
||||
return
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
DELETE
|
||||
FROM {table}
|
||||
WHERE
|
||||
{table}.{id} = %s
|
||||
RETURNING {table}.{message}, {table}.{expires}
|
||||
"""
|
||||
).format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
message=sql.Identifier("message"),
|
||||
expires=sql.Identifier("expires"),
|
||||
),
|
||||
(message_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return
|
||||
message, expires = row
|
||||
message = None
|
||||
case _:
|
||||
return
|
||||
self._receive_message(channel, message)
|
||||
self._receive_message(channel, message_id, message)
|
||||
|
||||
def _receive_message(self, channel: str, message: bytes) -> None:
|
||||
def _receive_message(self, channel: str, message_id: str, message: bytes | None) -> None:
|
||||
if (q := self.channel_layer.channels.get(channel)) is not None:
|
||||
q.put_nowait(message)
|
||||
q.put_nowait((message_id, message))
|
||||
|
||||
def _ensure_receiver(self) -> None:
|
||||
if self._receive_task is None:
|
||||
|
||||
@@ -33,7 +33,6 @@ classifiers = [
|
||||
dependencies = [
|
||||
"channels >=4.3,<4.4",
|
||||
"django >=4.2,<6.0",
|
||||
"django-pglock >=1.7,<2",
|
||||
"django-pgtrigger >=4,<5",
|
||||
"msgpack >=1,<2",
|
||||
"psycopg >=3,<4",
|
||||
|
||||
@@ -989,7 +989,6 @@ source = { editable = "packages/django-channels-postgres" }
|
||||
dependencies = [
|
||||
{ name = "channels" },
|
||||
{ name = "django" },
|
||||
{ name = "django-pglock" },
|
||||
{ name = "django-pgtrigger" },
|
||||
{ name = "msgpack" },
|
||||
{ name = "psycopg" },
|
||||
@@ -1000,7 +999,6 @@ dependencies = [
|
||||
requires-dist = [
|
||||
{ name = "channels", specifier = ">=4.3,<4.4" },
|
||||
{ name = "django", specifier = ">=4.2,<6.0" },
|
||||
{ name = "django-pglock", specifier = ">=1.7,<2" },
|
||||
{ name = "django-pgtrigger", specifier = ">=4,<5" },
|
||||
{ name = "msgpack", specifier = ">=1,<2" },
|
||||
{ name = "psycopg", specifier = ">=3,<4" },
|
||||
|
||||
Reference in New Issue
Block a user