From ae6c1906e42285984e6b4037e228abfe14eb339b Mon Sep 17 00:00:00 2001 From: Marc 'risson' Schmitt Date: Mon, 6 Oct 2025 20:43:00 +0200 Subject: [PATCH] packages/django-channels-postgres/layer: fix connection deadlock (#17270) --- .../django_channels_postgres/layer.py | 147 +++++------------- .../django-channels-postgres/pyproject.toml | 1 - uv.lock | 2 - 3 files changed, 41 insertions(+), 109 deletions(-) diff --git a/packages/django-channels-postgres/django_channels_postgres/layer.py b/packages/django-channels-postgres/django_channels_postgres/layer.py index 447c0eccd0..14b4f1a817 100644 --- a/packages/django-channels-postgres/django_channels_postgres/layer.py +++ b/packages/django-channels-postgres/django_channels_postgres/layer.py @@ -11,7 +11,6 @@ import msgpack from channels.layers import BaseChannelLayer from django.db import DEFAULT_DB_ALIAS, connections from django.utils.timezone import now -from pglock.core import _cast_lock_id from psycopg import AsyncConnection, Notify, sql from psycopg.conninfo import make_conninfo from psycopg.errors import Error as PsycopgError @@ -136,9 +135,13 @@ class PostgresChannelLoopLayer(BaseChannelLayer): # Each consumer gets its own *specific* channel, created with the `new_channel()` method. # This dict maps `channel_name` to a queue of messages for that channel. - self.channels: dict[str, asyncio.Queue[bytes]] = {} + self.channels: dict[str, asyncio.Queue[tuple[str, bytes | None]]] = {} - self.connection = PostgresChannelLayerConnection(self.using, self) + self.receiver = PostgresChannelLayerReceiver(self.using, self) + + async def _subscribe_to_channel(self, channel: str) -> None: + self.channels[channel] = asyncio.Queue() + await self.receiver.subscribe(channel) extensions = ["groups", "flush"] @@ -165,33 +168,29 @@ class PostgresChannelLoopLayer(BaseChannelLayer): Returns a new channel name that can be used by something in our process as a specific channel. """ - return f"{self.prefix}.{prefix}.{uuid4().hex}" + channel = f"{self.prefix}.{prefix}.{uuid4().hex}" + await self._subscribe_to_channel(channel) + return channel async def receive(self, channel: str) -> dict[str, Any]: """ Receive the first message that arrives on the channel. If more than one coroutine waits on the same channel, the first waiter will be given the message when it arrives. - - This is done by acquiring an `advistory_lock` from the database - based on the channel name. - - If the lock is acquired successfully, subsequent calls to this method - will not try to acquire the lock again. - _The lock is session based and should be released by postgres when - the session is closed_ - - If the lock is already acquired by another coroutine, - subsequent calls to this method will repeatedly try to acquire the lock - before proceeding to wait for a message. """ if channel not in self.channels: - self.channels[channel] = asyncio.Queue() - await self.connection.subscribe(channel) + await self._subscribe_to_channel(channel) q = self.channels[channel] try: - message = await q.get() + while True: + (message_id, message) = await q.get() + if message is None: + m = await Message.objects.filter(pk=message_id).afirst() + if m is None: + continue + message = cast(bytes, m.message) + break except (asyncio.CancelledError, TimeoutError, GeneratorExit): # We assume here that the reason we are cancelled is because the consumer # is exiting, therefore we need to cleanup by unsubscribe below. Indeed, @@ -204,7 +203,7 @@ class PostgresChannelLoopLayer(BaseChannelLayer): if channel in self.channels: del self.channels[channel] try: - await self.connection.unsubscribe(channel) + await self.receiver.unsubscribe(channel) except BaseException as exc: # noqa: BLE001 LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc) # We don't re-raise here because we want the CancelledError to be the one @@ -286,15 +285,14 @@ class PostgresChannelLoopLayer(BaseChannelLayer): Deletes all messages and groups. """ self.channels = {} - await self.connection.flush() + await self.receiver.flush() -class PostgresChannelLayerConnection: +class PostgresChannelLayerReceiver: def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None: self.using = using self.channel_layer = channel_layer self._subscribed_to: set[str] = set() - self._locked_channels: set[str] = set() self._lock = asyncio.Lock() self._receive_task: asyncio.Task[None] | None = None @@ -325,33 +323,23 @@ class PostgresChannelLayerConnection: while True: try: async with await self._create_connection() as conn: - await self._update_locks(conn) + await self._process_backlog(conn) + await conn.execute( + sql.SQL("LISTEN {channel}").format(channel=sql.Identifier(NOTIFY_CHANNEL)) + ) while True: - await self._process_backlog(conn) - await conn.execute( - sql.SQL("LISTEN {channel}").format( - channel=sql.Identifier(NOTIFY_CHANNEL) - ) - ) - first_loop = True - async for notify in conn.notifies(stop_after=1, timeout=5): - if first_loop: - await self._update_locks(conn) - first_loop = False - await self._receive_notify(conn, notify) - if first_loop: - await self._update_locks(conn) + async for notify in conn.notifies(timeout=30): + await self._receive_notify(notify) except (asyncio.CancelledError, TimeoutError, GeneratorExit): raise except PsycopgError as exc: LOGGER.warning("Postgres connection is not healthy", exc=exc) except BaseException as exc: # noqa: BLE001 LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True) - self._locked_channels = set() await asyncio.sleep(1) async def _process_backlog(self, conn: AsyncConnection) -> None: - if not self._locked_channels: + if not self._subscribed_to: return async with conn.cursor() as cursor: await cursor.execute( @@ -362,102 +350,49 @@ class PostgresChannelLayerConnection: WHERE {table}.{channel} IN (%s) AND {table}.{expires} >= %s - RETURNING {table}.{channel}, {table}.{message} + RETURNING {table}.{id}, {table}.{channel}, {table}.{message} """ ).format( table=sql.Identifier(MESSAGE_TABLE), + id=sql.Identifier("id"), channel=sql.Identifier("channel"), expires=sql.Identifier("expires"), message=sql.Identifier("message"), ), - (tuple(self._locked_channels), now()), + (tuple(self._subscribed_to), now()), ) async for row in cursor: - channel, message = row - self._receive_message(channel, message) + message_id, channel, message = row + self._receive_message(channel, message_id, message) - def _get_lock_id(self, channel: str) -> int: - lock_id = _cast_lock_id(f"channels.{channel}") # type: ignore[no-untyped-call] - return cast(int, lock_id) - - async def _update_locks(self, conn: AsyncConnection) -> None: - async with self._lock: - locks_to_release = self._locked_channels - self._subscribed_to - locks_to_acquire = self._subscribed_to - self._locked_channels - - for channel in locks_to_acquire: - lock_id = self._get_lock_id(channel) - async with conn.cursor() as cursor: - await cursor.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)) - row = await cursor.fetchone() - if row is not None: - if row[0]: - self._locked_channels.add(channel) - - for channel in locks_to_release: - lock_id = self._get_lock_id(channel) - async with conn.cursor() as cursor: - await cursor.execute("SELECT pg_advisory_unlock(%s)", (lock_id,)) - row = await cursor.fetchone() - if row is not None: - if row[0]: - self._locked_channels.remove(channel) - - async def _receive_notify(self, conn: AsyncConnection, notify: Notify) -> None: + async def _receive_notify(self, notify: Notify) -> None: payload = notify.payload split_payload = payload.split(":") + message: bytes | None = None match len(split_payload): case 4: message_id, channel, timestamp, base64_message = split_payload - if channel not in self._locked_channels: + if channel not in self._subscribed_to: return expires = datetime.fromtimestamp(float(timestamp), tz=UTC) if expires < now(): return message = b64decode(base64_message) - async with conn.cursor() as cursor: - await cursor.execute( - sql.SQL("DELETE FROM {table} WHERE {table}.{id} = %s").format( - table=sql.Identifier(MESSAGE_TABLE), - id=sql.Identifier("id"), - ) - ) case 3: message_id, channel, timestamp = split_payload - if channel not in self._locked_channels: + if channel not in self._subscribed_to: return expires = datetime.fromtimestamp(float(timestamp), tz=UTC) if expires < now(): return - async with conn.cursor() as cursor: - await cursor.execute( - sql.SQL( - """ - DELETE - FROM {table} - WHERE - {table}.{id} = %s - RETURNING {table}.{message}, {table}.{expires} - """ - ).format( - table=sql.Identifier(MESSAGE_TABLE), - id=sql.Identifier("id"), - message=sql.Identifier("message"), - expires=sql.Identifier("expires"), - ), - (message_id,), - ) - row = await cursor.fetchone() - if row is None: - return - message, expires = row + message = None case _: return - self._receive_message(channel, message) + self._receive_message(channel, message_id, message) - def _receive_message(self, channel: str, message: bytes) -> None: + def _receive_message(self, channel: str, message_id: str, message: bytes | None) -> None: if (q := self.channel_layer.channels.get(channel)) is not None: - q.put_nowait(message) + q.put_nowait((message_id, message)) def _ensure_receiver(self) -> None: if self._receive_task is None: diff --git a/packages/django-channels-postgres/pyproject.toml b/packages/django-channels-postgres/pyproject.toml index d15b1fe0f9..45da6a091f 100644 --- a/packages/django-channels-postgres/pyproject.toml +++ b/packages/django-channels-postgres/pyproject.toml @@ -33,7 +33,6 @@ classifiers = [ dependencies = [ "channels >=4.3,<4.4", "django >=4.2,<6.0", - "django-pglock >=1.7,<2", "django-pgtrigger >=4,<5", "msgpack >=1,<2", "psycopg >=3,<4", diff --git a/uv.lock b/uv.lock index e18f7a41f8..9ec7842083 100644 --- a/uv.lock +++ b/uv.lock @@ -989,7 +989,6 @@ source = { editable = "packages/django-channels-postgres" } dependencies = [ { name = "channels" }, { name = "django" }, - { name = "django-pglock" }, { name = "django-pgtrigger" }, { name = "msgpack" }, { name = "psycopg" }, @@ -1000,7 +999,6 @@ dependencies = [ requires-dist = [ { name = "channels", specifier = ">=4.3,<4.4" }, { name = "django", specifier = ">=4.2,<6.0" }, - { name = "django-pglock", specifier = ">=1.7,<2" }, { name = "django-pgtrigger", specifier = ">=4,<5" }, { name = "msgpack", specifier = ">=1,<2" }, { name = "psycopg", specifier = ">=3,<4" },