diff --git a/CODEOWNERS b/CODEOWNERS index 92f5a888d8..33d99c6634 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -24,6 +24,7 @@ Makefile @goauthentik/infrastructure .editorconfig @goauthentik/infrastructure CODEOWNERS @goauthentik/infrastructure # Backend packages +packages/django-channels-postgres @goauthentik/backend packages/django-postgres-cache @goauthentik/backend packages/django-dramatiq-postgres @goauthentik/backend # Web packages diff --git a/authentik/core/tasks.py b/authentik/core/tasks.py index 833822cf49..01e029ecee 100644 --- a/authentik/core/tasks.py +++ b/authentik/core/tasks.py @@ -4,6 +4,7 @@ from datetime import datetime, timedelta from django.utils.timezone import now from django.utils.translation import gettext_lazy as _ +from django_channels_postgres.models import GroupChannel, Message from django_postgres_cache.tasks import clear_expired_cache from dramatiq.actor import actor from structlog.stdlib import get_logger @@ -34,6 +35,8 @@ def clean_expired_models(): LOGGER.debug("Expired models", model=cls, amount=amount) self.info(f"Expired {amount} {cls._meta.verbose_name_plural}") clear_expired_cache() + Message.delete_expired() + GroupChannel.delete_expired() @actor(description=_("Remove temporary users created by SAML Sources.")) diff --git a/authentik/lib/logging.py b/authentik/lib/logging.py index aec7afc489..8f750da07e 100644 --- a/authentik/lib/logging.py +++ b/authentik/lib/logging.py @@ -112,7 +112,6 @@ def get_logger_config(): "hpack": "WARNING", "httpx": "WARNING", "azure": "WARNING", - "channels_postgres": "WARNING", } for handler_name, level in handler_level_map.items(): base_config["loggers"][handler_name] = { diff --git a/authentik/root/channels.py b/authentik/root/channels.py deleted file mode 100644 index 43b8156213..0000000000 --- a/authentik/root/channels.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Any - -from channels_postgres.core import PostgresChannelLayer as BasePostgresChannelLayer -from channels_postgres.db import DatabaseLayer as BaseDatabaseLayer -from django.conf import settings -from psycopg_pool import AsyncConnectionPool - -from authentik.root.db.base import DatabaseWrapper - - -class DatabaseLayer(BaseDatabaseLayer): - async def get_db_pool(self, db_params: dict[str, Any]) -> AsyncConnectionPool: - db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"]) - db_params = db_wrapper.get_connection_params() - db_params.pop("cursor_factory") - db_params.pop("context") - return await super().get_db_pool(db_params) - - -class PostgresChannelLayer(BasePostgresChannelLayer): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - self.django_db = DatabaseLayer(self.django_db.psycopg_options, self.db_params) - - @property - def db_params(self): - db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"]) - db_params = db_wrapper.get_connection_params() - db_params.pop("cursor_factory") - db_params.pop("context") - return db_params - - @db_params.setter - def db_params(self, value): - pass diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 080001f9e6..d6d9a1a930 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -64,7 +64,7 @@ SHARED_APPS = [ "pgactivity", "pglock", "channels", - "channels_postgres", + "django_channels_postgres", "django_dramatiq_postgres", "authentik.tasks", ] @@ -304,11 +304,7 @@ DATABASE_ROUTERS = ( CHANNEL_LAYERS = { "default": { - "BACKEND": "authentik.root.channels.PostgresChannelLayer", - "CONFIG": { - **DATABASES["default"], - "TIME_ZONE": None, - }, + "BACKEND": "django_channels_postgres.layer.PostgresChannelLayer", }, } diff --git a/authentik/root/test_runner.py b/authentik/root/test_runner.py index 0def491817..2a9bbf5653 100644 --- a/authentik/root/test_runner.py +++ b/authentik/root/test_runner.py @@ -62,11 +62,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover """Configure test environment settings""" settings.TEST = True settings.DRAMATIQ["test"] = True - settings.CHANNEL_LAYERS["default"]["CONFIG"] = { - **settings.DATABASES["default"], - **settings.DATABASES["default"]["TEST"], - "TIME_ZONE": None, - } # Test-specific configuration test_config = { diff --git a/packages/django-channels-postgres/README.md b/packages/django-channels-postgres/README.md new file mode 100644 index 0000000000..c69bd8ab8f --- /dev/null +++ b/packages/django-channels-postgres/README.md @@ -0,0 +1 @@ +# django-channels-postgres diff --git a/packages/django-channels-postgres/django_channels_postgres/__init__.py b/packages/django-channels-postgres/django_channels_postgres/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-channels-postgres/django_channels_postgres/apps.py b/packages/django-channels-postgres/django_channels_postgres/apps.py new file mode 100644 index 0000000000..488286d774 --- /dev/null +++ b/packages/django-channels-postgres/django_channels_postgres/apps.py @@ -0,0 +1,5 @@ +from django.apps import AppConfig + + +class DjangoChannelsPostgresConfig(AppConfig): + name = "django_channels_postgres" diff --git a/packages/django-channels-postgres/django_channels_postgres/layer.py b/packages/django-channels-postgres/django_channels_postgres/layer.py new file mode 100644 index 0000000000..447c0eccd0 --- /dev/null +++ b/packages/django-channels-postgres/django_channels_postgres/layer.py @@ -0,0 +1,472 @@ +import asyncio +import functools +import types +from base64 import b64decode +from datetime import UTC, datetime, timedelta +from re import Pattern +from typing import Any, cast +from uuid import uuid4 + +import msgpack +from channels.layers import BaseChannelLayer +from django.db import DEFAULT_DB_ALIAS, connections +from django.utils.timezone import now +from pglock.core import _cast_lock_id +from psycopg import AsyncConnection, Notify, sql +from psycopg.conninfo import make_conninfo +from psycopg.errors import Error as PsycopgError +from structlog.stdlib import get_logger + +from django_channels_postgres.models import NOTIFY_CHANNEL, GroupChannel, Message + +LOGGER = get_logger() + + +GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table +MESSAGE_TABLE = Message._meta.db_table + + +async def _async_proxy( + obj: "PostgresChannelLayerLoopProxy", + name: str, + *args: Any, + **kwargs: Any, +) -> Any: + # Must be defined as a function and not a method due to + # https://bugs.python.org/issue38364 + layer = obj._get_layer() + return await getattr(layer, name)(*args, **kwargs) + + +def _wrap_close(proxy: "PostgresChannelLayerLoopProxy", loop: asyncio.AbstractEventLoop) -> None: + original_impl = loop.close + + def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None: + if loop in proxy._layers: + layer = proxy._layers[loop] + del proxy._layers[loop] + loop.run_until_complete(layer.flush()) + self.close = original_impl # type: ignore[method-assign] + return self.close(*args, **kwargs) + + loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign] + + +class PostgresChannelLayerLoopProxy: + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + self._args = args + self._kwargs = kwargs + self._kwargs["channel_layer"] = self + self._layers: dict[asyncio.AbstractEventLoop, PostgresChannelLoopLayer] = {} + + def __getattr__(self, name: str) -> Any: + if name in ( + "new_channel", + "send", + "receive", + "group_add", + "group_discard", + "group_send", + "flush", + ): + return functools.partial(_async_proxy, self, name) + else: + return getattr(self._get_layer(), name) + + def serialize(self, message: dict[str, Any]) -> bytes: + """Serializes message to a byte string.""" + return cast(bytes, msgpack.packb(message, use_bin_type=True)) + + def deserialize(self, message: bytes) -> dict[str, Any]: + """Deserializes from a byte string.""" + return cast(dict[str, Any], msgpack.unpackb(message, raw=False)) + + def _get_layer(self) -> "PostgresChannelLoopLayer": + loop = asyncio.get_running_loop() + + try: + layer = self._layers[loop] + except KeyError: + layer = PostgresChannelLoopLayer(*self._args, **self._kwargs) + self._layers[loop] = layer + _wrap_close(self, loop) + + return layer + + +PostgresChannelLayer = PostgresChannelLayerLoopProxy + + +class PostgresChannelLoopLayer(BaseChannelLayer): + """ + Postgres channel layer. + + It uses the NOTIFY/LISTEN functionality of postgres to broadcast messages + + It also makes use of an internal message table to overcome the + 8000bytes limit of Postgres' NOTIFY messages. + Which is a far cry from the channels standard of 1MB + This table has a trigger that sends out the `NOTIFY` signal. + + Using a database also means messages are durable and will always be + available to consumers (as long as they're not expired). + """ + + def __init__( + self, + channel_layer: PostgresChannelLayerLoopProxy, + prefix: str = "asgi", + expiry: int = 60, + group_expiry: int = 86400, + capacity: int = 100, + channel_capacity: dict[Pattern[str] | str, int] | None = None, + using: str = DEFAULT_DB_ALIAS, + ) -> None: + super().__init__(expiry=expiry, capacity=capacity, channel_capacity=channel_capacity) + + self.group_expiry = group_expiry + self.prefix = prefix + assert isinstance(self.prefix, str), "Prefix must be unicode" # nosec + self.channel_layer = channel_layer + self.using = using + + # Each consumer gets its own *specific* channel, created with the `new_channel()` method. + # This dict maps `channel_name` to a queue of messages for that channel. + self.channels: dict[str, asyncio.Queue[bytes]] = {} + + self.connection = PostgresChannelLayerConnection(self.using, self) + + extensions = ["groups", "flush"] + + ### Channel layer API ### + + async def send(self, channel: str, message: dict[str, Any]) -> None: + """ + Send a message onto a (general or specific) channel. + """ + # Typecheck + assert isinstance(message, dict), "message is not a dict" # nosec + assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec + # Make sure the message does not contain reserved keys + assert "__asgi_channel__" not in message # nosec + + await Message.objects.using(self.using).acreate( + channel=channel, + message=self.channel_layer.serialize(message), + expires=now() + timedelta(seconds=self.expiry), + ) + + async def new_channel(self, prefix: str = "specific") -> str: + """ + Returns a new channel name that can be used by something in our + process as a specific channel. + """ + return f"{self.prefix}.{prefix}.{uuid4().hex}" + + async def receive(self, channel: str) -> dict[str, Any]: + """ + Receive the first message that arrives on the channel. + If more than one coroutine waits on the same channel, the first waiter + will be given the message when it arrives. + + This is done by acquiring an `advistory_lock` from the database + based on the channel name. + + If the lock is acquired successfully, subsequent calls to this method + will not try to acquire the lock again. + _The lock is session based and should be released by postgres when + the session is closed_ + + If the lock is already acquired by another coroutine, + subsequent calls to this method will repeatedly try to acquire the lock + before proceeding to wait for a message. + """ + if channel not in self.channels: + self.channels[channel] = asyncio.Queue() + await self.connection.subscribe(channel) + + q = self.channels[channel] + try: + message = await q.get() + except (asyncio.CancelledError, TimeoutError, GeneratorExit): + # We assume here that the reason we are cancelled is because the consumer + # is exiting, therefore we need to cleanup by unsubscribe below. Indeed, + # currently the way that Django Channels works, this is a safe assumption. + # In the future, Django Channels could change to call a *new* method that + # would serve as the antithesis of `new_channel()`; this new method might + # be named `delete_channel()`. If that were the case, we would do the + # following cleanup from that new `delete_channel()` method, but, since + # that's not how Django Channels works (yet), we do the cleanup below: + if channel in self.channels: + del self.channels[channel] + try: + await self.connection.unsubscribe(channel) + except BaseException as exc: # noqa: BLE001 + LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc) + # We don't re-raise here because we want the CancelledError to be the one + # re-raised + raise + return self.channel_layer.deserialize(message) + + # ============================================================== + # Groups extension + # ============================================================== + + async def group_add(self, group: str, channel: str) -> None: + """ + Adds the channel name to a group. + """ + # Check the inputs + assert self.require_valid_group_name(group), "Group name not valid" # nosec + assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec + + group_key = self._group_key(group) + + await GroupChannel.objects.using(self.using).aupdate_or_create( + group_key=group_key, + channel=channel, + defaults={ + "expires": now() + timedelta(seconds=self.group_expiry), + }, + ) + + async def group_discard(self, group: str, channel: str) -> None: + """ + Removes the channel from the named group if it is in the group; + does nothing otherwise (does not error) + """ + # Check the inputs + assert self.require_valid_group_name(group), "Group name not valid" # nosec + assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec + + group_key = self._group_key(group) + + await ( + GroupChannel.objects.using(self.using) + .filter(group_key=group_key, channel=channel) + .adelete() + ) + + async def group_send(self, group: str, message: dict[str, Any]) -> None: + """ + Sends a message to the entire group. + """ + assert self.require_valid_group_name(group), "Group name not valid" # nosec + + group_key = self._group_key(group) + + serialized_message = self.channel_layer.serialize(message) + messages = [ + Message( + channel=channel, + message=serialized_message, + expires=now() + timedelta(seconds=self.expiry), + ) + async for channel in GroupChannel.objects.using(self.using) + .filter(group_key=group_key, expires__gte=now()) + .values_list("channel", flat=True) + .distinct() + ] + await Message.objects.using(self.using).abulk_create(messages) + + def _group_key(self, group: str) -> str: + """ + Common function to make the storage key for the group. + """ + return f"{self.prefix}.group.{group}" + + ### Flush extension ### + + async def flush(self) -> None: + """ + Deletes all messages and groups. + """ + self.channels = {} + await self.connection.flush() + + +class PostgresChannelLayerConnection: + def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None: + self.using = using + self.channel_layer = channel_layer + self._subscribed_to: set[str] = set() + self._locked_channels: set[str] = set() + self._lock = asyncio.Lock() + self._receive_task: asyncio.Task[None] | None = None + + async def subscribe(self, channel: str) -> None: + async with self._lock: + if channel not in self._subscribed_to: + self._ensure_receiver() + self._subscribed_to.add(channel) + + async def unsubscribe(self, channel: str) -> None: + async with self._lock: + if channel in self._subscribed_to: + self._ensure_receiver() + self._subscribed_to.remove(channel) + + async def flush(self) -> None: + async with self._lock: + if self._receive_task is not None: + self._receive_task.cancel() + try: + await self._receive_task + except asyncio.CancelledError: + pass + self._receive_task = None + self._subscribed_to = set() + + async def _do_receiving(self) -> None: + while True: + try: + async with await self._create_connection() as conn: + await self._update_locks(conn) + while True: + await self._process_backlog(conn) + await conn.execute( + sql.SQL("LISTEN {channel}").format( + channel=sql.Identifier(NOTIFY_CHANNEL) + ) + ) + first_loop = True + async for notify in conn.notifies(stop_after=1, timeout=5): + if first_loop: + await self._update_locks(conn) + first_loop = False + await self._receive_notify(conn, notify) + if first_loop: + await self._update_locks(conn) + except (asyncio.CancelledError, TimeoutError, GeneratorExit): + raise + except PsycopgError as exc: + LOGGER.warning("Postgres connection is not healthy", exc=exc) + except BaseException as exc: # noqa: BLE001 + LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True) + self._locked_channels = set() + await asyncio.sleep(1) + + async def _process_backlog(self, conn: AsyncConnection) -> None: + if not self._locked_channels: + return + async with conn.cursor() as cursor: + await cursor.execute( + sql.SQL( + """ + DELETE + FROM {table} + WHERE + {table}.{channel} IN (%s) + AND {table}.{expires} >= %s + RETURNING {table}.{channel}, {table}.{message} + """ + ).format( + table=sql.Identifier(MESSAGE_TABLE), + channel=sql.Identifier("channel"), + expires=sql.Identifier("expires"), + message=sql.Identifier("message"), + ), + (tuple(self._locked_channels), now()), + ) + async for row in cursor: + channel, message = row + self._receive_message(channel, message) + + def _get_lock_id(self, channel: str) -> int: + lock_id = _cast_lock_id(f"channels.{channel}") # type: ignore[no-untyped-call] + return cast(int, lock_id) + + async def _update_locks(self, conn: AsyncConnection) -> None: + async with self._lock: + locks_to_release = self._locked_channels - self._subscribed_to + locks_to_acquire = self._subscribed_to - self._locked_channels + + for channel in locks_to_acquire: + lock_id = self._get_lock_id(channel) + async with conn.cursor() as cursor: + await cursor.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,)) + row = await cursor.fetchone() + if row is not None: + if row[0]: + self._locked_channels.add(channel) + + for channel in locks_to_release: + lock_id = self._get_lock_id(channel) + async with conn.cursor() as cursor: + await cursor.execute("SELECT pg_advisory_unlock(%s)", (lock_id,)) + row = await cursor.fetchone() + if row is not None: + if row[0]: + self._locked_channels.remove(channel) + + async def _receive_notify(self, conn: AsyncConnection, notify: Notify) -> None: + payload = notify.payload + split_payload = payload.split(":") + match len(split_payload): + case 4: + message_id, channel, timestamp, base64_message = split_payload + if channel not in self._locked_channels: + return + expires = datetime.fromtimestamp(float(timestamp), tz=UTC) + if expires < now(): + return + message = b64decode(base64_message) + async with conn.cursor() as cursor: + await cursor.execute( + sql.SQL("DELETE FROM {table} WHERE {table}.{id} = %s").format( + table=sql.Identifier(MESSAGE_TABLE), + id=sql.Identifier("id"), + ) + ) + case 3: + message_id, channel, timestamp = split_payload + if channel not in self._locked_channels: + return + expires = datetime.fromtimestamp(float(timestamp), tz=UTC) + if expires < now(): + return + async with conn.cursor() as cursor: + await cursor.execute( + sql.SQL( + """ + DELETE + FROM {table} + WHERE + {table}.{id} = %s + RETURNING {table}.{message}, {table}.{expires} + """ + ).format( + table=sql.Identifier(MESSAGE_TABLE), + id=sql.Identifier("id"), + message=sql.Identifier("message"), + expires=sql.Identifier("expires"), + ), + (message_id,), + ) + row = await cursor.fetchone() + if row is None: + return + message, expires = row + case _: + return + self._receive_message(channel, message) + + def _receive_message(self, channel: str, message: bytes) -> None: + if (q := self.channel_layer.channels.get(channel)) is not None: + q.put_nowait(message) + + def _ensure_receiver(self) -> None: + if self._receive_task is None: + self._receive_task = asyncio.ensure_future(self._do_receiving()) + + async def _create_connection(self) -> AsyncConnection: + db_params = connections[self.using].get_connection_params() + # Prevent psycopg from using the custom synchronous cursor factory from django + db_params.pop("cursor_factory") + db_params.pop("context") + conninfo = make_conninfo(conninfo="", **db_params, connect_timeout=10) + return await AsyncConnection.connect(conninfo=conninfo, autocommit=True) diff --git a/packages/django-channels-postgres/django_channels_postgres/migrations/0001_initial.py b/packages/django-channels-postgres/django_channels_postgres/migrations/0001_initial.py new file mode 100644 index 0000000000..b973c859da --- /dev/null +++ b/packages/django-channels-postgres/django_channels_postgres/migrations/0001_initial.py @@ -0,0 +1,94 @@ +# Generated by Django 5.1.13 on 2025-10-04 14:35 + +import django_channels_postgres.models +import pgtrigger.compiler +import pgtrigger.migrations +import uuid +from django.db import migrations, models + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [] + + operations = [ + migrations.CreateModel( + name="GroupChannel", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + ("group_key", models.TextField(db_index=True)), + ("channel", models.TextField(db_index=True)), + ( + "expires", + models.DateTimeField( + db_index=True, default=django_channels_postgres.models._default_group_expiry + ), + ), + ], + options={ + "verbose_name": "Group channel", + "verbose_name_plural": "Group channels", + "indexes": [ + models.Index( + fields=["group_key", "channel"], name="django_chan_group_k_173f44_idx" + ), + models.Index( + fields=["group_key", "expires"], name="django_chan_group_k_45d2e0_idx" + ), + ], + }, + ), + migrations.CreateModel( + name="Message", + fields=[ + ( + "id", + models.UUIDField( + default=uuid.uuid4, editable=False, primary_key=True, serialize=False + ), + ), + ("channel", models.TextField(db_index=True)), + ("message", models.BinaryField()), + ( + "expires", + models.DateTimeField( + db_index=True, + default=django_channels_postgres.models._default_message_expiry, + ), + ), + ], + options={ + "verbose_name": "Message", + "verbose_name_plural": "Messages", + "indexes": [ + models.Index( + fields=["channel", "expires"], name="django_chan_channel_e8ca51_idx" + ) + ], + }, + ), + pgtrigger.migrations.AddTrigger( + model_name="message", + trigger=pgtrigger.compiler.Trigger( + name="notify_new_channels_message", + sql=pgtrigger.compiler.UpsertTriggerSql( + constraint="CONSTRAINT", + declare="DECLARE payload text; encoded_message text; epoch text;", + func="\n encoded_message := encode(NEW.message, 'base64');\n epoch := extract(epoch from NEW.expires)::text;\n IF octet_length(NEW.id::text) + octet_length(NEW.channel) + octet_length(epoch) + octet_length(encoded_message) + 3 <= 8000 THEN\n payload := NEW.id::text || ':' || NEW.channel || ':' || epoch || ':' || encoded_message;\n ELSE\n payload := NEW.id::text || ':' || NEW.channel || ':' || epoch;\n END IF;\n\n PERFORM pg_notify('channels_messages', payload);\n RETURN NEW;\n ", + hash="cf7a665df0bbb7d865cdbc92b63d818fd25733d8", + operation="INSERT", + pgid="pgtrigger_notify_new_channels_message_d21ae", + table="django_channels_postgres_message", + timing="DEFERRABLE INITIALLY DEFERRED", + when="AFTER", + ), + ), + ), + ] diff --git a/packages/django-channels-postgres/django_channels_postgres/migrations/__init__.py b/packages/django-channels-postgres/django_channels_postgres/migrations/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/django-channels-postgres/django_channels_postgres/models.py b/packages/django-channels-postgres/django_channels_postgres/models.py new file mode 100644 index 0000000000..b38ed236b7 --- /dev/null +++ b/packages/django-channels-postgres/django_channels_postgres/models.py @@ -0,0 +1,97 @@ +from datetime import datetime, timedelta +from uuid import uuid4 + +import pgtrigger +from django.db import models +from django.utils.timezone import now +from django.utils.translation import gettext_lazy as _ + + +def _default_group_expiry() -> datetime: + return now() + timedelta(seconds=86400) + + +def _default_message_expiry() -> datetime: + return now() + timedelta(minutes=1) + + +NOTIFY_CHANNEL = "channels_messages" + + +class GroupChannel(models.Model): + """ + A model that represents a group channel. + + Groups are used to send messages to multiple channels. + """ + + id = models.UUIDField(primary_key=True, editable=False, default=uuid4) + group_key = models.TextField(db_index=True) + channel = models.TextField(db_index=True) + expires = models.DateTimeField(db_index=True, default=_default_group_expiry) + + class Meta: + verbose_name = _("Group channel") + verbose_name_plural = _("Group channels") + indexes = ( + models.Index(fields=("group_key", "channel")), + models.Index(fields=("group_key", "expires")), + ) + + def __str__(self) -> str: + return f"Group '{self.group_key}' on channel '{self.channel}'" + + @classmethod + def delete_expired(cls) -> None: + cls.objects.filter(expires__lt=now()).delete() + + +class Message(models.Model): + """ + A model that represents a message. + + Messages are used to send messages to a specific channel. + E.g for user to user private messages. + """ + + id = models.UUIDField(primary_key=True, editable=False, default=uuid4) + channel = models.TextField(db_index=True) + message = models.BinaryField() + expires = models.DateTimeField(db_index=True, default=_default_message_expiry) + + class Meta: + verbose_name = _("Message") + verbose_name_plural = _("Messages") + indexes = (models.Index(fields=("channel", "expires")),) + triggers = ( + pgtrigger.Trigger( + name="notify_new_channels_message", + operation=pgtrigger.Insert, + when=pgtrigger.After, + timing=pgtrigger.Deferred, + declare=[ + ("payload", "text"), + ("encoded_message", "text"), + ("epoch", "text"), + ], + func=f""" + encoded_message := encode(NEW.message, 'base64'); + epoch := extract(epoch from NEW.expires)::text; + IF octet_length(NEW.id::text) + octet_length(NEW.channel) + octet_length(epoch) + octet_length(encoded_message) + 3 <= 8000 THEN + payload := NEW.id::text || ':' || NEW.channel || ':' || epoch || ':' || encoded_message; + ELSE + payload := NEW.id::text || ':' || NEW.channel || ':' || epoch; + END IF; + + PERFORM pg_notify('{NOTIFY_CHANNEL}', payload); + RETURN NEW; + """, # noqa: E501 + ), + ) + + def __str__(self) -> str: + return f"Message '{self.pk}' on channel '{self.channel}'" + + @classmethod + def delete_expired(cls) -> None: + cls.objects.filter(expires__lt=now()).delete() diff --git a/packages/django-channels-postgres/pyproject.toml b/packages/django-channels-postgres/pyproject.toml new file mode 100644 index 0000000000..d15b1fe0f9 --- /dev/null +++ b/packages/django-channels-postgres/pyproject.toml @@ -0,0 +1,53 @@ +[project] +name = "django-channels-postgres" +version = "0.1.0" +description = "Django channels layer using PostgreSQL NOTIFY/LISTEN" +requires-python = ">=3.9,<3.14" +readme = "README.md" +license = "MIT" +authors = [{ name = "Authentik Security Inc.", email = "hello@goauthentik.io" }] +keywords = ["django", "channels", "postgres"] + +classifiers = [ + "Development Status :: 3 - Alpha", + "Environment :: Web Environment", + "Framework :: Django", + "Framework :: Django :: 4.2", + "Framework :: Django :: 5.0", + "Framework :: Django :: 5.1", + "Framework :: Django :: 5.2", + "Intended Audience :: Developers", + "Operating System :: MacOS", + "Operating System :: POSIX", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python", + "Topic :: Software Development :: Libraries :: Python Modules", + "Typing :: Typed", +] + +dependencies = [ + "channels >=4.3,<4.4", + "django >=4.2,<6.0", + "django-pglock >=1.7,<2", + "django-pgtrigger >=4,<5", + "msgpack >=1,<2", + "psycopg >=3,<4", + "structlog >=25,<26", +] + +[project.urls] +Homepage = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres" +Documentation = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres" +Repository = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.setuptools.packages] +find = {} diff --git a/packages/django-dramatiq-postgres/pyproject.toml b/packages/django-dramatiq-postgres/pyproject.toml index 7524a6d3cb..ab872d7b35 100644 --- a/packages/django-dramatiq-postgres/pyproject.toml +++ b/packages/django-dramatiq-postgres/pyproject.toml @@ -33,6 +33,7 @@ classifiers = [ dependencies = [ "cron-converter >=1,<2", "django >=4.2,<6.0", + "django-pglock >=1.7,<2", "django-pgtrigger >=4,<5", "dramatiq[watch] >=1.17,<1.18", "tenacity >=9,<10", diff --git a/pyproject.toml b/pyproject.toml index 341da4cc27..1e254c60e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,12 +7,12 @@ requires-python = "==3.13.*" dependencies = [ "argon2-cffi==25.1.0", "channels==4.3.1", - "channels-postgres==1.1.2", "cryptography==45.0.5", "dacite==1.9.2", "deepmerge==2.0", "defusedxml==0.7.1", "django==5.1.13", + "django-channels-postgres", "django-countries==7.6.1", "django-cte==2.0.0", "django-dramatiq-postgres", @@ -102,6 +102,7 @@ dev = [ "requests-mock==1.12.1", "ruff==0.11.9", "selenium==4.32.0", + "types-channels==4.3.0.20250822", "types-ldap3==2.9.13.20250622", ] @@ -119,13 +120,14 @@ no-binary-package = [ [tool.uv.sources] djangorestframework = { git = "https://github.com/goauthentik/django-rest-framework", rev = "896722bab969fabc74a08b827da59409cf9f1a4e" } +django-channels-postgres = { workspace = true } django-dramatiq-postgres = { workspace = true } django-postgres-cache = { workspace = true } opencontainers = { git = "https://github.com/vsoch/oci-python", rev = "ceb4fcc090851717a3069d78e85ceb1e86c2740c" } -channels-postgres = { git = "https://github.com/rissson/channels_postgres", rev = "93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" } [tool.uv.workspace] members = [ + "packages/django-channels-postgres", "packages/django-dramatiq-postgres", "packages/django-postgres-cache", ] @@ -209,7 +211,7 @@ module = ["dramatiq.*", "pglock.*"] follow_untyped_imports = true [[tool.mypy.overrides]] -module = ["cron_converter.*"] +module = ["cron_converter.*", "msgpack.*"] ignore_missing_imports = true [[tool.mypy.overrides]] diff --git a/uv.lock b/uv.lock index 1307b10341..bef3cf8989 100644 --- a/uv.lock +++ b/uv.lock @@ -1,10 +1,11 @@ version = 1 -revision = 3 +revision = 2 requires-python = "==3.13.*" [manifest] members = [ "authentik", + "django-channels-postgres", "django-dramatiq-postgres", "django-postgres-cache", ] @@ -165,12 +166,12 @@ source = { editable = "." } dependencies = [ { name = "argon2-cffi" }, { name = "channels" }, - { name = "channels-postgres" }, { name = "cryptography" }, { name = "dacite" }, { name = "deepmerge" }, { name = "defusedxml" }, { name = "django" }, + { name = "django-channels-postgres" }, { name = "django-countries" }, { name = "django-cte" }, { name = "django-dramatiq-postgres" }, @@ -260,6 +261,7 @@ dev = [ { name = "requests-mock" }, { name = "ruff" }, { name = "selenium" }, + { name = "types-channels" }, { name = "types-ldap3" }, ] @@ -267,12 +269,12 @@ dev = [ requires-dist = [ { name = "argon2-cffi", specifier = "==25.1.0" }, { name = "channels", specifier = "==4.3.1" }, - { name = "channels-postgres", git = "https://github.com/rissson/channels_postgres?rev=93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" }, { name = "cryptography", specifier = "==45.0.5" }, { name = "dacite", specifier = "==1.9.2" }, { name = "deepmerge", specifier = "==2.0" }, { name = "defusedxml", specifier = "==0.7.1" }, { name = "django", specifier = "==5.1.13" }, + { name = "django-channels-postgres", editable = "packages/django-channels-postgres" }, { name = "django-countries", specifier = "==7.6.1" }, { name = "django-cte", specifier = "==2.0.0" }, { name = "django-dramatiq-postgres", editable = "packages/django-dramatiq-postgres" }, @@ -362,6 +364,7 @@ dev = [ { name = "requests-mock", specifier = "==1.12.1" }, { name = "ruff", specifier = "==0.11.9" }, { name = "selenium", specifier = "==4.32.0" }, + { name = "types-channels", specifier = "==4.3.0.20250822" }, { name = "types-ldap3", specifier = "==2.9.13.20250622" }, ] @@ -684,17 +687,6 @@ daphne = [ { name = "daphne" }, ] -[[package]] -name = "channels-postgres" -version = "1.1.4" -source = { git = "https://github.com/rissson/channels_postgres?rev=93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f#93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" } -dependencies = [ - { name = "asgiref" }, - { name = "channels" }, - { name = "msgpack" }, - { name = "psycopg", extra = ["pool"] }, -] - [[package]] name = "charset-normalizer" version = "3.4.3" @@ -921,6 +913,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6c/f2/4b39467b74de9bb698c95232011e97dc848f490baae8d78c2e58848c4562/django-5.1.13-py3-none-any.whl", hash = "sha256:06f257f79dc4c17f3f9e23b106a4c5ed1335abecbe731e83c598c941d14fbeed", size = 8277515, upload-time = "2025-10-01T14:25:28.65Z" }, ] +[[package]] +name = "django-channels-postgres" +version = "0.1.0" +source = { editable = "packages/django-channels-postgres" } +dependencies = [ + { name = "channels" }, + { name = "django" }, + { name = "django-pglock" }, + { name = "django-pgtrigger" }, + { name = "msgpack" }, + { name = "psycopg" }, + { name = "structlog" }, +] + +[package.metadata] +requires-dist = [ + { name = "channels", specifier = ">=4.3,<4.4" }, + { name = "django", specifier = ">=4.2,<6.0" }, + { name = "django-pglock", specifier = ">=1.7,<2" }, + { name = "django-pgtrigger", specifier = ">=4,<5" }, + { name = "msgpack", specifier = ">=1,<2" }, + { name = "psycopg", specifier = ">=3,<4" }, + { name = "structlog", specifier = ">=25,<26" }, +] + [[package]] name = "django-countries" version = "7.6.1" @@ -953,6 +970,7 @@ source = { editable = "packages/django-dramatiq-postgres" } dependencies = [ { name = "cron-converter" }, { name = "django" }, + { name = "django-pglock" }, { name = "django-pgtrigger" }, { name = "dramatiq", extra = ["watch"] }, { name = "structlog" }, @@ -963,6 +981,7 @@ dependencies = [ requires-dist = [ { name = "cron-converter", specifier = ">=1,<2" }, { name = "django", specifier = ">=4.2,<6.0" }, + { name = "django-pglock", specifier = ">=1.7,<2" }, { name = "django-pgtrigger", specifier = ">=4,<5" }, { name = "dramatiq", extras = ["watch"], specifier = ">=1.17,<1.18" }, { name = "structlog", specifier = ">=25,<26" }, @@ -2821,15 +2840,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" }, ] -[[package]] -name = "redis" -version = "6.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" }, -] - [[package]] name = "referencing" version = "0.36.2" @@ -3265,6 +3275,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/bb/d43e5c75054e53efce310e79d63df0ac3f25e34c926be5dffb7d283fb2a8/typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1", size = 17605, upload-time = "2021-12-10T21:09:37.844Z" }, ] +[[package]] +name = "types-channels" +version = "4.3.0.20250822" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asgiref" }, + { name = "django-stubs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/22/3d/e0a164b7eaab18ab2302b7a3d90843824b729d9fe1d7bdbb3769dfc9d77b/types_channels-4.3.0.20250822.tar.gz", hash = "sha256:29a4928fdaed6d444b93b69d44fcdb5a8fe32fa72d6a41016c5d39fa7bd7f474", size = 15357, upload-time = "2025-08-22T03:04:26.444Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/52/4e3094e43d460feacb9051ec4c3498f8272f69d92b772647211478b25079/types_channels-4.3.0.20250822-py3-none-any.whl", hash = "sha256:d3fc0a1467c8cc901686826408c8a673822e07aa79cbe1a6d21946e7e55d9ddf", size = 21125, upload-time = "2025-08-22T03:04:25.539Z" }, +] + [[package]] name = "types-ldap3" version = "2.9.13.20250622"