packages/django-channels-postgres/layer: fix connection deadlock (#17270)

This commit is contained in:
Marc 'risson' Schmitt
2025-10-06 20:43:00 +02:00
committed by GitHub
parent 2108575b73
commit ae6c1906e4
3 changed files with 41 additions and 109 deletions
@@ -11,7 +11,6 @@ import msgpack
from channels.layers import BaseChannelLayer
from django.db import DEFAULT_DB_ALIAS, connections
from django.utils.timezone import now
from pglock.core import _cast_lock_id
from psycopg import AsyncConnection, Notify, sql
from psycopg.conninfo import make_conninfo
from psycopg.errors import Error as PsycopgError
@@ -136,9 +135,13 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
# This dict maps `channel_name` to a queue of messages for that channel.
self.channels: dict[str, asyncio.Queue[bytes]] = {}
self.channels: dict[str, asyncio.Queue[tuple[str, bytes | None]]] = {}
self.connection = PostgresChannelLayerConnection(self.using, self)
self.receiver = PostgresChannelLayerReceiver(self.using, self)
async def _subscribe_to_channel(self, channel: str) -> None:
self.channels[channel] = asyncio.Queue()
await self.receiver.subscribe(channel)
extensions = ["groups", "flush"]
@@ -165,33 +168,29 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
return f"{self.prefix}.{prefix}.{uuid4().hex}"
channel = f"{self.prefix}.{prefix}.{uuid4().hex}"
await self._subscribe_to_channel(channel)
return channel
async def receive(self, channel: str) -> dict[str, Any]:
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, the first waiter
will be given the message when it arrives.
This is done by acquiring an `advistory_lock` from the database
based on the channel name.
If the lock is acquired successfully, subsequent calls to this method
will not try to acquire the lock again.
_The lock is session based and should be released by postgres when
the session is closed_
If the lock is already acquired by another coroutine,
subsequent calls to this method will repeatedly try to acquire the lock
before proceeding to wait for a message.
"""
if channel not in self.channels:
self.channels[channel] = asyncio.Queue()
await self.connection.subscribe(channel)
await self._subscribe_to_channel(channel)
q = self.channels[channel]
try:
message = await q.get()
while True:
(message_id, message) = await q.get()
if message is None:
m = await Message.objects.filter(pk=message_id).afirst()
if m is None:
continue
message = cast(bytes, m.message)
break
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
# We assume here that the reason we are cancelled is because the consumer
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
@@ -204,7 +203,7 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
if channel in self.channels:
del self.channels[channel]
try:
await self.connection.unsubscribe(channel)
await self.receiver.unsubscribe(channel)
except BaseException as exc: # noqa: BLE001
LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc)
# We don't re-raise here because we want the CancelledError to be the one
@@ -286,15 +285,14 @@ class PostgresChannelLoopLayer(BaseChannelLayer):
Deletes all messages and groups.
"""
self.channels = {}
await self.connection.flush()
await self.receiver.flush()
class PostgresChannelLayerConnection:
class PostgresChannelLayerReceiver:
def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None:
self.using = using
self.channel_layer = channel_layer
self._subscribed_to: set[str] = set()
self._locked_channels: set[str] = set()
self._lock = asyncio.Lock()
self._receive_task: asyncio.Task[None] | None = None
@@ -325,33 +323,23 @@ class PostgresChannelLayerConnection:
while True:
try:
async with await self._create_connection() as conn:
await self._update_locks(conn)
await self._process_backlog(conn)
await conn.execute(
sql.SQL("LISTEN {channel}").format(channel=sql.Identifier(NOTIFY_CHANNEL))
)
while True:
await self._process_backlog(conn)
await conn.execute(
sql.SQL("LISTEN {channel}").format(
channel=sql.Identifier(NOTIFY_CHANNEL)
)
)
first_loop = True
async for notify in conn.notifies(stop_after=1, timeout=5):
if first_loop:
await self._update_locks(conn)
first_loop = False
await self._receive_notify(conn, notify)
if first_loop:
await self._update_locks(conn)
async for notify in conn.notifies(timeout=30):
await self._receive_notify(notify)
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
raise
except PsycopgError as exc:
LOGGER.warning("Postgres connection is not healthy", exc=exc)
except BaseException as exc: # noqa: BLE001
LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True)
self._locked_channels = set()
await asyncio.sleep(1)
async def _process_backlog(self, conn: AsyncConnection) -> None:
if not self._locked_channels:
if not self._subscribed_to:
return
async with conn.cursor() as cursor:
await cursor.execute(
@@ -362,102 +350,49 @@ class PostgresChannelLayerConnection:
WHERE
{table}.{channel} IN (%s)
AND {table}.{expires} >= %s
RETURNING {table}.{channel}, {table}.{message}
RETURNING {table}.{id}, {table}.{channel}, {table}.{message}
"""
).format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
expires=sql.Identifier("expires"),
message=sql.Identifier("message"),
),
(tuple(self._locked_channels), now()),
(tuple(self._subscribed_to), now()),
)
async for row in cursor:
channel, message = row
self._receive_message(channel, message)
message_id, channel, message = row
self._receive_message(channel, message_id, message)
def _get_lock_id(self, channel: str) -> int:
lock_id = _cast_lock_id(f"channels.{channel}") # type: ignore[no-untyped-call]
return cast(int, lock_id)
async def _update_locks(self, conn: AsyncConnection) -> None:
async with self._lock:
locks_to_release = self._locked_channels - self._subscribed_to
locks_to_acquire = self._subscribed_to - self._locked_channels
for channel in locks_to_acquire:
lock_id = self._get_lock_id(channel)
async with conn.cursor() as cursor:
await cursor.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,))
row = await cursor.fetchone()
if row is not None:
if row[0]:
self._locked_channels.add(channel)
for channel in locks_to_release:
lock_id = self._get_lock_id(channel)
async with conn.cursor() as cursor:
await cursor.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
row = await cursor.fetchone()
if row is not None:
if row[0]:
self._locked_channels.remove(channel)
async def _receive_notify(self, conn: AsyncConnection, notify: Notify) -> None:
async def _receive_notify(self, notify: Notify) -> None:
payload = notify.payload
split_payload = payload.split(":")
message: bytes | None = None
match len(split_payload):
case 4:
message_id, channel, timestamp, base64_message = split_payload
if channel not in self._locked_channels:
if channel not in self._subscribed_to:
return
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
if expires < now():
return
message = b64decode(base64_message)
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("DELETE FROM {table} WHERE {table}.{id} = %s").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
)
)
case 3:
message_id, channel, timestamp = split_payload
if channel not in self._locked_channels:
if channel not in self._subscribed_to:
return
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
if expires < now():
return
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL(
"""
DELETE
FROM {table}
WHERE
{table}.{id} = %s
RETURNING {table}.{message}, {table}.{expires}
"""
).format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
message=sql.Identifier("message"),
expires=sql.Identifier("expires"),
),
(message_id,),
)
row = await cursor.fetchone()
if row is None:
return
message, expires = row
message = None
case _:
return
self._receive_message(channel, message)
self._receive_message(channel, message_id, message)
def _receive_message(self, channel: str, message: bytes) -> None:
def _receive_message(self, channel: str, message_id: str, message: bytes | None) -> None:
if (q := self.channel_layer.channels.get(channel)) is not None:
q.put_nowait(message)
q.put_nowait((message_id, message))
def _ensure_receiver(self) -> None:
if self._receive_task is None:
@@ -33,7 +33,6 @@ classifiers = [
dependencies = [
"channels >=4.3,<4.4",
"django >=4.2,<6.0",
"django-pglock >=1.7,<2",
"django-pgtrigger >=4,<5",
"msgpack >=1,<2",
"psycopg >=3,<4",
Generated
-2
View File
@@ -989,7 +989,6 @@ source = { editable = "packages/django-channels-postgres" }
dependencies = [
{ name = "channels" },
{ name = "django" },
{ name = "django-pglock" },
{ name = "django-pgtrigger" },
{ name = "msgpack" },
{ name = "psycopg" },
@@ -1000,7 +999,6 @@ dependencies = [
requires-dist = [
{ name = "channels", specifier = ">=4.3,<4.4" },
{ name = "django", specifier = ">=4.2,<6.0" },
{ name = "django-pglock", specifier = ">=1.7,<2" },
{ name = "django-pgtrigger", specifier = ">=4,<5" },
{ name = "msgpack", specifier = ">=1,<2" },
{ name = "psycopg", specifier = ">=3,<4" },