packages/django-dramatiq-postgres: broker: ensure locking happens with the same connection (cherry-pick #18095 to version-2025.10) (#18119)

packages/django-dramatiq-postgres: broker: ensure locking happens with the same connection (#18095)

Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
authentik-automation[bot]
2025-11-13 17:18:13 +00:00
committed by GitHub
parent 7dd1cd5c59
commit d018f0381c
5 changed files with 65 additions and 54 deletions
-1
View File
@@ -145,7 +145,6 @@ worker:
consumer_listen_timeout: "seconds=30"
task_max_retries: 5
task_default_time_limit: "minutes=10"
lock_purge_interval: "minutes=1"
task_purge_interval: "days=1"
task_expiration: "days=30"
scheduler_interval: "seconds=60"
-3
View File
@@ -380,9 +380,6 @@ DRAMATIQ = {
"broker_class": "authentik.tasks.broker.Broker",
"channel_prefix": "authentik",
"task_model": "authentik.tasks.models.Task",
"lock_purge_interval": timedelta_from_string(
CONFIG.get("worker.lock_purge_interval")
).total_seconds(),
"task_purge_interval": timedelta_from_string(
CONFIG.get("worker.task_purge_interval")
).total_seconds(),
@@ -61,6 +61,7 @@ def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]:
try:
return func(*args, **kwargs)
except DATABASE_ERRORS as exc:
logger.warning("Database error encountered", exc=exc)
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
return wrapper
@@ -239,15 +240,18 @@ class _PostgresConsumer(Consumer):
self.in_processing: set[str] = set()
self.prefetch = prefetch
self.misses = 0
# We have two different connections here. One for locks and one for listening to
# notifications. We can't use the same connection for both as the listen connection might
# be blocked with pending notifications. We also can't use a Django connection as we can't
# be sure we'll get the same one every time to be able to release locks from the same
# connection.
self._locks_connection: DatabaseWrapper | None = None
self._listen_connection: DatabaseWrapper | None = None
self.postgres_channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
# Override because dramatiq doesn't allow us setting this manually
self.timeout = Conf().worker["consumer_listen_timeout"]
self.lock_purge_interval = timedelta(seconds=Conf().lock_purge_interval)
self.lock_purge_last_run = timezone.now()
self.task_purge_interval = timedelta(seconds=Conf().task_purge_interval)
self.task_purge_last_run = timezone.now() - self.task_purge_interval
@@ -258,14 +262,17 @@ class _PostgresConsumer(Consumer):
self.scheduler_interval = timedelta(seconds=Conf().scheduler_interval)
self.scheduler_last_run = timezone.now() - self.scheduler_interval
@property
def connection(self) -> DatabaseWrapper:
return cast(DatabaseWrapper, connections[self.db_alias])
@property
def query_set(self) -> QuerySet[TaskBase]:
return self.broker.query_set
@property
def locks_connection(self) -> DatabaseWrapper:
if self._locks_connection is not None and self._locks_connection.is_usable():
return self._locks_connection
self._locks_connection = cast(DatabaseWrapper, connections.create_connection(self.db_alias))
return self._locks_connection
@property
def listen_connection(self) -> DatabaseWrapper:
if self._listen_connection is not None and self._listen_connection.is_usable():
@@ -320,21 +327,40 @@ class _PostgresConsumer(Consumer):
self.logger.debug("Message already consumed by self", message_id=message_id)
return None
lock_result = (
self.query_set.filter(message_id=message_id)
.exclude(state__in=(TaskState.DONE, TaskState.REJECTED))
.exclude(eta__gte=timezone.now() + timedelta(seconds=self.timeout))
.extra(
where=["pg_try_advisory_lock(%s)"],
params=[self._get_message_lock_id(message_id)],
with self.locks_connection.cursor() as cursor:
cursor.execute(
sql.SQL(
"""
UPDATE {table}
SET {state} = %(state)s, {mtime} = %(mtime)s
WHERE
{table}.{message_id} = %(message_id)s
AND
{table}.{state} != ALL(%(excluded_states)s)
AND
({table}.{eta} < %(maximum_eta)s OR {table}.{eta} IS NULL)
AND
pg_try_advisory_lock(%(lock_id)s)
"""
).format(
table=sql.Identifier(self.query_set.model._meta.db_table),
state=sql.Identifier("state"),
mtime=sql.Identifier("mtime"),
message_id=sql.Identifier("message_id"),
eta=sql.Identifier("eta"),
),
{
"state": TaskState.CONSUMED.value,
"mtime": timezone.now(),
"message_id": message_id,
"excluded_states": [TaskState.DONE.value, TaskState.REJECTED.value],
"maximum_eta": timezone.now() + timedelta(seconds=self.timeout),
"lock_id": self._get_message_lock_id(message_id),
},
)
.update(
state=TaskState.CONSUMED,
mtime=timezone.now(),
)
)
if lock_result != 1:
return None
if cursor.rowcount != 1:
self._unlock_message(message_id)
return None
task: TaskBase | None = (
self.query_set.defer(None).defer("result").filter(message_id=message_id).first()
@@ -405,9 +431,10 @@ class _PostgresConsumer(Consumer):
def _unlock_message(self, message_id: str) -> bool:
self.logger.debug("Unlocking message", message_id=message_id)
try:
with self.connection.cursor() as cursor:
with self.locks_connection.cursor() as cursor:
cursor.execute(
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),)
"SELECT pg_advisory_unlock(%s)",
(self._get_message_lock_id(message_id),),
)
return True
except DATABASE_ERRORS:
@@ -420,7 +447,7 @@ class _PostgresConsumer(Consumer):
self.in_processing.remove(str(message.message_id))
except KeyError:
pass
self._unlock_message(str(message.message_id))
self.to_unlock.add(str(message.message_id))
task = message.options.pop("task", None)
self.query_set.filter(
message_id=message.message_id,
@@ -453,7 +480,6 @@ class _PostgresConsumer(Consumer):
for message in messages:
self.to_unlock.add(str(message.message_id))
self.in_processing.remove(str(message.message_id))
self._purge_locks()
def _scheduler(self) -> None:
if not self.scheduler:
@@ -464,8 +490,6 @@ class _PostgresConsumer(Consumer):
self.schedule_last_run = timezone.now()
def _purge_locks(self) -> None:
if timezone.now() - self.lock_purge_last_run < self.lock_purge_interval:
return
while True:
try:
message_id = self.to_unlock.pop()
@@ -473,7 +497,6 @@ class _PostgresConsumer(Consumer):
break
if not self._unlock_message(str(message_id)):
return
self.lock_purge_last_run = timezone.now()
def _auto_purge(self) -> None:
if timezone.now() - self.task_purge_last_run < self.task_purge_interval:
@@ -492,15 +515,17 @@ class _PostgresConsumer(Consumer):
try:
self._purge_locks()
finally:
try:
self.connection.close()
except DATABASE_ERRORS:
pass
finally:
if self._listen_connection is not None:
conn = self._listen_connection
self._listen_connection = None
try:
conn.close()
except DATABASE_ERRORS:
pass
if self._locks_connection is not None:
conn = self._locks_connection
self._locks_connection = None
try:
conn.close()
except DATABASE_ERRORS:
pass
if self._listen_connection is not None:
conn = self._listen_connection
self._listen_connection = None
try:
conn.close()
except DATABASE_ERRORS:
pass
@@ -63,10 +63,6 @@ class Conf:
def task_model(self) -> str:
return cast(str, self.conf["task_model"])
@property
def lock_purge_interval(self) -> int:
return cast(int, self.conf.get("lock_purge_interval", 60))
@property
def task_purge_interval(self) -> int:
# 24 hours
@@ -215,12 +215,6 @@ Configure the default duration a task can run for before it is aborted. Some tas
Defaults to `minutes=10`.
##### `AUTHENTIK_WORKER__LOCK_PURGE_INTERVAL`
Configure the interval at which old PostgreSQL locks are cleaned up.
Defaults to `minutes=1`.
##### `AUTHENTIK_WORKER__TASK_PURGE_INTERVAL`
Configure the interval at which old tasks are cleaned up.