mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
packages/django-postgres-cache: use upsert instead of select/update in a transaction (#17760)
This commit is contained in:
committed by
GitHub
parent
9b6aa56df2
commit
27ceb3ccf3
@@ -1,15 +1,21 @@
|
||||
import base64
|
||||
import pickle # nosec
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.cache.backends.base import DEFAULT_TIMEOUT
|
||||
from django.core.cache.backends.db import DatabaseCache as BaseDatabaseCache
|
||||
from django.db import DatabaseError
|
||||
from django.db.utils import ProgrammingError
|
||||
from django.utils.module_loading import import_string
|
||||
from django.utils.timezone import now
|
||||
from psqlextra.types import ConflictAction
|
||||
|
||||
from django_postgres_cache.models import CacheEntry
|
||||
|
||||
|
||||
class DatabaseCache(BaseDatabaseCache):
|
||||
|
||||
def __init__(self, table: str, params: dict[str, Any]) -> None:
|
||||
super().__init__(table, params)
|
||||
self.reverse_key_func = import_string(params["REVERSE_KEY_FUNCTION"])
|
||||
@@ -49,3 +55,87 @@ class DatabaseCache(BaseDatabaseCache):
|
||||
if not entry:
|
||||
return None
|
||||
return int((entry.expires - now()).total_seconds())
|
||||
|
||||
def _base_set_expiry(self, timeout: float | None) -> datetime:
|
||||
timeout = self.get_backend_timeout(timeout)
|
||||
if timeout is None:
|
||||
exp = datetime.max
|
||||
else:
|
||||
tz = UTC if settings.USE_TZ else None
|
||||
exp = datetime.fromtimestamp(timeout, tz=tz)
|
||||
exp.replace(microsecond=0)
|
||||
return exp
|
||||
|
||||
def _base_set_data(
|
||||
self,
|
||||
key: Any,
|
||||
value: Any,
|
||||
timeout: float | None,
|
||||
version: int | None = None,
|
||||
) -> tuple[str, str, datetime]:
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
pickled = pickle.dumps(value, self.pickle_protocol)
|
||||
# The DB column is expecting a string, so make sure the value is a
|
||||
# string, not bytes. Refs #19274.
|
||||
b64encoded = base64.b64encode(pickled).decode("latin1")
|
||||
|
||||
return (key, b64encoded, self._base_set_expiry(timeout))
|
||||
|
||||
def touch(
|
||||
self,
|
||||
key: Any,
|
||||
timeout: float | None = DEFAULT_TIMEOUT,
|
||||
version: int | None = None,
|
||||
) -> bool:
|
||||
key = self.make_and_validate_key(key, version=version)
|
||||
expiry = self._base_set_expiry(timeout)
|
||||
try:
|
||||
count = CacheEntry.objects.filter(cache_key=key).update(expires=expiry)
|
||||
return bool(count != 0)
|
||||
except DatabaseError:
|
||||
return False
|
||||
|
||||
def add(
|
||||
self,
|
||||
key: Any,
|
||||
value: Any,
|
||||
timeout: float | None = DEFAULT_TIMEOUT,
|
||||
version: int | None = None,
|
||||
) -> bool:
|
||||
key, value, expiry = self._base_set_data(key, value, timeout, version)
|
||||
try:
|
||||
CacheEntry.objects.on_conflict(
|
||||
["cache_key"],
|
||||
ConflictAction.UPDATE,
|
||||
update_values=dict(
|
||||
expires=expiry,
|
||||
),
|
||||
).insert(
|
||||
cache_key=key,
|
||||
value=value,
|
||||
expires=expiry,
|
||||
)
|
||||
# We don't know if the row already existed, we just return True for success
|
||||
return True
|
||||
except DatabaseError:
|
||||
return False
|
||||
|
||||
def set(
|
||||
self,
|
||||
key: Any,
|
||||
value: Any,
|
||||
timeout: float | None = DEFAULT_TIMEOUT,
|
||||
version: int | None = None,
|
||||
) -> None:
|
||||
key, value, expiry = self._base_set_data(key, value, timeout, version)
|
||||
CacheEntry.objects.on_conflict(
|
||||
["cache_key"],
|
||||
ConflictAction.UPDATE,
|
||||
).insert(
|
||||
cache_key=key,
|
||||
value=value,
|
||||
expires=expiry,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
CacheEntry.objects.truncate()
|
||||
|
||||
+20
@@ -0,0 +1,20 @@
|
||||
# Generated by Django 5.2.7 on 2025-10-28 14:04
|
||||
|
||||
import psqlextra.manager.manager
|
||||
from django.db import migrations
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
|
||||
dependencies = [
|
||||
("django_postgres_cache", "0001_initial"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterModelManagers(
|
||||
name="cacheentry",
|
||||
managers=[
|
||||
("objects", psqlextra.manager.manager.PostgresManager()), # type: ignore[no-untyped-call]
|
||||
],
|
||||
),
|
||||
]
|
||||
@@ -1,12 +1,14 @@
|
||||
from django.db import models
|
||||
from psqlextra.manager import PostgresManager
|
||||
|
||||
|
||||
class CacheEntry(models.Model):
|
||||
|
||||
cache_key = models.TextField(primary_key=True)
|
||||
value = models.TextField()
|
||||
expires = models.DateTimeField(db_index=True)
|
||||
|
||||
objects = PostgresManager() # type: ignore[no-untyped-call]
|
||||
|
||||
class Meta:
|
||||
default_permissions = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user