mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
packages/django-channels-postgres: init (#17247)
This commit is contained in:
committed by
GitHub
parent
bd421e5815
commit
4fb61bb991
@@ -24,6 +24,7 @@ Makefile @goauthentik/infrastructure
|
||||
.editorconfig @goauthentik/infrastructure
|
||||
CODEOWNERS @goauthentik/infrastructure
|
||||
# Backend packages
|
||||
packages/django-channels-postgres @goauthentik/backend
|
||||
packages/django-postgres-cache @goauthentik/backend
|
||||
packages/django-dramatiq-postgres @goauthentik/backend
|
||||
# Web packages
|
||||
|
||||
@@ -4,6 +4,7 @@ from datetime import datetime, timedelta
|
||||
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
from django_channels_postgres.models import GroupChannel, Message
|
||||
from django_postgres_cache.tasks import clear_expired_cache
|
||||
from dramatiq.actor import actor
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -34,6 +35,8 @@ def clean_expired_models():
|
||||
LOGGER.debug("Expired models", model=cls, amount=amount)
|
||||
self.info(f"Expired {amount} {cls._meta.verbose_name_plural}")
|
||||
clear_expired_cache()
|
||||
Message.delete_expired()
|
||||
GroupChannel.delete_expired()
|
||||
|
||||
|
||||
@actor(description=_("Remove temporary users created by SAML Sources."))
|
||||
|
||||
@@ -112,7 +112,6 @@ def get_logger_config():
|
||||
"hpack": "WARNING",
|
||||
"httpx": "WARNING",
|
||||
"azure": "WARNING",
|
||||
"channels_postgres": "WARNING",
|
||||
}
|
||||
for handler_name, level in handler_level_map.items():
|
||||
base_config["loggers"][handler_name] = {
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from channels_postgres.core import PostgresChannelLayer as BasePostgresChannelLayer
|
||||
from channels_postgres.db import DatabaseLayer as BaseDatabaseLayer
|
||||
from django.conf import settings
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
|
||||
from authentik.root.db.base import DatabaseWrapper
|
||||
|
||||
|
||||
class DatabaseLayer(BaseDatabaseLayer):
|
||||
async def get_db_pool(self, db_params: dict[str, Any]) -> AsyncConnectionPool:
|
||||
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
|
||||
db_params = db_wrapper.get_connection_params()
|
||||
db_params.pop("cursor_factory")
|
||||
db_params.pop("context")
|
||||
return await super().get_db_pool(db_params)
|
||||
|
||||
|
||||
class PostgresChannelLayer(BasePostgresChannelLayer):
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.django_db = DatabaseLayer(self.django_db.psycopg_options, self.db_params)
|
||||
|
||||
@property
|
||||
def db_params(self):
|
||||
db_wrapper = DatabaseWrapper(settings.CHANNEL_LAYERS["default"]["CONFIG"])
|
||||
db_params = db_wrapper.get_connection_params()
|
||||
db_params.pop("cursor_factory")
|
||||
db_params.pop("context")
|
||||
return db_params
|
||||
|
||||
@db_params.setter
|
||||
def db_params(self, value):
|
||||
pass
|
||||
@@ -64,7 +64,7 @@ SHARED_APPS = [
|
||||
"pgactivity",
|
||||
"pglock",
|
||||
"channels",
|
||||
"channels_postgres",
|
||||
"django_channels_postgres",
|
||||
"django_dramatiq_postgres",
|
||||
"authentik.tasks",
|
||||
]
|
||||
@@ -304,11 +304,7 @@ DATABASE_ROUTERS = (
|
||||
|
||||
CHANNEL_LAYERS = {
|
||||
"default": {
|
||||
"BACKEND": "authentik.root.channels.PostgresChannelLayer",
|
||||
"CONFIG": {
|
||||
**DATABASES["default"],
|
||||
"TIME_ZONE": None,
|
||||
},
|
||||
"BACKEND": "django_channels_postgres.layer.PostgresChannelLayer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -62,11 +62,6 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
|
||||
"""Configure test environment settings"""
|
||||
settings.TEST = True
|
||||
settings.DRAMATIQ["test"] = True
|
||||
settings.CHANNEL_LAYERS["default"]["CONFIG"] = {
|
||||
**settings.DATABASES["default"],
|
||||
**settings.DATABASES["default"]["TEST"],
|
||||
"TIME_ZONE": None,
|
||||
}
|
||||
|
||||
# Test-specific configuration
|
||||
test_config = {
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# django-channels-postgres
|
||||
@@ -0,0 +1,5 @@
|
||||
from django.apps import AppConfig
|
||||
|
||||
|
||||
class DjangoChannelsPostgresConfig(AppConfig):
|
||||
name = "django_channels_postgres"
|
||||
@@ -0,0 +1,472 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import types
|
||||
from base64 import b64decode
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from re import Pattern
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import msgpack
|
||||
from channels.layers import BaseChannelLayer
|
||||
from django.db import DEFAULT_DB_ALIAS, connections
|
||||
from django.utils.timezone import now
|
||||
from pglock.core import _cast_lock_id
|
||||
from psycopg import AsyncConnection, Notify, sql
|
||||
from psycopg.conninfo import make_conninfo
|
||||
from psycopg.errors import Error as PsycopgError
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from django_channels_postgres.models import NOTIFY_CHANNEL, GroupChannel, Message
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table
|
||||
MESSAGE_TABLE = Message._meta.db_table
|
||||
|
||||
|
||||
async def _async_proxy(
|
||||
obj: "PostgresChannelLayerLoopProxy",
|
||||
name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
# Must be defined as a function and not a method due to
|
||||
# https://bugs.python.org/issue38364
|
||||
layer = obj._get_layer()
|
||||
return await getattr(layer, name)(*args, **kwargs)
|
||||
|
||||
|
||||
def _wrap_close(proxy: "PostgresChannelLayerLoopProxy", loop: asyncio.AbstractEventLoop) -> None:
|
||||
original_impl = loop.close
|
||||
|
||||
def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None:
|
||||
if loop in proxy._layers:
|
||||
layer = proxy._layers[loop]
|
||||
del proxy._layers[loop]
|
||||
loop.run_until_complete(layer.flush())
|
||||
self.close = original_impl # type: ignore[method-assign]
|
||||
return self.close(*args, **kwargs)
|
||||
|
||||
loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign]
|
||||
|
||||
|
||||
class PostgresChannelLayerLoopProxy:
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._args = args
|
||||
self._kwargs = kwargs
|
||||
self._kwargs["channel_layer"] = self
|
||||
self._layers: dict[asyncio.AbstractEventLoop, PostgresChannelLoopLayer] = {}
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in (
|
||||
"new_channel",
|
||||
"send",
|
||||
"receive",
|
||||
"group_add",
|
||||
"group_discard",
|
||||
"group_send",
|
||||
"flush",
|
||||
):
|
||||
return functools.partial(_async_proxy, self, name)
|
||||
else:
|
||||
return getattr(self._get_layer(), name)
|
||||
|
||||
def serialize(self, message: dict[str, Any]) -> bytes:
|
||||
"""Serializes message to a byte string."""
|
||||
return cast(bytes, msgpack.packb(message, use_bin_type=True))
|
||||
|
||||
def deserialize(self, message: bytes) -> dict[str, Any]:
|
||||
"""Deserializes from a byte string."""
|
||||
return cast(dict[str, Any], msgpack.unpackb(message, raw=False))
|
||||
|
||||
def _get_layer(self) -> "PostgresChannelLoopLayer":
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
layer = self._layers[loop]
|
||||
except KeyError:
|
||||
layer = PostgresChannelLoopLayer(*self._args, **self._kwargs)
|
||||
self._layers[loop] = layer
|
||||
_wrap_close(self, loop)
|
||||
|
||||
return layer
|
||||
|
||||
|
||||
PostgresChannelLayer = PostgresChannelLayerLoopProxy
|
||||
|
||||
|
||||
class PostgresChannelLoopLayer(BaseChannelLayer):
|
||||
"""
|
||||
Postgres channel layer.
|
||||
|
||||
It uses the NOTIFY/LISTEN functionality of postgres to broadcast messages
|
||||
|
||||
It also makes use of an internal message table to overcome the
|
||||
8000bytes limit of Postgres' NOTIFY messages.
|
||||
Which is a far cry from the channels standard of 1MB
|
||||
This table has a trigger that sends out the `NOTIFY` signal.
|
||||
|
||||
Using a database also means messages are durable and will always be
|
||||
available to consumers (as long as they're not expired).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channel_layer: PostgresChannelLayerLoopProxy,
|
||||
prefix: str = "asgi",
|
||||
expiry: int = 60,
|
||||
group_expiry: int = 86400,
|
||||
capacity: int = 100,
|
||||
channel_capacity: dict[Pattern[str] | str, int] | None = None,
|
||||
using: str = DEFAULT_DB_ALIAS,
|
||||
) -> None:
|
||||
super().__init__(expiry=expiry, capacity=capacity, channel_capacity=channel_capacity)
|
||||
|
||||
self.group_expiry = group_expiry
|
||||
self.prefix = prefix
|
||||
assert isinstance(self.prefix, str), "Prefix must be unicode" # nosec
|
||||
self.channel_layer = channel_layer
|
||||
self.using = using
|
||||
|
||||
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
|
||||
# This dict maps `channel_name` to a queue of messages for that channel.
|
||||
self.channels: dict[str, asyncio.Queue[bytes]] = {}
|
||||
|
||||
self.connection = PostgresChannelLayerConnection(self.using, self)
|
||||
|
||||
extensions = ["groups", "flush"]
|
||||
|
||||
### Channel layer API ###
|
||||
|
||||
async def send(self, channel: str, message: dict[str, Any]) -> None:
|
||||
"""
|
||||
Send a message onto a (general or specific) channel.
|
||||
"""
|
||||
# Typecheck
|
||||
assert isinstance(message, dict), "message is not a dict" # nosec
|
||||
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
||||
# Make sure the message does not contain reserved keys
|
||||
assert "__asgi_channel__" not in message # nosec
|
||||
|
||||
await Message.objects.using(self.using).acreate(
|
||||
channel=channel,
|
||||
message=self.channel_layer.serialize(message),
|
||||
expires=now() + timedelta(seconds=self.expiry),
|
||||
)
|
||||
|
||||
async def new_channel(self, prefix: str = "specific") -> str:
|
||||
"""
|
||||
Returns a new channel name that can be used by something in our
|
||||
process as a specific channel.
|
||||
"""
|
||||
return f"{self.prefix}.{prefix}.{uuid4().hex}"
|
||||
|
||||
async def receive(self, channel: str) -> dict[str, Any]:
|
||||
"""
|
||||
Receive the first message that arrives on the channel.
|
||||
If more than one coroutine waits on the same channel, the first waiter
|
||||
will be given the message when it arrives.
|
||||
|
||||
This is done by acquiring an `advistory_lock` from the database
|
||||
based on the channel name.
|
||||
|
||||
If the lock is acquired successfully, subsequent calls to this method
|
||||
will not try to acquire the lock again.
|
||||
_The lock is session based and should be released by postgres when
|
||||
the session is closed_
|
||||
|
||||
If the lock is already acquired by another coroutine,
|
||||
subsequent calls to this method will repeatedly try to acquire the lock
|
||||
before proceeding to wait for a message.
|
||||
"""
|
||||
if channel not in self.channels:
|
||||
self.channels[channel] = asyncio.Queue()
|
||||
await self.connection.subscribe(channel)
|
||||
|
||||
q = self.channels[channel]
|
||||
try:
|
||||
message = await q.get()
|
||||
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
|
||||
# We assume here that the reason we are cancelled is because the consumer
|
||||
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
|
||||
# currently the way that Django Channels works, this is a safe assumption.
|
||||
# In the future, Django Channels could change to call a *new* method that
|
||||
# would serve as the antithesis of `new_channel()`; this new method might
|
||||
# be named `delete_channel()`. If that were the case, we would do the
|
||||
# following cleanup from that new `delete_channel()` method, but, since
|
||||
# that's not how Django Channels works (yet), we do the cleanup below:
|
||||
if channel in self.channels:
|
||||
del self.channels[channel]
|
||||
try:
|
||||
await self.connection.unsubscribe(channel)
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc)
|
||||
# We don't re-raise here because we want the CancelledError to be the one
|
||||
# re-raised
|
||||
raise
|
||||
return self.channel_layer.deserialize(message)
|
||||
|
||||
# ==============================================================
|
||||
# Groups extension
|
||||
# ==============================================================
|
||||
|
||||
async def group_add(self, group: str, channel: str) -> None:
|
||||
"""
|
||||
Adds the channel name to a group.
|
||||
"""
|
||||
# Check the inputs
|
||||
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
||||
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
||||
|
||||
group_key = self._group_key(group)
|
||||
|
||||
await GroupChannel.objects.using(self.using).aupdate_or_create(
|
||||
group_key=group_key,
|
||||
channel=channel,
|
||||
defaults={
|
||||
"expires": now() + timedelta(seconds=self.group_expiry),
|
||||
},
|
||||
)
|
||||
|
||||
async def group_discard(self, group: str, channel: str) -> None:
|
||||
"""
|
||||
Removes the channel from the named group if it is in the group;
|
||||
does nothing otherwise (does not error)
|
||||
"""
|
||||
# Check the inputs
|
||||
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
||||
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
||||
|
||||
group_key = self._group_key(group)
|
||||
|
||||
await (
|
||||
GroupChannel.objects.using(self.using)
|
||||
.filter(group_key=group_key, channel=channel)
|
||||
.adelete()
|
||||
)
|
||||
|
||||
async def group_send(self, group: str, message: dict[str, Any]) -> None:
|
||||
"""
|
||||
Sends a message to the entire group.
|
||||
"""
|
||||
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
||||
|
||||
group_key = self._group_key(group)
|
||||
|
||||
serialized_message = self.channel_layer.serialize(message)
|
||||
messages = [
|
||||
Message(
|
||||
channel=channel,
|
||||
message=serialized_message,
|
||||
expires=now() + timedelta(seconds=self.expiry),
|
||||
)
|
||||
async for channel in GroupChannel.objects.using(self.using)
|
||||
.filter(group_key=group_key, expires__gte=now())
|
||||
.values_list("channel", flat=True)
|
||||
.distinct()
|
||||
]
|
||||
await Message.objects.using(self.using).abulk_create(messages)
|
||||
|
||||
def _group_key(self, group: str) -> str:
|
||||
"""
|
||||
Common function to make the storage key for the group.
|
||||
"""
|
||||
return f"{self.prefix}.group.{group}"
|
||||
|
||||
### Flush extension ###
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""
|
||||
Deletes all messages and groups.
|
||||
"""
|
||||
self.channels = {}
|
||||
await self.connection.flush()
|
||||
|
||||
|
||||
class PostgresChannelLayerConnection:
|
||||
def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None:
|
||||
self.using = using
|
||||
self.channel_layer = channel_layer
|
||||
self._subscribed_to: set[str] = set()
|
||||
self._locked_channels: set[str] = set()
|
||||
self._lock = asyncio.Lock()
|
||||
self._receive_task: asyncio.Task[None] | None = None
|
||||
|
||||
async def subscribe(self, channel: str) -> None:
|
||||
async with self._lock:
|
||||
if channel not in self._subscribed_to:
|
||||
self._ensure_receiver()
|
||||
self._subscribed_to.add(channel)
|
||||
|
||||
async def unsubscribe(self, channel: str) -> None:
|
||||
async with self._lock:
|
||||
if channel in self._subscribed_to:
|
||||
self._ensure_receiver()
|
||||
self._subscribed_to.remove(channel)
|
||||
|
||||
async def flush(self) -> None:
|
||||
async with self._lock:
|
||||
if self._receive_task is not None:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._receive_task = None
|
||||
self._subscribed_to = set()
|
||||
|
||||
async def _do_receiving(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
async with await self._create_connection() as conn:
|
||||
await self._update_locks(conn)
|
||||
while True:
|
||||
await self._process_backlog(conn)
|
||||
await conn.execute(
|
||||
sql.SQL("LISTEN {channel}").format(
|
||||
channel=sql.Identifier(NOTIFY_CHANNEL)
|
||||
)
|
||||
)
|
||||
first_loop = True
|
||||
async for notify in conn.notifies(stop_after=1, timeout=5):
|
||||
if first_loop:
|
||||
await self._update_locks(conn)
|
||||
first_loop = False
|
||||
await self._receive_notify(conn, notify)
|
||||
if first_loop:
|
||||
await self._update_locks(conn)
|
||||
except (asyncio.CancelledError, TimeoutError, GeneratorExit):
|
||||
raise
|
||||
except PsycopgError as exc:
|
||||
LOGGER.warning("Postgres connection is not healthy", exc=exc)
|
||||
except BaseException as exc: # noqa: BLE001
|
||||
LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True)
|
||||
self._locked_channels = set()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def _process_backlog(self, conn: AsyncConnection) -> None:
|
||||
if not self._locked_channels:
|
||||
return
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
DELETE
|
||||
FROM {table}
|
||||
WHERE
|
||||
{table}.{channel} IN (%s)
|
||||
AND {table}.{expires} >= %s
|
||||
RETURNING {table}.{channel}, {table}.{message}
|
||||
"""
|
||||
).format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
channel=sql.Identifier("channel"),
|
||||
expires=sql.Identifier("expires"),
|
||||
message=sql.Identifier("message"),
|
||||
),
|
||||
(tuple(self._locked_channels), now()),
|
||||
)
|
||||
async for row in cursor:
|
||||
channel, message = row
|
||||
self._receive_message(channel, message)
|
||||
|
||||
def _get_lock_id(self, channel: str) -> int:
|
||||
lock_id = _cast_lock_id(f"channels.{channel}") # type: ignore[no-untyped-call]
|
||||
return cast(int, lock_id)
|
||||
|
||||
async def _update_locks(self, conn: AsyncConnection) -> None:
|
||||
async with self._lock:
|
||||
locks_to_release = self._locked_channels - self._subscribed_to
|
||||
locks_to_acquire = self._subscribed_to - self._locked_channels
|
||||
|
||||
for channel in locks_to_acquire:
|
||||
lock_id = self._get_lock_id(channel)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT pg_try_advisory_lock(%s)", (lock_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is not None:
|
||||
if row[0]:
|
||||
self._locked_channels.add(channel)
|
||||
|
||||
for channel in locks_to_release:
|
||||
lock_id = self._get_lock_id(channel)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT pg_advisory_unlock(%s)", (lock_id,))
|
||||
row = await cursor.fetchone()
|
||||
if row is not None:
|
||||
if row[0]:
|
||||
self._locked_channels.remove(channel)
|
||||
|
||||
async def _receive_notify(self, conn: AsyncConnection, notify: Notify) -> None:
|
||||
payload = notify.payload
|
||||
split_payload = payload.split(":")
|
||||
match len(split_payload):
|
||||
case 4:
|
||||
message_id, channel, timestamp, base64_message = split_payload
|
||||
if channel not in self._locked_channels:
|
||||
return
|
||||
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
||||
if expires < now():
|
||||
return
|
||||
message = b64decode(base64_message)
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
sql.SQL("DELETE FROM {table} WHERE {table}.{id} = %s").format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
)
|
||||
)
|
||||
case 3:
|
||||
message_id, channel, timestamp = split_payload
|
||||
if channel not in self._locked_channels:
|
||||
return
|
||||
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
||||
if expires < now():
|
||||
return
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute(
|
||||
sql.SQL(
|
||||
"""
|
||||
DELETE
|
||||
FROM {table}
|
||||
WHERE
|
||||
{table}.{id} = %s
|
||||
RETURNING {table}.{message}, {table}.{expires}
|
||||
"""
|
||||
).format(
|
||||
table=sql.Identifier(MESSAGE_TABLE),
|
||||
id=sql.Identifier("id"),
|
||||
message=sql.Identifier("message"),
|
||||
expires=sql.Identifier("expires"),
|
||||
),
|
||||
(message_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
if row is None:
|
||||
return
|
||||
message, expires = row
|
||||
case _:
|
||||
return
|
||||
self._receive_message(channel, message)
|
||||
|
||||
def _receive_message(self, channel: str, message: bytes) -> None:
|
||||
if (q := self.channel_layer.channels.get(channel)) is not None:
|
||||
q.put_nowait(message)
|
||||
|
||||
def _ensure_receiver(self) -> None:
|
||||
if self._receive_task is None:
|
||||
self._receive_task = asyncio.ensure_future(self._do_receiving())
|
||||
|
||||
async def _create_connection(self) -> AsyncConnection:
|
||||
db_params = connections[self.using].get_connection_params()
|
||||
# Prevent psycopg from using the custom synchronous cursor factory from django
|
||||
db_params.pop("cursor_factory")
|
||||
db_params.pop("context")
|
||||
conninfo = make_conninfo(conninfo="", **db_params, connect_timeout=10)
|
||||
return await AsyncConnection.connect(conninfo=conninfo, autocommit=True)
|
||||
@@ -0,0 +1,94 @@
|
||||
# Generated by Django 5.1.13 on 2025-10-04 14:35
|
||||
|
||||
import django_channels_postgres.models
|
||||
import pgtrigger.compiler
|
||||
import pgtrigger.migrations
|
||||
import uuid
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
initial = True
|
||||
|
||||
dependencies = []
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="GroupChannel",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("group_key", models.TextField(db_index=True)),
|
||||
("channel", models.TextField(db_index=True)),
|
||||
(
|
||||
"expires",
|
||||
models.DateTimeField(
|
||||
db_index=True, default=django_channels_postgres.models._default_group_expiry
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Group channel",
|
||||
"verbose_name_plural": "Group channels",
|
||||
"indexes": [
|
||||
models.Index(
|
||||
fields=["group_key", "channel"], name="django_chan_group_k_173f44_idx"
|
||||
),
|
||||
models.Index(
|
||||
fields=["group_key", "expires"], name="django_chan_group_k_45d2e0_idx"
|
||||
),
|
||||
],
|
||||
},
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Message",
|
||||
fields=[
|
||||
(
|
||||
"id",
|
||||
models.UUIDField(
|
||||
default=uuid.uuid4, editable=False, primary_key=True, serialize=False
|
||||
),
|
||||
),
|
||||
("channel", models.TextField(db_index=True)),
|
||||
("message", models.BinaryField()),
|
||||
(
|
||||
"expires",
|
||||
models.DateTimeField(
|
||||
db_index=True,
|
||||
default=django_channels_postgres.models._default_message_expiry,
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "Message",
|
||||
"verbose_name_plural": "Messages",
|
||||
"indexes": [
|
||||
models.Index(
|
||||
fields=["channel", "expires"], name="django_chan_channel_e8ca51_idx"
|
||||
)
|
||||
],
|
||||
},
|
||||
),
|
||||
pgtrigger.migrations.AddTrigger(
|
||||
model_name="message",
|
||||
trigger=pgtrigger.compiler.Trigger(
|
||||
name="notify_new_channels_message",
|
||||
sql=pgtrigger.compiler.UpsertTriggerSql(
|
||||
constraint="CONSTRAINT",
|
||||
declare="DECLARE payload text; encoded_message text; epoch text;",
|
||||
func="\n encoded_message := encode(NEW.message, 'base64');\n epoch := extract(epoch from NEW.expires)::text;\n IF octet_length(NEW.id::text) + octet_length(NEW.channel) + octet_length(epoch) + octet_length(encoded_message) + 3 <= 8000 THEN\n payload := NEW.id::text || ':' || NEW.channel || ':' || epoch || ':' || encoded_message;\n ELSE\n payload := NEW.id::text || ':' || NEW.channel || ':' || epoch;\n END IF;\n\n PERFORM pg_notify('channels_messages', payload);\n RETURN NEW;\n ",
|
||||
hash="cf7a665df0bbb7d865cdbc92b63d818fd25733d8",
|
||||
operation="INSERT",
|
||||
pgid="pgtrigger_notify_new_channels_message_d21ae",
|
||||
table="django_channels_postgres_message",
|
||||
timing="DEFERRABLE INITIALLY DEFERRED",
|
||||
when="AFTER",
|
||||
),
|
||||
),
|
||||
),
|
||||
]
|
||||
@@ -0,0 +1,97 @@
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pgtrigger
|
||||
from django.db import models
|
||||
from django.utils.timezone import now
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
def _default_group_expiry() -> datetime:
|
||||
return now() + timedelta(seconds=86400)
|
||||
|
||||
|
||||
def _default_message_expiry() -> datetime:
|
||||
return now() + timedelta(minutes=1)
|
||||
|
||||
|
||||
NOTIFY_CHANNEL = "channels_messages"
|
||||
|
||||
|
||||
class GroupChannel(models.Model):
|
||||
"""
|
||||
A model that represents a group channel.
|
||||
|
||||
Groups are used to send messages to multiple channels.
|
||||
"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
group_key = models.TextField(db_index=True)
|
||||
channel = models.TextField(db_index=True)
|
||||
expires = models.DateTimeField(db_index=True, default=_default_group_expiry)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Group channel")
|
||||
verbose_name_plural = _("Group channels")
|
||||
indexes = (
|
||||
models.Index(fields=("group_key", "channel")),
|
||||
models.Index(fields=("group_key", "expires")),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Group '{self.group_key}' on channel '{self.channel}'"
|
||||
|
||||
@classmethod
|
||||
def delete_expired(cls) -> None:
|
||||
cls.objects.filter(expires__lt=now()).delete()
|
||||
|
||||
|
||||
class Message(models.Model):
|
||||
"""
|
||||
A model that represents a message.
|
||||
|
||||
Messages are used to send messages to a specific channel.
|
||||
E.g for user to user private messages.
|
||||
"""
|
||||
|
||||
id = models.UUIDField(primary_key=True, editable=False, default=uuid4)
|
||||
channel = models.TextField(db_index=True)
|
||||
message = models.BinaryField()
|
||||
expires = models.DateTimeField(db_index=True, default=_default_message_expiry)
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("Message")
|
||||
verbose_name_plural = _("Messages")
|
||||
indexes = (models.Index(fields=("channel", "expires")),)
|
||||
triggers = (
|
||||
pgtrigger.Trigger(
|
||||
name="notify_new_channels_message",
|
||||
operation=pgtrigger.Insert,
|
||||
when=pgtrigger.After,
|
||||
timing=pgtrigger.Deferred,
|
||||
declare=[
|
||||
("payload", "text"),
|
||||
("encoded_message", "text"),
|
||||
("epoch", "text"),
|
||||
],
|
||||
func=f"""
|
||||
encoded_message := encode(NEW.message, 'base64');
|
||||
epoch := extract(epoch from NEW.expires)::text;
|
||||
IF octet_length(NEW.id::text) + octet_length(NEW.channel) + octet_length(epoch) + octet_length(encoded_message) + 3 <= 8000 THEN
|
||||
payload := NEW.id::text || ':' || NEW.channel || ':' || epoch || ':' || encoded_message;
|
||||
ELSE
|
||||
payload := NEW.id::text || ':' || NEW.channel || ':' || epoch;
|
||||
END IF;
|
||||
|
||||
PERFORM pg_notify('{NOTIFY_CHANNEL}', payload);
|
||||
RETURN NEW;
|
||||
""", # noqa: E501
|
||||
),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"Message '{self.pk}' on channel '{self.channel}'"
|
||||
|
||||
@classmethod
|
||||
def delete_expired(cls) -> None:
|
||||
cls.objects.filter(expires__lt=now()).delete()
|
||||
@@ -0,0 +1,53 @@
|
||||
[project]
|
||||
name = "django-channels-postgres"
|
||||
version = "0.1.0"
|
||||
description = "Django channels layer using PostgreSQL NOTIFY/LISTEN"
|
||||
requires-python = ">=3.9,<3.14"
|
||||
readme = "README.md"
|
||||
license = "MIT"
|
||||
authors = [{ name = "Authentik Security Inc.", email = "hello@goauthentik.io" }]
|
||||
keywords = ["django", "channels", "postgres"]
|
||||
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Environment :: Web Environment",
|
||||
"Framework :: Django",
|
||||
"Framework :: Django :: 4.2",
|
||||
"Framework :: Django :: 5.0",
|
||||
"Framework :: Django :: 5.1",
|
||||
"Framework :: Django :: 5.2",
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: POSIX",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Programming Language :: Python",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"channels >=4.3,<4.4",
|
||||
"django >=4.2,<6.0",
|
||||
"django-pglock >=1.7,<2",
|
||||
"django-pgtrigger >=4,<5",
|
||||
"msgpack >=1,<2",
|
||||
"psycopg >=3,<4",
|
||||
"structlog >=25,<26",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres"
|
||||
Documentation = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres"
|
||||
Repository = "https://github.com/goauthentik/authentik/tree/main/packages/django-channels-postgres"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.setuptools.packages]
|
||||
find = {}
|
||||
@@ -33,6 +33,7 @@ classifiers = [
|
||||
dependencies = [
|
||||
"cron-converter >=1,<2",
|
||||
"django >=4.2,<6.0",
|
||||
"django-pglock >=1.7,<2",
|
||||
"django-pgtrigger >=4,<5",
|
||||
"dramatiq[watch] >=1.17,<1.18",
|
||||
"tenacity >=9,<10",
|
||||
|
||||
+5
-3
@@ -7,12 +7,12 @@ requires-python = "==3.13.*"
|
||||
dependencies = [
|
||||
"argon2-cffi==25.1.0",
|
||||
"channels==4.3.1",
|
||||
"channels-postgres==1.1.2",
|
||||
"cryptography==45.0.5",
|
||||
"dacite==1.9.2",
|
||||
"deepmerge==2.0",
|
||||
"defusedxml==0.7.1",
|
||||
"django==5.1.13",
|
||||
"django-channels-postgres",
|
||||
"django-countries==7.6.1",
|
||||
"django-cte==2.0.0",
|
||||
"django-dramatiq-postgres",
|
||||
@@ -102,6 +102,7 @@ dev = [
|
||||
"requests-mock==1.12.1",
|
||||
"ruff==0.11.9",
|
||||
"selenium==4.32.0",
|
||||
"types-channels==4.3.0.20250822",
|
||||
"types-ldap3==2.9.13.20250622",
|
||||
]
|
||||
|
||||
@@ -119,13 +120,14 @@ no-binary-package = [
|
||||
|
||||
[tool.uv.sources]
|
||||
djangorestframework = { git = "https://github.com/goauthentik/django-rest-framework", rev = "896722bab969fabc74a08b827da59409cf9f1a4e" }
|
||||
django-channels-postgres = { workspace = true }
|
||||
django-dramatiq-postgres = { workspace = true }
|
||||
django-postgres-cache = { workspace = true }
|
||||
opencontainers = { git = "https://github.com/vsoch/oci-python", rev = "ceb4fcc090851717a3069d78e85ceb1e86c2740c" }
|
||||
channels-postgres = { git = "https://github.com/rissson/channels_postgres", rev = "93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" }
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"packages/django-channels-postgres",
|
||||
"packages/django-dramatiq-postgres",
|
||||
"packages/django-postgres-cache",
|
||||
]
|
||||
@@ -209,7 +211,7 @@ module = ["dramatiq.*", "pglock.*"]
|
||||
follow_untyped_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["cron_converter.*"]
|
||||
module = ["cron_converter.*", "msgpack.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = "==3.13.*"
|
||||
|
||||
[manifest]
|
||||
members = [
|
||||
"authentik",
|
||||
"django-channels-postgres",
|
||||
"django-dramatiq-postgres",
|
||||
"django-postgres-cache",
|
||||
]
|
||||
@@ -165,12 +166,12 @@ source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "argon2-cffi" },
|
||||
{ name = "channels" },
|
||||
{ name = "channels-postgres" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "dacite" },
|
||||
{ name = "deepmerge" },
|
||||
{ name = "defusedxml" },
|
||||
{ name = "django" },
|
||||
{ name = "django-channels-postgres" },
|
||||
{ name = "django-countries" },
|
||||
{ name = "django-cte" },
|
||||
{ name = "django-dramatiq-postgres" },
|
||||
@@ -260,6 +261,7 @@ dev = [
|
||||
{ name = "requests-mock" },
|
||||
{ name = "ruff" },
|
||||
{ name = "selenium" },
|
||||
{ name = "types-channels" },
|
||||
{ name = "types-ldap3" },
|
||||
]
|
||||
|
||||
@@ -267,12 +269,12 @@ dev = [
|
||||
requires-dist = [
|
||||
{ name = "argon2-cffi", specifier = "==25.1.0" },
|
||||
{ name = "channels", specifier = "==4.3.1" },
|
||||
{ name = "channels-postgres", git = "https://github.com/rissson/channels_postgres?rev=93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" },
|
||||
{ name = "cryptography", specifier = "==45.0.5" },
|
||||
{ name = "dacite", specifier = "==1.9.2" },
|
||||
{ name = "deepmerge", specifier = "==2.0" },
|
||||
{ name = "defusedxml", specifier = "==0.7.1" },
|
||||
{ name = "django", specifier = "==5.1.13" },
|
||||
{ name = "django-channels-postgres", editable = "packages/django-channels-postgres" },
|
||||
{ name = "django-countries", specifier = "==7.6.1" },
|
||||
{ name = "django-cte", specifier = "==2.0.0" },
|
||||
{ name = "django-dramatiq-postgres", editable = "packages/django-dramatiq-postgres" },
|
||||
@@ -362,6 +364,7 @@ dev = [
|
||||
{ name = "requests-mock", specifier = "==1.12.1" },
|
||||
{ name = "ruff", specifier = "==0.11.9" },
|
||||
{ name = "selenium", specifier = "==4.32.0" },
|
||||
{ name = "types-channels", specifier = "==4.3.0.20250822" },
|
||||
{ name = "types-ldap3", specifier = "==2.9.13.20250622" },
|
||||
]
|
||||
|
||||
@@ -684,17 +687,6 @@ daphne = [
|
||||
{ name = "daphne" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "channels-postgres"
|
||||
version = "1.1.4"
|
||||
source = { git = "https://github.com/rissson/channels_postgres?rev=93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f#93ed24e3c5317d7ccf7e8b1ce913c0f365d1728f" }
|
||||
dependencies = [
|
||||
{ name = "asgiref" },
|
||||
{ name = "channels" },
|
||||
{ name = "msgpack" },
|
||||
{ name = "psycopg", extra = ["pool"] },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "charset-normalizer"
|
||||
version = "3.4.3"
|
||||
@@ -921,6 +913,31 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/6c/f2/4b39467b74de9bb698c95232011e97dc848f490baae8d78c2e58848c4562/django-5.1.13-py3-none-any.whl", hash = "sha256:06f257f79dc4c17f3f9e23b106a4c5ed1335abecbe731e83c598c941d14fbeed", size = 8277515, upload-time = "2025-10-01T14:25:28.65Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-channels-postgres"
|
||||
version = "0.1.0"
|
||||
source = { editable = "packages/django-channels-postgres" }
|
||||
dependencies = [
|
||||
{ name = "channels" },
|
||||
{ name = "django" },
|
||||
{ name = "django-pglock" },
|
||||
{ name = "django-pgtrigger" },
|
||||
{ name = "msgpack" },
|
||||
{ name = "psycopg" },
|
||||
{ name = "structlog" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "channels", specifier = ">=4.3,<4.4" },
|
||||
{ name = "django", specifier = ">=4.2,<6.0" },
|
||||
{ name = "django-pglock", specifier = ">=1.7,<2" },
|
||||
{ name = "django-pgtrigger", specifier = ">=4,<5" },
|
||||
{ name = "msgpack", specifier = ">=1,<2" },
|
||||
{ name = "psycopg", specifier = ">=3,<4" },
|
||||
{ name = "structlog", specifier = ">=25,<26" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "django-countries"
|
||||
version = "7.6.1"
|
||||
@@ -953,6 +970,7 @@ source = { editable = "packages/django-dramatiq-postgres" }
|
||||
dependencies = [
|
||||
{ name = "cron-converter" },
|
||||
{ name = "django" },
|
||||
{ name = "django-pglock" },
|
||||
{ name = "django-pgtrigger" },
|
||||
{ name = "dramatiq", extra = ["watch"] },
|
||||
{ name = "structlog" },
|
||||
@@ -963,6 +981,7 @@ dependencies = [
|
||||
requires-dist = [
|
||||
{ name = "cron-converter", specifier = ">=1,<2" },
|
||||
{ name = "django", specifier = ">=4.2,<6.0" },
|
||||
{ name = "django-pglock", specifier = ">=1.7,<2" },
|
||||
{ name = "django-pgtrigger", specifier = ">=4,<5" },
|
||||
{ name = "dramatiq", extras = ["watch"], specifier = ">=1.17,<1.18" },
|
||||
{ name = "structlog", specifier = ">=25,<26" },
|
||||
@@ -2821,15 +2840,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fa/de/02b54f42487e3d3c6efb3f89428677074ca7bf43aae402517bc7cca949f3/PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563", size = 156446, upload-time = "2024-08-06T20:33:04.33Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "redis"
|
||||
version = "6.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/0d/d6/e8b92798a5bd67d659d51a18170e91c16ac3b59738d91894651ee255ed49/redis-6.4.0.tar.gz", hash = "sha256:b01bc7282b8444e28ec36b261df5375183bb47a07eb9c603f284e89cbc5ef010", size = 4647399, upload-time = "2025-08-07T08:10:11.441Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e8/02/89e2ed7e85db6c93dfa9e8f691c5087df4e3551ab39081a4d7c6d1f90e05/redis-6.4.0-py3-none-any.whl", hash = "sha256:f0544fa9604264e9464cdf4814e7d4830f74b165d52f2a330a760a88dd248b7f", size = 279847, upload-time = "2025-08-07T08:10:09.84Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "referencing"
|
||||
version = "0.36.2"
|
||||
@@ -3265,6 +3275,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9a/bb/d43e5c75054e53efce310e79d63df0ac3f25e34c926be5dffb7d283fb2a8/typeguard-2.13.3-py3-none-any.whl", hash = "sha256:5e3e3be01e887e7eafae5af63d1f36c849aaa94e3a0112097312aabfa16284f1", size = 17605, upload-time = "2021-12-10T21:09:37.844Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-channels"
|
||||
version = "4.3.0.20250822"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "asgiref" },
|
||||
{ name = "django-stubs" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/22/3d/e0a164b7eaab18ab2302b7a3d90843824b729d9fe1d7bdbb3769dfc9d77b/types_channels-4.3.0.20250822.tar.gz", hash = "sha256:29a4928fdaed6d444b93b69d44fcdb5a8fe32fa72d6a41016c5d39fa7bd7f474", size = 15357, upload-time = "2025-08-22T03:04:26.444Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/cb/52/4e3094e43d460feacb9051ec4c3498f8272f69d92b772647211478b25079/types_channels-4.3.0.20250822-py3-none-any.whl", hash = "sha256:d3fc0a1467c8cc901686826408c8a673822e07aa79cbe1a6d21946e7e55d9ddf", size = 21125, upload-time = "2025-08-22T03:04:25.539Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "types-ldap3"
|
||||
version = "2.9.13.20250622"
|
||||
|
||||
Reference in New Issue
Block a user