brands: select_related models accessed in the hot path (#23162)

Co-authored-by: Ryan Pesek <rpesek@cloudflare.com>
Co-authored-by: Jens Langhammer <jens@goauthentik.io>
Signed-off-by: Marc 'risson' Schmitt <marc.schmitt@risson.space>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Marc 'risson' Schmitt
2026-06-17 15:17:04 +02:00
committed by GitHub
parent 5839b40efa
commit c755232f0a
5 changed files with 96 additions and 2 deletions
+14
View File
@@ -18,6 +18,20 @@ from authentik.lib.models import SerializerModel
LOGGER = get_logger()
# Brand FKs read on the request hot path. select_related pulls them into the
# same SELECT to avoid N+1 lazy loads; CurrentBrandSerializer alone reads 7.
_BRAND_RELATED_FK_FIELDS = (
"flow_authentication",
"flow_invalidation",
"flow_recovery",
"flow_unenrollment",
"flow_user_settings",
"flow_device_code",
"flow_lockdown",
"default_application",
)
class Brand(SerializerModel):
"""Single brand"""
View File
@@ -0,0 +1,79 @@
"""Test brands"""
from django.http import HttpRequest
from django.test import TestCase
from authentik.brands.models import Brand
from authentik.brands.utils import _BRAND_RELATED_FK_FIELDS, get_brand_for_request
from authentik.core.tests.utils import create_test_flow
from authentik.flows.models import FlowDesignation
class TestGetBrandForRequestSelectRelated(TestCase):
"""``get_brand_for_request`` must hydrate the FK fields read on the
request hot path so later access doesn't trigger lazy loads."""
def setUp(self):
Brand.objects.all().delete()
self.flow_auth = create_test_flow(designation=FlowDesignation.AUTHENTICATION)
self.brand = Brand.objects.create(
domain="select-related-test.example.com",
flow_authentication=self.flow_auth,
)
def _make_request(self, host: str) -> HttpRequest:
request = HttpRequest()
request.META["HTTP_HOST"] = host
return request
def test_brand_fks_are_loaded_in_single_query(self):
"""Brand FK access after ``get_brand_for_request`` must not trigger
extra queries."""
request = self._make_request("select-related-test.example.com")
with self.assertNumQueries(1):
brand = get_brand_for_request(request)
_ = brand.flow_authentication
_ = brand.flow_authentication.slug if brand.flow_authentication else None
_ = brand.flow_invalidation
_ = brand.flow_recovery
_ = brand.flow_unenrollment
_ = brand.flow_user_settings
_ = brand.flow_device_code
_ = brand.flow_lockdown
_ = brand.default_application
def test_brand_related_fk_list_complete(self):
"""``_BRAND_RELATED_FK_FIELDS`` covers every Flow/Application FK on
Brand — fails loud when a new FK is added but not registered here."""
actual_fks = {
f.name
for f in Brand._meta.get_fields()
if f.many_to_one and f.related_model is not None
}
relevant_fks = {
name for name in actual_fks if name.startswith("flow_") or name == "default_application"
}
declared = set(_BRAND_RELATED_FK_FIELDS)
missing = relevant_fks - declared
self.assertFalse(
missing,
f"Brand FK fields {missing} aren't in _BRAND_RELATED_FK_FIELDS — "
"add them or the request hot path will incur extra queries.",
)
def test_brand_related_fks_all_exist_on_model(self):
"""Every entry in ``_BRAND_RELATED_FK_FIELDS`` is a real FK on Brand.
``select_related`` raises ``FieldError`` at first use if any entry
is stale, which would break every request."""
actual_fks = {
f.name
for f in Brand._meta.get_fields()
if f.many_to_one and f.related_model is not None
}
declared = set(_BRAND_RELATED_FK_FIELDS)
extraneous = declared - actual_fks
self.assertFalse(
extraneous,
f"_BRAND_RELATED_FK_FIELDS contains {extraneous} which don't "
f"exist on Brand (actual FKs: {sorted(actual_fks)}).",
)
+3 -2
View File
@@ -9,7 +9,7 @@ from django.utils.html import _json_script_escapes
from django.utils.safestring import mark_safe
from authentik import authentik_full_version
from authentik.brands.models import Brand
from authentik.brands.models import _BRAND_RELATED_FK_FIELDS, Brand
from authentik.lib.sentry import get_http_meta
from authentik.tenants.models import Tenant
@@ -21,7 +21,8 @@ def get_brand_for_request(request: HttpRequest) -> Brand:
"""Get brand object for current request"""
brand = (
Brand.objects.annotate(
Brand.objects.select_related(*_BRAND_RELATED_FK_FIELDS)
.annotate(
host_domain=Value(request.get_host()),
domain_length=Length("domain"),
match_priority=Case(