mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
packages/django-dramatiq-postgres: broker: ensure locking happens with the same connection (cherry-pick #18095 to version-2025.10) (#18119)
packages/django-dramatiq-postgres: broker: ensure locking happens with the same connection (#18095) Co-authored-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
This commit is contained in:
committed by
GitHub
parent
7dd1cd5c59
commit
d018f0381c
@@ -145,7 +145,6 @@ worker:
|
||||
consumer_listen_timeout: "seconds=30"
|
||||
task_max_retries: 5
|
||||
task_default_time_limit: "minutes=10"
|
||||
lock_purge_interval: "minutes=1"
|
||||
task_purge_interval: "days=1"
|
||||
task_expiration: "days=30"
|
||||
scheduler_interval: "seconds=60"
|
||||
|
||||
@@ -380,9 +380,6 @@ DRAMATIQ = {
|
||||
"broker_class": "authentik.tasks.broker.Broker",
|
||||
"channel_prefix": "authentik",
|
||||
"task_model": "authentik.tasks.models.Task",
|
||||
"lock_purge_interval": timedelta_from_string(
|
||||
CONFIG.get("worker.lock_purge_interval")
|
||||
).total_seconds(),
|
||||
"task_purge_interval": timedelta_from_string(
|
||||
CONFIG.get("worker.task_purge_interval")
|
||||
).total_seconds(),
|
||||
|
||||
@@ -61,6 +61,7 @@ def raise_connection_error(func: Callable[P, R]) -> Callable[P, R]:
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except DATABASE_ERRORS as exc:
|
||||
logger.warning("Database error encountered", exc=exc)
|
||||
raise ConnectionError(str(exc)) from exc # type: ignore[no-untyped-call]
|
||||
|
||||
return wrapper
|
||||
@@ -239,15 +240,18 @@ class _PostgresConsumer(Consumer):
|
||||
self.in_processing: set[str] = set()
|
||||
self.prefetch = prefetch
|
||||
self.misses = 0
|
||||
# We have two different connections here. One for locks and one for listening to
|
||||
# notifications. We can't use the same connection for both as the listen connection might
|
||||
# be blocked with pending notifications. We also can't use a Django connection as we can't
|
||||
# be sure we'll get the same one every time to be able to release locks from the same
|
||||
# connection.
|
||||
self._locks_connection: DatabaseWrapper | None = None
|
||||
self._listen_connection: DatabaseWrapper | None = None
|
||||
self.postgres_channel = channel_name(self.queue_name, ChannelIdentifier.ENQUEUE)
|
||||
|
||||
# Override because dramatiq doesn't allow us setting this manually
|
||||
self.timeout = Conf().worker["consumer_listen_timeout"]
|
||||
|
||||
self.lock_purge_interval = timedelta(seconds=Conf().lock_purge_interval)
|
||||
self.lock_purge_last_run = timezone.now()
|
||||
|
||||
self.task_purge_interval = timedelta(seconds=Conf().task_purge_interval)
|
||||
self.task_purge_last_run = timezone.now() - self.task_purge_interval
|
||||
|
||||
@@ -258,14 +262,17 @@ class _PostgresConsumer(Consumer):
|
||||
self.scheduler_interval = timedelta(seconds=Conf().scheduler_interval)
|
||||
self.scheduler_last_run = timezone.now() - self.scheduler_interval
|
||||
|
||||
@property
|
||||
def connection(self) -> DatabaseWrapper:
|
||||
return cast(DatabaseWrapper, connections[self.db_alias])
|
||||
|
||||
@property
|
||||
def query_set(self) -> QuerySet[TaskBase]:
|
||||
return self.broker.query_set
|
||||
|
||||
@property
|
||||
def locks_connection(self) -> DatabaseWrapper:
|
||||
if self._locks_connection is not None and self._locks_connection.is_usable():
|
||||
return self._locks_connection
|
||||
self._locks_connection = cast(DatabaseWrapper, connections.create_connection(self.db_alias))
|
||||
return self._locks_connection
|
||||
|
||||
@property
|
||||
def listen_connection(self) -> DatabaseWrapper:
|
||||
if self._listen_connection is not None and self._listen_connection.is_usable():
|
||||
@@ -320,21 +327,40 @@ class _PostgresConsumer(Consumer):
|
||||
self.logger.debug("Message already consumed by self", message_id=message_id)
|
||||
return None
|
||||
|
||||
lock_result = (
|
||||
self.query_set.filter(message_id=message_id)
|
||||
.exclude(state__in=(TaskState.DONE, TaskState.REJECTED))
|
||||
.exclude(eta__gte=timezone.now() + timedelta(seconds=self.timeout))
|
||||
.extra(
|
||||
where=["pg_try_advisory_lock(%s)"],
|
||||
params=[self._get_message_lock_id(message_id)],
|
||||
with self.locks_connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
UPDATE {table}
|
||||
SET {state} = %(state)s, {mtime} = %(mtime)s
|
||||
WHERE
|
||||
{table}.{message_id} = %(message_id)s
|
||||
AND
|
||||
{table}.{state} != ALL(%(excluded_states)s)
|
||||
AND
|
||||
({table}.{eta} < %(maximum_eta)s OR {table}.{eta} IS NULL)
|
||||
AND
|
||||
pg_try_advisory_lock(%(lock_id)s)
|
||||
"""
|
||||
).format(
|
||||
table=sql.Identifier(self.query_set.model._meta.db_table),
|
||||
state=sql.Identifier("state"),
|
||||
mtime=sql.Identifier("mtime"),
|
||||
message_id=sql.Identifier("message_id"),
|
||||
eta=sql.Identifier("eta"),
|
||||
),
|
||||
{
|
||||
"state": TaskState.CONSUMED.value,
|
||||
"mtime": timezone.now(),
|
||||
"message_id": message_id,
|
||||
"excluded_states": [TaskState.DONE.value, TaskState.REJECTED.value],
|
||||
"maximum_eta": timezone.now() + timedelta(seconds=self.timeout),
|
||||
"lock_id": self._get_message_lock_id(message_id),
|
||||
},
|
||||
)
|
||||
.update(
|
||||
state=TaskState.CONSUMED,
|
||||
mtime=timezone.now(),
|
||||
)
|
||||
)
|
||||
if lock_result != 1:
|
||||
return None
|
||||
if cursor.rowcount != 1:
|
||||
self._unlock_message(message_id)
|
||||
return None
|
||||
|
||||
task: TaskBase | None = (
|
||||
self.query_set.defer(None).defer("result").filter(message_id=message_id).first()
|
||||
@@ -405,9 +431,10 @@ class _PostgresConsumer(Consumer):
|
||||
def _unlock_message(self, message_id: str) -> bool:
|
||||
self.logger.debug("Unlocking message", message_id=message_id)
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
with self.locks_connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT pg_advisory_unlock(%s)", (self._get_message_lock_id(message_id),)
|
||||
"SELECT pg_advisory_unlock(%s)",
|
||||
(self._get_message_lock_id(message_id),),
|
||||
)
|
||||
return True
|
||||
except DATABASE_ERRORS:
|
||||
@@ -420,7 +447,7 @@ class _PostgresConsumer(Consumer):
|
||||
self.in_processing.remove(str(message.message_id))
|
||||
except KeyError:
|
||||
pass
|
||||
self._unlock_message(str(message.message_id))
|
||||
self.to_unlock.add(str(message.message_id))
|
||||
task = message.options.pop("task", None)
|
||||
self.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
@@ -453,7 +480,6 @@ class _PostgresConsumer(Consumer):
|
||||
for message in messages:
|
||||
self.to_unlock.add(str(message.message_id))
|
||||
self.in_processing.remove(str(message.message_id))
|
||||
self._purge_locks()
|
||||
|
||||
def _scheduler(self) -> None:
|
||||
if not self.scheduler:
|
||||
@@ -464,8 +490,6 @@ class _PostgresConsumer(Consumer):
|
||||
self.schedule_last_run = timezone.now()
|
||||
|
||||
def _purge_locks(self) -> None:
|
||||
if timezone.now() - self.lock_purge_last_run < self.lock_purge_interval:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
message_id = self.to_unlock.pop()
|
||||
@@ -473,7 +497,6 @@ class _PostgresConsumer(Consumer):
|
||||
break
|
||||
if not self._unlock_message(str(message_id)):
|
||||
return
|
||||
self.lock_purge_last_run = timezone.now()
|
||||
|
||||
def _auto_purge(self) -> None:
|
||||
if timezone.now() - self.task_purge_last_run < self.task_purge_interval:
|
||||
@@ -492,15 +515,17 @@ class _PostgresConsumer(Consumer):
|
||||
try:
|
||||
self._purge_locks()
|
||||
finally:
|
||||
try:
|
||||
self.connection.close()
|
||||
except DATABASE_ERRORS:
|
||||
pass
|
||||
finally:
|
||||
if self._listen_connection is not None:
|
||||
conn = self._listen_connection
|
||||
self._listen_connection = None
|
||||
try:
|
||||
conn.close()
|
||||
except DATABASE_ERRORS:
|
||||
pass
|
||||
if self._locks_connection is not None:
|
||||
conn = self._locks_connection
|
||||
self._locks_connection = None
|
||||
try:
|
||||
conn.close()
|
||||
except DATABASE_ERRORS:
|
||||
pass
|
||||
if self._listen_connection is not None:
|
||||
conn = self._listen_connection
|
||||
self._listen_connection = None
|
||||
try:
|
||||
conn.close()
|
||||
except DATABASE_ERRORS:
|
||||
pass
|
||||
|
||||
@@ -63,10 +63,6 @@ class Conf:
|
||||
def task_model(self) -> str:
|
||||
return cast(str, self.conf["task_model"])
|
||||
|
||||
@property
|
||||
def lock_purge_interval(self) -> int:
|
||||
return cast(int, self.conf.get("lock_purge_interval", 60))
|
||||
|
||||
@property
|
||||
def task_purge_interval(self) -> int:
|
||||
# 24 hours
|
||||
|
||||
@@ -215,12 +215,6 @@ Configure the default duration a task can run for before it is aborted. Some tas
|
||||
|
||||
Defaults to `minutes=10`.
|
||||
|
||||
##### `AUTHENTIK_WORKER__LOCK_PURGE_INTERVAL`
|
||||
|
||||
Configure the interval at which old PostgreSQL locks are cleaned up.
|
||||
|
||||
Defaults to `minutes=1`.
|
||||
|
||||
##### `AUTHENTIK_WORKER__TASK_PURGE_INTERVAL`
|
||||
|
||||
Configure the interval at which old tasks are cleaned up.
|
||||
|
||||
Reference in New Issue
Block a user