diff --git a/authentik/core/urls.py b/authentik/core/urls.py index 704ddacc85..78db01ba3e 100644 --- a/authentik/core/urls.py +++ b/authentik/core/urls.py @@ -21,16 +21,11 @@ from authentik.core.api.users import UserViewSet from authentik.core.setup.views import SetupView from authentik.core.views.apps import RedirectToAppLaunch from authentik.core.views.debug import AccessDeniedView -from authentik.core.views.interface import ( - BrandDefaultRedirectView, - InterfaceView, - RootRedirectView, -) +from authentik.core.views.interface import BrandDefaultRedirectView, InterfaceView, RootRedirectView from authentik.flows.views.interface import FlowInterfaceView from authentik.root.asgi_middleware import AuthMiddlewareStack from authentik.root.middleware import ChannelsLoggingMiddleware from authentik.root.ws.consumer import MessageConsumer -from authentik.tenants.channels import TenantsAwareMiddleware urlpatterns = [ path( @@ -103,9 +98,7 @@ api_urlpatterns = [ websocket_urlpatterns = [ path( "ws/client/", - ChannelsLoggingMiddleware( - TenantsAwareMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi())) - ), + ChannelsLoggingMiddleware(AuthMiddlewareStack(MessageConsumer.as_asgi())), ), ] diff --git a/authentik/outposts/urls.py b/authentik/outposts/urls.py index 43cf872801..80ed1156d4 100644 --- a/authentik/outposts/urls.py +++ b/authentik/outposts/urls.py @@ -11,14 +11,11 @@ from authentik.outposts.api.service_connections import ( from authentik.outposts.channels import TokenOutpostMiddleware from authentik.outposts.consumer import OutpostConsumer from authentik.root.middleware import ChannelsLoggingMiddleware -from authentik.tenants.channels import TenantsAwareMiddleware websocket_urlpatterns = [ path( "ws/outpost//", - ChannelsLoggingMiddleware( - TenantsAwareMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())) - ), + ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())), ), ] diff --git a/authentik/providers/rac/urls.py b/authentik/providers/rac/urls.py index 57824901e2..a9322f4400 100644 --- a/authentik/providers/rac/urls.py +++ b/authentik/providers/rac/urls.py @@ -12,7 +12,6 @@ from authentik.providers.rac.consumer_outpost import RACOutpostConsumer from authentik.providers.rac.views import RACInterface, RACStartView from authentik.root.asgi_middleware import AuthMiddlewareStack from authentik.root.middleware import ChannelsLoggingMiddleware -from authentik.tenants.channels import TenantsAwareMiddleware urlpatterns = [ path( @@ -30,15 +29,11 @@ urlpatterns = [ websocket_urlpatterns = [ path( "ws/rac//", - ChannelsLoggingMiddleware( - TenantsAwareMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi())) - ), + ChannelsLoggingMiddleware(AuthMiddlewareStack(RACClientConsumer.as_asgi())), ), path( "ws/outpost_rac//", - ChannelsLoggingMiddleware( - TenantsAwareMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())) - ), + ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())), ), ] diff --git a/authentik/tenants/channels.py b/authentik/tenants/channels.py deleted file mode 100644 index 5033cf8865..0000000000 --- a/authentik/tenants/channels.py +++ /dev/null @@ -1,48 +0,0 @@ -from channels.db import database_sync_to_async -from django.db import close_old_connections, connection -from django.http.request import split_domain_port -from django_tenants.utils import ( - get_public_schema_name, - remove_www, -) - -from authentik.tenants.models import Domain, Tenant - - -class TenantsAwareMiddleware: - """Set the database schema for use with django-tenants""" - - def __init__(self, inner): - self.inner = inner - - def get_hostname_from_scope(self, scope: list[tuple[bytes, bytes]]) -> str | None: - headers = {k.replace(b"-", b"_").upper(): v for k, v in scope.get("headers", [])} - hostname, _ = split_domain_port(headers.get(b"HOST", b"").decode("utf-8")) - if not hostname: - return None - return remove_www(hostname) - - async def get_default_tenant(self) -> Tenant: - return await database_sync_to_async(Tenant.objects.get)( - schema_name=get_public_schema_name() - ) - - async def get_tenant(self, hostname: str | None) -> Tenant: - if not hostname: - return await self.get_default_tenant() - - try: - domain = await database_sync_to_async(Domain.objects.select_related("tenant").get)( - domain=hostname - ) - except Domain.DoesNotExist: - return await self.get_default_tenant() - return domain.tenant - - async def __call__(self, scope, receive, send): - close_old_connections() - hostname = self.get_hostname_from_scope(scope) - tenant = await self.get_tenant(hostname) - scope["tenant"] = tenant - connection.set_tenant(tenant) - return await self.inner(scope, receive, send)