diff --git a/Makefile b/Makefile index 9d9d77b5db..6d85a741d1 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,8 @@ BREW_LDFLAGS := BREW_CPPFLAGS := BREW_PKG_CONFIG_PATH := +UV := uv + # For macOS users, add the libxml2 installed from brew libxmlsec1 to the build path # to prevent SAML-related tests from failing and ensure correct pip dependency compilation ifeq ($(UNAME_S),Darwin) @@ -33,22 +35,21 @@ ifeq ($(UNAME_S),Darwin) LIBXML2_EXISTS := $(shell brew list libxml2 2> /dev/null) ifdef LIBXML2_EXISTS _xml_pref := $(shell brew --prefix libxml2) - BREW_LDFLAGS += -L${_xml_pref}/lib $(LDFLAGS) - BREW_CPPFLAGS += -I${_xml_pref}/include $(CPPFLAGS) - BREW_PKG_CONFIG_PATH += ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH) + BREW_LDFLAGS += -L${_xml_pref}/lib + BREW_CPPFLAGS += -I${_xml_pref}/include + BREW_PKG_CONFIG_PATH = ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH) endif KRB5_EXISTS := $(shell brew list krb5 2> /dev/null) ifdef KRB5_EXISTS _krb5_pref := $(shell brew --prefix krb5) - BREW_LDFLAGS += -L${_krb5_pref}/lib $(LDFLAGS) - BREW_CPPFLAGS += -I${_krb5_pref}/include $(CPPFLAGS) - BREW_PKG_CONFIG_PATH += ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH) + BREW_LDFLAGS += -L${_krb5_pref}/lib + BREW_CPPFLAGS += -I${_krb5_pref}/include + BREW_PKG_CONFIG_PATH = ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH) endif endif + UV := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)" uv endif -UV_ARGS := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)" - all: lint-fix lint gen web test ## Lint, build, and test everything HELP_WIDTH := $(shell grep -h '^[a-z][^ ]*:.*\#\#' $(MAKEFILE_LIST) 2>/dev/null | \ @@ -65,47 +66,47 @@ go-test: go test -timeout 0 -v -race -cover ./... test: ## Run the server tests and produce a coverage report (locally) - $(KRB_PATH) uv run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik) - uv run coverage html - uv run coverage report + $(UV) run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik) + $(UV) run coverage html + $(UV) run coverage report lint-fix: lint-codespell ## Lint and automatically fix errors in the python source code. Reports spelling errors. - uv run black $(PY_SOURCES) - uv run ruff check --fix $(PY_SOURCES) + $(UV) run black $(PY_SOURCES) + $(UV) run ruff check --fix $(PY_SOURCES) lint-codespell: ## Reports spelling errors. - uv run codespell -w + $(UV) run codespell -w lint: ## Lint the python and golang sources - uv run bandit -c pyproject.toml -r $(PY_SOURCES) + $(UV) run bandit -c pyproject.toml -r $(PY_SOURCES) golangci-lint run -v core-install: ifeq ($(UNAME_S),Darwin) # Clear cache to ensure fresh compilation - uv cache clean + $(UV) cache clean # Force compilation from source for lxml and xmlsec with correct environment - $(UV_ARGS) uv sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec + $(UV) sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec else - uv sync --frozen + $(UV) sync --frozen endif migrate: ## Run the Authentik Django server's migrations - uv run python -m lifecycle.migrate + $(UV) run python -m lifecycle.migrate i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that require translation into files to send to a translation service aws-cfn: - cd lifecycle/aws && npm i && uv run npm run aws-cfn + cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn run-server: ## Run the main authentik server process - uv run ak server + $(UV) run ak server run-worker: ## Run the main authentik worker process - uv run ak worker + $(UV) run ak worker core-i18n-extract: - uv run ak makemessages \ + $(UV) run ak makemessages \ --add-location file \ --no-obsolete \ --ignore web \ @@ -118,17 +119,17 @@ core-i18n-extract: install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core` dev-drop-db: - $(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null)) - $(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null)) - $(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null)) + $(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null)) + $(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null)) + $(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null)) dropdb -U ${pg_user} -h ${pg_host} ${pg_name} || true # Also remove the test-db if it exists dropdb -U ${pg_user} -h ${pg_host} test_${pg_name} || true dev-create-db: - $(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null)) - $(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null)) - $(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null)) + $(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null)) + $(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null)) + $(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null)) createdb -U ${pg_user} -h ${pg_host} ${pg_name} dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state. @@ -156,10 +157,10 @@ gen-build: ## Extract the schema from the database AUTHENTIK_DEBUG=true \ AUTHENTIK_TENANTS__ENABLED=true \ AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \ - uv run ak build_schema + $(UV) run ak build_schema gen-compose: - uv run scripts/generate_compose.py + $(UV) run scripts/generate_compose.py gen-changelog: ## (Release) generate the changelog based from the commits since the last tag git log --pretty=format:" - %s" $(shell git describe --tags $(shell git rev-list --tags --max-count=1))...$(shell git branch --show-current) | sort > changelog.md @@ -227,7 +228,7 @@ endif go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO} gen-dev-config: ## Generate a local development config file - uv run scripts/generate_config.py + $(UV) run scripts/generate_config.py gen: gen-build gen-client-ts @@ -327,24 +328,24 @@ ci--meta-debug: node --version ci-mypy: ci--meta-debug - uv run mypy --strict $(PY_SOURCES) + $(UV) run mypy --strict $(PY_SOURCES) ci-black: ci--meta-debug - uv run black --check $(PY_SOURCES) + $(UV) run black --check $(PY_SOURCES) ci-ruff: ci--meta-debug - uv run ruff check $(PY_SOURCES) + $(UV) run ruff check $(PY_SOURCES) ci-codespell: ci--meta-debug - uv run codespell -s + $(UV) run codespell -s ci-bandit: ci--meta-debug - uv run bandit -r $(PY_SOURCES) + $(UV) run bandit -r $(PY_SOURCES) ci-pending-migrations: ci--meta-debug - uv run ak makemigrations --check + $(UV) run ak makemigrations --check ci-test: ci--meta-debug - uv run coverage run manage.py test --keepdb authentik - uv run coverage report - uv run coverage xml + $(UV) run coverage run manage.py test --keepdb authentik + $(UV) run coverage report + $(UV) run coverage xml diff --git a/authentik/blueprints/v1/common.py b/authentik/blueprints/v1/common.py index 4c975f7cd3..2a49aa1eef 100644 --- a/authentik/blueprints/v1/common.py +++ b/authentik/blueprints/v1/common.py @@ -9,7 +9,7 @@ from functools import reduce from json import JSONDecodeError, loads from operator import ixor from os import getenv -from typing import Any, Literal, Union +from typing import Any, Literal from uuid import UUID from deepmerge import always_merger @@ -70,19 +70,17 @@ class BlueprintEntryDesiredState(Enum): class BlueprintEntryPermission: """Describe object-level permissions""" - permission: Union[str, "YAMLTag"] - user: Union[int, "YAMLTag", None] = field(default=None) - role: Union[str, "YAMLTag", None] = field(default=None) + permission: str | YAMLTag + user: int | YAMLTag | None = field(default=None) + role: str | YAMLTag | None = field(default=None) @dataclass class BlueprintEntry: """Single entry of a blueprint""" - model: Union[str, "YAMLTag"] - state: Union[BlueprintEntryDesiredState, "YAMLTag"] = field( - default=BlueprintEntryDesiredState.PRESENT - ) + model: str | YAMLTag + state: BlueprintEntryDesiredState | YAMLTag = field(default=BlueprintEntryDesiredState.PRESENT) conditions: list[Any] = field(default_factory=list) identifiers: dict[str, Any] = field(default_factory=dict) attrs: dict[str, Any] | None = field(default_factory=dict) @@ -96,7 +94,7 @@ class BlueprintEntry: self.__tag_contexts: list[YAMLTagContext] = [] @staticmethod - def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry": + def from_model(model: SerializerModel, *extra_identifier_names: str) -> BlueprintEntry: """Convert a SerializerModel instance to a blueprint Entry""" identifiers = { "pk": model.pk, @@ -114,8 +112,8 @@ class BlueprintEntry: def get_tag_context( self, depth: int = 0, - context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None, - ) -> "YAMLTagContext": + context_tag_type: type[YAMLTagContext] | tuple[YAMLTagContext, ...] | None = None, + ) -> YAMLTagContext: """Get a YAMLTagContext object located at a certain depth in the tag tree""" if depth < 0: raise ValueError("depth must be a positive number or zero") @@ -130,7 +128,7 @@ class BlueprintEntry: except IndexError as exc: raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc - def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any: + def tag_resolver(self, value: Any, blueprint: Blueprint) -> Any: """Check if we have any special tags that need handling""" val = copy(value) @@ -152,23 +150,23 @@ class BlueprintEntry: return val - def get_attrs(self, blueprint: "Blueprint") -> dict[str, Any]: + def get_attrs(self, blueprint: Blueprint) -> dict[str, Any]: """Get attributes of this entry, with all yaml tags resolved""" return self.tag_resolver(self.attrs, blueprint) - def get_identifiers(self, blueprint: "Blueprint") -> dict[str, Any]: + def get_identifiers(self, blueprint: Blueprint) -> dict[str, Any]: """Get attributes of this entry, with all yaml tags resolved""" return self.tag_resolver(self.identifiers, blueprint) - def get_state(self, blueprint: "Blueprint") -> BlueprintEntryDesiredState: + def get_state(self, blueprint: Blueprint) -> BlueprintEntryDesiredState: """Get the blueprint state, with yaml tags resolved if present""" return BlueprintEntryDesiredState(self.tag_resolver(self.state, blueprint)) - def get_model(self, blueprint: "Blueprint") -> str: + def get_model(self, blueprint: Blueprint) -> str: """Get the blueprint model, with yaml tags resolved if present""" return str(self.tag_resolver(self.model, blueprint)) - def get_permissions(self, blueprint: "Blueprint") -> Generator[BlueprintEntryPermission]: + def get_permissions(self, blueprint: Blueprint) -> Generator[BlueprintEntryPermission]: """Get permissions of this entry, with all yaml tags resolved""" for perm in self.permissions: yield BlueprintEntryPermission( @@ -177,7 +175,7 @@ class BlueprintEntry: role=self.tag_resolver(perm.role, blueprint), ) - def check_all_conditions_match(self, blueprint: "Blueprint") -> bool: + def check_all_conditions_match(self, blueprint: Blueprint) -> bool: """Check all conditions of this entry match (evaluate to True)""" return all(self.tag_resolver(self.conditions, blueprint)) @@ -232,7 +230,7 @@ class KeyOf(YAMLTag): id_from: str - def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: + def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None: super().__init__() self.id_from = node.value @@ -258,7 +256,7 @@ class Env(YAMLTag): key: str default: Any | None - def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None: super().__init__() self.default = None if isinstance(node, ScalarNode): @@ -277,7 +275,7 @@ class File(YAMLTag): path: str default: Any | None - def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None: super().__init__() self.default = None if isinstance(node, ScalarNode): @@ -305,7 +303,7 @@ class Context(YAMLTag): key: str default: Any | None - def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None: super().__init__() self.default = None if isinstance(node, ScalarNode): @@ -328,7 +326,7 @@ class ParseJSON(YAMLTag): raw: str - def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None: + def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None: super().__init__() self.raw = node.value @@ -345,7 +343,7 @@ class Format(YAMLTag): format_string: str args: list[Any] - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.format_string = loader.construct_object(node.value[0]) self.args = [] @@ -372,7 +370,7 @@ class Find(YAMLTag): model_name: str | YAMLTag conditions: list[list] - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.model_name = loader.construct_object(node.value[0]) self.conditions = [] @@ -444,7 +442,7 @@ class Condition(YAMLTag): "XNOR": lambda args: not (reduce(ixor, args) if len(args) > 1 else args[0]), } - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.mode = loader.construct_object(node.value[0]) self.args = [] @@ -478,7 +476,7 @@ class If(YAMLTag): when_true: Any when_false: Any - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.condition = loader.construct_object(node.value[0]) if len(node.value) == 1: @@ -518,7 +516,7 @@ class Enumerate(YAMLTag, YAMLTagContext): ), } - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.iterable = loader.construct_object(node.value[0]) self.output_body = loader.construct_object(node.value[1]) @@ -584,7 +582,7 @@ class EnumeratedItem(YAMLTag): _SUPPORTED_CONTEXT_TAGS = (Enumerate,) - def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None: + def __init__(self, _loader: BlueprintLoader, node: ScalarNode) -> None: super().__init__() self.depth = int(node.value) @@ -640,7 +638,7 @@ class AtIndex(YAMLTag): attribute: int | str | YAMLTag default: Any | UNSET - def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None: + def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None: super().__init__() self.obj = loader.construct_object(node.value[0]) self.attribute = loader.construct_object(node.value[1]) @@ -757,7 +755,7 @@ class EntryInvalidError(SentryIgnoredException): @staticmethod def from_entry( msg_or_exc: str | Exception, entry: BlueprintEntry, *args, **kwargs - ) -> "EntryInvalidError": + ) -> EntryInvalidError: """Create EntryInvalidError with the context of an entry""" error = EntryInvalidError(msg_or_exc, *args, **kwargs) if isinstance(msg_or_exc, ValidationError): diff --git a/authentik/blueprints/v1/importer.py b/authentik/blueprints/v1/importer.py index 312c92468d..86a54e4cbc 100644 --- a/authentik/blueprints/v1/importer.py +++ b/authentik/blueprints/v1/importer.py @@ -147,7 +147,7 @@ class Importer: } @staticmethod - def from_string(yaml_input: str, context: dict | None = None) -> "Importer": + def from_string(yaml_input: str, context: dict | None = None) -> Importer: """Parse YAML string and create blueprint importer from it""" import_dict = load(yaml_input, BlueprintLoader) try: diff --git a/authentik/blueprints/v1/meta/apply_blueprint.py b/authentik/blueprints/v1/meta/apply_blueprint.py index 7ae5b7bc2d..3f4f70c906 100644 --- a/authentik/blueprints/v1/meta/apply_blueprint.py +++ b/authentik/blueprints/v1/meta/apply_blueprint.py @@ -23,7 +23,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer): # We cannot override `instance` as that will confuse rest_framework # and make it attempt to update the instance - blueprint_instance: "BlueprintInstance" + blueprint_instance: BlueprintInstance def validate(self, attrs): from authentik.blueprints.models import BlueprintInstance diff --git a/authentik/core/models.py b/authentik/core/models.py index 57c221bf9d..039fa0dc54 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import StrEnum from hashlib import sha256 -from typing import Any, Optional, Self +from typing import Any, Self from uuid import uuid4 import pgtrigger @@ -225,7 +225,7 @@ class Group(SerializerModel, AttributesMixin): # in the LDAP Outpost we use the last 5 chars so match here return int(str(self.pk.int)[:5]) - def is_member(self, user: "User") -> bool: + def is_member(self, user: User) -> bool: """Recursively check if `user` is member of us, or any parent.""" return user.all_groups().filter(group_uuid=self.group_uuid).exists() @@ -466,7 +466,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser): always_merger.merge(final_attributes, self.attributes) return final_attributes - def app_entitlements(self, app: "Application | None") -> QuerySet["ApplicationEntitlement"]: + def app_entitlements(self, app: Application | None) -> QuerySet[ApplicationEntitlement]: """Get all entitlements this user has for `app`.""" if not app: return [] @@ -485,7 +485,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser): ).order_by("name") return qs - def app_entitlements_attributes(self, app: "Application | None") -> dict: + def app_entitlements_attributes(self, app: Application | None) -> dict: """Get a dictionary containing all merged attributes from app entitlements for `app`.""" final_attributes = {} for attrs in self.app_entitlements(app).values_list("attributes", flat=True): @@ -654,7 +654,7 @@ class BackchannelProvider(Provider): class ApplicationQuerySet(QuerySet): - def with_provider(self) -> "QuerySet[Application]": + def with_provider(self) -> QuerySet[Application]: qs = self.select_related("provider") for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider): qs = qs.select_related(f"provider__{subclass}") @@ -713,9 +713,7 @@ class Application(SerializerModel, PolicyBindingModel): return get_file_manager(FileUsage.MEDIA).file_url(self.meta_icon) - def get_launch_url( - self, user: Optional["User"] = None, user_data: dict | None = None - ) -> str | None: + def get_launch_url(self, user: User | None = None, user_data: dict | None = None) -> str | None: """Get launch URL if set, otherwise attempt to get launch URL based on provider. Args: @@ -948,7 +946,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel): raise NotImplementedError @property - def property_mapping_type(self) -> "type[PropertyMapping]": + def property_mapping_type(self) -> type[PropertyMapping]: """Return property mapping type used by this object""" if self.managed == self.MANAGED_INBUILT: from authentik.core.models import PropertyMapping @@ -1069,7 +1067,7 @@ class ExpiringModel(models.Model): return self.delete(*args, **kwargs) @classmethod - def filter_not_expired(cls, **kwargs) -> QuerySet["Self"]: + def filter_not_expired(cls, **kwargs) -> QuerySet[Self]: """Filer for tokens which are not expired yet or are not expiring, and match filters in `kwargs`""" for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)): @@ -1265,7 +1263,7 @@ class AuthenticatedSession(SerializerModel): return f"Authenticated Session {str(self.pk)[:10]}" @staticmethod - def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]: + def from_request(request: HttpRequest, user: User) -> AuthenticatedSession | None: """Create a new session from a http request""" if not hasattr(request, "session") or not request.session.exists( request.session.session_key diff --git a/authentik/core/signals.py b/authentik/core/signals.py index 3ebda21be6..9fbd9aac44 100644 --- a/authentik/core/signals.py +++ b/authentik/core/signals.py @@ -63,7 +63,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_): @receiver(post_delete, sender=AuthenticatedSession) -def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_): +def authenticated_session_delete(sender: type[Model], instance: AuthenticatedSession, **_): """Delete session when authenticated session is deleted""" Session.objects.filter(session_key=instance.pk).delete() diff --git a/authentik/core/sources/mapper.py b/authentik/core/sources/mapper.py index 8e163edae4..e384a3dca2 100644 --- a/authentik/core/sources/mapper.py +++ b/authentik/core/sources/mapper.py @@ -49,7 +49,7 @@ class SourceMapper: def build_object_properties( self, object_type: type[User | Group], - manager: "PropertyMappingManager | None" = None, + manager: PropertyMappingManager | None = None, user: User | None = None, request: HttpRequest | None = None, **kwargs, diff --git a/authentik/endpoints/connectors/agent/api/agent.py b/authentik/endpoints/connectors/agent/api/agent.py index c2e6e4690f..17688fa092 100644 --- a/authentik/endpoints/connectors/agent/api/agent.py +++ b/authentik/endpoints/connectors/agent/api/agent.py @@ -62,7 +62,7 @@ class AgentConfigSerializer(PassiveSerializer): def get_system_config(self, instance: AgentConnector) -> ConfigSerializer: return ConfigView.get_config(self.context["request"]).data - def get_license_status(self, instance: AgentConnector) -> "LicenseUsageStatus": + def get_license_status(self, instance: AgentConnector) -> LicenseUsageStatus: try: from authentik.enterprise.license import LicenseKey diff --git a/authentik/endpoints/connectors/agent/models.py b/authentik/endpoints/connectors/agent/models.py index b34d446962..5b3cb46cc7 100644 --- a/authentik/endpoints/connectors/agent/models.py +++ b/authentik/endpoints/connectors/agent/models.py @@ -68,7 +68,7 @@ class AgentConnector(Connector): return AuthenticatorEndpointStageView @property - def controller(self) -> type["AgentConnectorController"]: + def controller(self) -> type[AgentConnectorController]: from authentik.endpoints.connectors.agent.controller import AgentConnectorController return AgentConnectorController diff --git a/authentik/endpoints/models.py b/authentik/endpoints/models.py index 68bebebbcc..fe3114d520 100644 --- a/authentik/endpoints/models.py +++ b/authentik/endpoints/models.py @@ -43,7 +43,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi return f"goauthentik.io/endpoints/devices/{self.device_uuid}/facts" @property - def cached_facts(self) -> "DeviceFactSnapshot": + def cached_facts(self) -> DeviceFactSnapshot: if cached := cache.get(self.cache_key_facts): return cached facts = self.facts @@ -51,7 +51,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi return facts @property - def facts(self) -> "DeviceFactSnapshot": + def facts(self) -> DeviceFactSnapshot: data = {} last_updated = datetime.fromtimestamp(0, UTC) for snapshot_data, snapshort_created in DeviceFactSnapshot.filter_not_expired( @@ -157,7 +157,7 @@ class Connector(ScheduledModel, SerializerModel): raise NotImplementedError @property - def controller(self) -> type["BaseController[Connector]"]: + def controller(self) -> type[BaseController[Connector]]: raise NotImplementedError @property @@ -205,7 +205,7 @@ class EndpointStage(Stage): mode = models.TextField(choices=StageMode.choices, default=StageMode.OPTIONAL) @property - def view(self) -> type["StageView"]: + def view(self) -> type[StageView]: from authentik.endpoints.stage import EndpointStageView return EndpointStageView diff --git a/authentik/enterprise/license.py b/authentik/enterprise/license.py index 3ca209426a..e666b4ab14 100644 --- a/authentik/enterprise/license.py +++ b/authentik/enterprise/license.py @@ -93,7 +93,7 @@ class LicenseKey: license_flags: list[LicenseFlags] = field(default_factory=list) @staticmethod - def validate(jwt: str, check_expiry=True) -> "LicenseKey": + def validate(jwt: str, check_expiry=True) -> LicenseKey: """Validate the license from a given JWT""" try: headers = get_unverified_header(jwt) @@ -128,7 +128,7 @@ class LicenseKey: return body @staticmethod - def get_total() -> "LicenseKey": + def get_total() -> LicenseKey: """Get a summarized version of all (not expired) licenses""" total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0) for lic in License.objects.all(): diff --git a/authentik/enterprise/models.py b/authentik/enterprise/models.py index 1b90138709..93ea9e5533 100644 --- a/authentik/enterprise/models.py +++ b/authentik/enterprise/models.py @@ -50,7 +50,7 @@ class License(SerializerModel): return LicenseSerializer @property - def status(self) -> "LicenseKey": + def status(self) -> LicenseKey: """Get parsed license status""" from authentik.enterprise.license import LicenseKey diff --git a/authentik/enterprise/providers/scim/auth_oauth2.py b/authentik/enterprise/providers/scim/auth_oauth2.py index a7f1977131..a5ab7dae96 100644 --- a/authentik/enterprise/providers/scim/auth_oauth2.py +++ b/authentik/enterprise/providers/scim/auth_oauth2.py @@ -19,7 +19,7 @@ class SCIMOAuthException(SCIMRequestException): class SCIMOAuthAuth: - def __init__(self, provider: "SCIMProvider"): + def __init__(self, provider: SCIMProvider): self.provider = provider self.user = provider.auth_oauth_user self.logger = get_logger().bind() diff --git a/authentik/enterprise/providers/ssf/views/auth.py b/authentik/enterprise/providers/ssf/views/auth.py index 91f90b81f2..c455193bc4 100644 --- a/authentik/enterprise/providers/ssf/views/auth.py +++ b/authentik/enterprise/providers/ssf/views/auth.py @@ -17,9 +17,9 @@ if TYPE_CHECKING: class SSFTokenAuth(BaseAuthentication): """SSF Token auth""" - view: "SSFView" + view: SSFView - def __init__(self, view: "SSFView") -> None: + def __init__(self, view: SSFView) -> None: super().__init__() self.view = view diff --git a/authentik/events/context_processors/asn.py b/authentik/events/context_processors/asn.py index e04748ef32..48e9255521 100644 --- a/authentik/events/context_processors/asn.py +++ b/authentik/events/context_processors/asn.py @@ -1,6 +1,6 @@ """ASN Enricher""" -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import TYPE_CHECKING, TypedDict from django.http import HttpRequest from geoip2.errors import GeoIP2Error @@ -27,7 +27,7 @@ class ASNDict(TypedDict): class ASNContextProcessor(MMDBContextProcessor): """ASN Database reader wrapper""" - def capability(self) -> Optional["Capabilities"]: + def capability(self) -> Capabilities | None: from authentik.api.v3.config import Capabilities return Capabilities.CAN_ASN @@ -35,7 +35,7 @@ class ASNContextProcessor(MMDBContextProcessor): def path(self) -> str | None: return CONFIG.get("events.context_processors.asn") - def enrich_event(self, event: "Event"): + def enrich_event(self, event: Event): asn = self.asn_dict(event.client_ip) if not asn: return diff --git a/authentik/events/context_processors/base.py b/authentik/events/context_processors/base.py index fe94dd6d09..f71756fe09 100644 --- a/authentik/events/context_processors/base.py +++ b/authentik/events/context_processors/base.py @@ -1,7 +1,7 @@ """Base event enricher""" from functools import cache -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.http import HttpRequest @@ -13,7 +13,7 @@ if TYPE_CHECKING: class EventContextProcessor: """Base event enricher""" - def capability(self) -> Optional["Capabilities"]: + def capability(self) -> Capabilities | None: """Return the capability this context processor provides""" return None @@ -21,7 +21,7 @@ class EventContextProcessor: """Return true if this context processor is configured""" return False - def enrich_event(self, event: "Event"): + def enrich_event(self, event: Event): """Modify event""" raise NotImplementedError diff --git a/authentik/events/context_processors/geoip.py b/authentik/events/context_processors/geoip.py index d176bba9fe..e6c4ba8e24 100644 --- a/authentik/events/context_processors/geoip.py +++ b/authentik/events/context_processors/geoip.py @@ -1,6 +1,6 @@ """events GeoIP Reader""" -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import TYPE_CHECKING, TypedDict from django.http import HttpRequest from geoip2.errors import GeoIP2Error @@ -29,7 +29,7 @@ class GeoIPDict(TypedDict): class GeoIPContextProcessor(MMDBContextProcessor): """Slim wrapper around GeoIP API""" - def capability(self) -> Optional["Capabilities"]: + def capability(self) -> Capabilities | None: from authentik.api.v3.config import Capabilities return Capabilities.CAN_GEO_IP @@ -37,7 +37,7 @@ class GeoIPContextProcessor(MMDBContextProcessor): def path(self) -> str | None: return CONFIG.get("events.context_processors.geoip") - def enrich_event(self, event: "Event"): + def enrich_event(self, event: Event): city = self.city_dict(event.client_ip) if not city: return diff --git a/authentik/events/logs.py b/authentik/events/logs.py index 33feddbc3a..09b8148283 100644 --- a/authentik/events/logs.py +++ b/authentik/events/logs.py @@ -25,7 +25,7 @@ class LogEvent: attributes: dict[str, Any] = field(default_factory=dict) @staticmethod - def from_event_dict(item: EventDict) -> "LogEvent": + def from_event_dict(item: EventDict) -> LogEvent: event = item.pop("event") log_level = item.pop("level").lower() timestamp = datetime.fromisoformat(item.pop("timestamp")).replace(tzinfo=UTC) diff --git a/authentik/events/models.py b/authentik/events/models.py index 73d571dfa9..cfea806939 100644 --- a/authentik/events/models.py +++ b/authentik/events/models.py @@ -151,7 +151,7 @@ class Event(SerializerModel, ExpiringModel): action: str | EventAction, app: str | None = None, **kwargs, - ) -> "Event": + ) -> Event: """Create new Event instance from arguments. Instance is NOT saved.""" if not isinstance(action, EventAction): action = EventAction.CUSTOM_PREFIX + action @@ -169,19 +169,19 @@ class Event(SerializerModel, ExpiringModel): event = Event(action=action, app=app, context=cleaned_kwargs) return event - def with_exception(self, exc: Exception) -> "Event": + def with_exception(self, exc: Exception) -> Event: """Add data from 'exc' to the event in a database-saveable format""" self.context.setdefault("message", str(exc)) self.context["exception"] = exception_to_dict(exc) return self - def set_user(self, user: User) -> "Event": + def set_user(self, user: User) -> Event: """Set `.user` based on user, ensuring the correct attributes are copied. This should only be used when self.from_http is *not* used.""" self.user = get_user(user) return self - def from_http(self, request: HttpRequest, user: User | None = None) -> "Event": + def from_http(self, request: HttpRequest, user: User | None = None) -> Event: """Add data from a Django-HttpRequest, allowing the creation of Events independently from requests. `user` arguments optionally overrides user from requests.""" @@ -343,7 +343,7 @@ class NotificationTransport(TasksModel, SerializerModel): ), ) - def send(self, notification: "Notification") -> list[str]: + def send(self, notification: Notification) -> list[str]: """Send notification to user, called from async task""" if self.mode == TransportMode.LOCAL: return self.send_local(notification) @@ -355,7 +355,7 @@ class NotificationTransport(TasksModel, SerializerModel): return self.send_email(notification) raise ValueError(f"Invalid mode {self.mode} set") - def send_local(self, notification: "Notification") -> list[str]: + def send_local(self, notification: Notification) -> list[str]: """Local notification delivery""" if self.webhook_mapping_body: self.webhook_mapping_body.evaluate( @@ -375,7 +375,7 @@ class NotificationTransport(TasksModel, SerializerModel): ) return [] - def send_webhook(self, notification: "Notification") -> list[str]: + def send_webhook(self, notification: Notification) -> list[str]: """Send notification to generic webhook""" default_body = { "body": notification.body, @@ -419,7 +419,7 @@ class NotificationTransport(TasksModel, SerializerModel): response.text, ] - def send_webhook_slack(self, notification: "Notification") -> list[str]: + def send_webhook_slack(self, notification: Notification) -> list[str]: """Send notification to slack or slack-compatible endpoints""" fields = [ { @@ -477,7 +477,7 @@ class NotificationTransport(TasksModel, SerializerModel): response.text, ] - def send_email(self, notification: "Notification") -> list[str]: + def send_email(self, notification: Notification) -> list[str]: """Send notification via global email configuration""" from authentik.stages.email.tasks import send_mail diff --git a/authentik/flows/api/flows_diagram.py b/authentik/flows/api/flows_diagram.py index 784eeed8f3..debe6a60c4 100644 --- a/authentik/flows/api/flows_diagram.py +++ b/authentik/flows/api/flows_diagram.py @@ -18,7 +18,7 @@ class DiagramElement: identifier: str description: str action: str | None = None - source: list["DiagramElement"] | None = None + source: list[DiagramElement] | None = None style: list[str] = field(default_factory=lambda: ["[", "]"]) diff --git a/authentik/flows/challenge.py b/authentik/flows/challenge.py index dfb3585ec4..d99d1f1e57 100644 --- a/authentik/flows/challenge.py +++ b/authentik/flows/challenge.py @@ -2,7 +2,7 @@ from dataclasses import asdict, is_dataclass from enum import Enum -from typing import TYPE_CHECKING, Optional, TypedDict +from typing import TYPE_CHECKING, TypedDict from uuid import UUID from django.core.serializers.json import DjangoJSONEncoder @@ -137,7 +137,7 @@ class PermissionDict(TypedDict): class ChallengeResponse(PassiveSerializer): """Base class for all challenge responses""" - stage: Optional["StageView"] + stage: StageView | None component = CharField(default="xak-flow-response-default") def __init__(self, instance=None, data=None, **kwargs): diff --git a/authentik/flows/markers.py b/authentik/flows/markers.py index 7e6b9a2819..ac46062b23 100644 --- a/authentik/flows/markers.py +++ b/authentik/flows/markers.py @@ -23,7 +23,7 @@ class StageMarker: def process( self, - plan: "FlowPlan", + plan: FlowPlan, binding: FlowStageBinding, http_request: HttpRequest, ) -> FlowStageBinding | None: @@ -40,7 +40,7 @@ class ReevaluateMarker(StageMarker): def process( self, - plan: "FlowPlan", + plan: FlowPlan, binding: FlowStageBinding, http_request: HttpRequest, ) -> FlowStageBinding | None: diff --git a/authentik/flows/models.py b/authentik/flows/models.py index 6e6c458ccf..80d79066ad 100644 --- a/authentik/flows/models.py +++ b/authentik/flows/models.py @@ -90,7 +90,7 @@ class Stage(SerializerModel): objects = InheritanceManager() @property - def view(self) -> type["StageView"]: + def view(self) -> type[StageView]: """Return StageView class that implements logic for this stage""" # This is a bit of a workaround, since we can't set class methods with setattr if hasattr(self, "__in_memory_type"): @@ -117,7 +117,7 @@ class Stage(SerializerModel): return f"Stage {self.name}" -def in_memory_stage(view: type["StageView"], **kwargs) -> Stage: +def in_memory_stage(view: type[StageView], **kwargs) -> Stage: """Creates an in-memory stage instance, based on a `view` as view. Any key-word arguments are set as attributes on the stage object, accessible via `self.executor.current_stage`.""" @@ -310,13 +310,13 @@ class FlowToken(InternallyManagedMixin, Token): revoke_on_execution = models.BooleanField(default=True) @staticmethod - def pickle(plan: "FlowPlan") -> str: + def pickle(plan: FlowPlan) -> str: """Pickle into string""" data = dumps(plan) return b64encode(data).decode() @property - def plan(self) -> "FlowPlan": + def plan(self) -> FlowPlan: """Load Flow plan from pickled version""" return loads(b64decode(self._plan.encode())) # nosec diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index fb195afc7d..d81eff4d7d 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -123,7 +123,7 @@ class FlowPlan: def requires_flow_executor( self, - allowed_silent_types: list["StageView"] | None = None, + allowed_silent_types: list[StageView] | None = None, ): # Check if we actually need to show the Flow executor, or if we can jump straight to the end found_unskippable = True @@ -145,7 +145,7 @@ class FlowPlan: request: HttpRequest, flow: Flow, next: str | None = None, - allowed_silent_types: list["StageView"] | None = None, + allowed_silent_types: list[StageView] | None = None, ) -> HttpResponse: """Redirect to the flow executor for this flow plan""" from authentik.flows.views.executor import ( diff --git a/authentik/flows/stage.py b/authentik/flows/stage.py index 727a41ddb8..25718a9326 100644 --- a/authentik/flows/stage.py +++ b/authentik/flows/stage.py @@ -46,13 +46,13 @@ HIST_FLOWS_STAGE_TIME = Histogram( class StageView(View): """Abstract Stage""" - executor: "FlowExecutorView" + executor: FlowExecutorView request: HttpRequest = None logger: BoundLogger - def __init__(self, executor: "FlowExecutorView", **kwargs): + def __init__(self, executor: FlowExecutorView, **kwargs): self.executor = executor current_stage = getattr(self.executor, "current_stage", None) self.logger = get_logger().bind( @@ -257,7 +257,7 @@ class AccessDeniedStage(ChallengeStageView): error_message: str | None - def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs): + def __init__(self, executor: FlowExecutorView, error_message: str | None = None, **kwargs): super().__init__(executor, **kwargs) self.error_message = error_message diff --git a/authentik/lib/avatars.py b/authentik/lib/avatars.py index 1343a92cbc..853c94a85e 100644 --- a/authentik/lib/avatars.py +++ b/authentik/lib/avatars.py @@ -37,18 +37,18 @@ SVG_FONTS = [ ] -def avatar_mode_none(user: "User", mode: str) -> str | None: +def avatar_mode_none(user: User, mode: str) -> str | None: """No avatar""" return DEFAULT_AVATAR -def avatar_mode_attribute(user: "User", mode: str) -> str | None: +def avatar_mode_attribute(user: User, mode: str) -> str | None: """Avatars based on a user attribute""" avatar = get_path_from_dict(user.attributes, mode[11:], default=None) return avatar -def avatar_mode_gravatar(user: "User", mode: str) -> str | None: +def avatar_mode_gravatar(user: User, mode: str) -> str | None: """Gravatar avatars""" mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec @@ -141,7 +141,7 @@ def generate_avatar_from_name( return etree.tostring(root_element).decode() -def avatar_mode_generated(user: "User", mode: str) -> str | None: +def avatar_mode_generated(user: User, mode: str) -> str | None: """Wrapper that converts generated avatar to base64 svg""" # By default generate based off of user's display name name = user.name.strip() @@ -155,7 +155,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None: return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}" -def avatar_mode_url(user: "User", mode: str) -> str | None: +def avatar_mode_url(user: User, mode: str) -> str | None: """Format url""" mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec @@ -197,7 +197,7 @@ def avatar_mode_url(user: "User", mode: str) -> str | None: return formatted_url -def get_avatar(user: "User", request: HttpRequest | None = None) -> str: +def get_avatar(user: User, request: HttpRequest | None = None) -> str: """Get avatar with configured mode""" mode_map = { "none": avatar_mode_none, diff --git a/authentik/lib/expression/evaluator.py b/authentik/lib/expression/evaluator.py index 4119316c4e..7859a7e816 100644 --- a/authentik/lib/expression/evaluator.py +++ b/authentik/lib/expression/evaluator.py @@ -238,7 +238,7 @@ class BaseEvaluator: address: str | list[str], subject: str, body: str | None = None, - stage: "EmailStage | None" = None, + stage: EmailStage | None = None, template: str | None = None, context: dict | None = None, ) -> bool: diff --git a/authentik/outposts/controllers/k8s/base.py b/authentik/outposts/controllers/k8s/base.py index 7f4b153b30..f6dd4f6f49 100644 --- a/authentik/outposts/controllers/k8s/base.py +++ b/authentik/outposts/controllers/k8s/base.py @@ -36,9 +36,9 @@ def get_version() -> str: class KubernetesObjectReconciler[T]: """Base Kubernetes Reconciler, handles the basic logic.""" - controller: "KubernetesController" + controller: KubernetesController - def __init__(self, controller: "KubernetesController"): + def __init__(self, controller: KubernetesController): self.controller = controller self.namespace = controller.outpost.config.kubernetes_namespace self.logger = get_logger().bind(type=self.__class__.__name__) diff --git a/authentik/outposts/controllers/k8s/deployment.py b/authentik/outposts/controllers/k8s/deployment.py index 2b0d7db8b3..e13a6f462f 100644 --- a/authentik/outposts/controllers/k8s/deployment.py +++ b/authentik/outposts/controllers/k8s/deployment.py @@ -39,7 +39,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]): outpost: Outpost - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api = AppsV1Api(controller.client) self.outpost = self.controller.outpost diff --git a/authentik/outposts/controllers/k8s/secret.py b/authentik/outposts/controllers/k8s/secret.py index ca765b6833..4c89c6c778 100644 --- a/authentik/outposts/controllers/k8s/secret.py +++ b/authentik/outposts/controllers/k8s/secret.py @@ -21,7 +21,7 @@ def b64string(source: str) -> str: class SecretReconciler(KubernetesObjectReconciler[V1Secret]): """Kubernetes Secret Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api = CoreV1Api(controller.client) diff --git a/authentik/outposts/controllers/k8s/service.py b/authentik/outposts/controllers/k8s/service.py index f3cbb75e0b..44bfd57396 100644 --- a/authentik/outposts/controllers/k8s/service.py +++ b/authentik/outposts/controllers/k8s/service.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: class ServiceReconciler(KubernetesObjectReconciler[V1Service]): """Kubernetes Service Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api = CoreV1Api(controller.client) diff --git a/authentik/outposts/controllers/k8s/service_monitor.py b/authentik/outposts/controllers/k8s/service_monitor.py index c40bdb3fea..1856e61c3f 100644 --- a/authentik/outposts/controllers/k8s/service_monitor.py +++ b/authentik/outposts/controllers/k8s/service_monitor.py @@ -65,7 +65,7 @@ CRD_PLURAL = "servicemonitors" class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusServiceMonitor]): """Kubernetes Prometheus ServiceMonitor Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api_ex = ApiextensionsV1Api(controller.client) self.api = CustomObjectsApi(controller.client) diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index 086cb507ce..ebc9f74035 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -304,7 +304,7 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel): return f"goauthentik.io/outposts/state/{self.uuid.hex}" @property - def state(self) -> list["OutpostState"]: + def state(self) -> list[OutpostState]: """Get outpost's health status""" return OutpostState.for_outpost(self) @@ -480,7 +480,7 @@ class OutpostState: return parse(self.version) != OUR_VERSION @staticmethod - def for_outpost(outpost: Outpost) -> list["OutpostState"]: + def for_outpost(outpost: Outpost) -> list[OutpostState]: """Get all states for an outpost""" keys = cache.keys(f"{outpost.state_cache_prefix}/*") if not keys: @@ -492,7 +492,7 @@ class OutpostState: return states @staticmethod - def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState": + def for_instance_uid(outpost: Outpost, uid: str) -> OutpostState: """Get state for a single instance""" key = f"{outpost.state_cache_prefix}/{uid}" default_data = {"uid": uid} diff --git a/authentik/policies/engine.py b/authentik/policies/engine.py index db03cd10b6..0a57b92e74 100644 --- a/authentik/policies/engine.py +++ b/authentik/policies/engine.py @@ -140,7 +140,7 @@ class PolicyEngine: passing = False self.__static_result = PolicyResult(passing) - def build(self) -> "PolicyEngine": + def build(self) -> PolicyEngine: """Build wrapper which monitors performance""" with ( start_span( diff --git a/authentik/policies/expression/evaluator.py b/authentik/policies/expression/evaluator.py index 82d71c7488..1561ed2aa4 100644 --- a/authentik/policies/expression/evaluator.py +++ b/authentik/policies/expression/evaluator.py @@ -1,7 +1,7 @@ """authentik expression policy evaluator""" from ipaddress import ip_address -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from django.http import HttpRequest from structlog.stdlib import get_logger @@ -23,7 +23,7 @@ class PolicyEvaluator(BaseEvaluator): _messages: list[str] - policy: Optional["ExpressionPolicy"] = None + policy: ExpressionPolicy | None = None def __init__(self, policy_name: str | None = None): super().__init__(policy_name or "PolicyEvaluator") diff --git a/authentik/policies/models.py b/authentik/policies/models.py index 9a5832851c..6ddec83658 100644 --- a/authentik/policies/models.py +++ b/authentik/policies/models.py @@ -55,7 +55,7 @@ class PolicyBindingModel(models.Model): class BoundPolicyQuerySet(models.QuerySet): """QuerySet for filtering enabled bindings for a Policy type""" - def for_policy(self, policy: "Policy"): + def for_policy(self, policy: Policy): return self.filter(policy__in=policy._default_manager.all()).filter(enabled=True) diff --git a/authentik/providers/oauth2/id_token.py b/authentik/providers/oauth2/id_token.py index fb19569111..cc48ac7494 100644 --- a/authentik/providers/oauth2/id_token.py +++ b/authentik/providers/oauth2/id_token.py @@ -73,8 +73,8 @@ class IDToken: @staticmethod def new( - provider: "OAuth2Provider", token: "BaseGrantModel", request: HttpRequest, **kwargs - ) -> "IDToken": + provider: OAuth2Provider, token: BaseGrantModel, request: HttpRequest, **kwargs + ) -> IDToken: """Create ID Token""" id_token = IDToken(provider, token, **kwargs) id_token.exp = int( @@ -147,7 +147,7 @@ class IDToken: id_dict.update(self.claims) return id_dict - def to_access_token(self, provider: "OAuth2Provider", token: "BaseGrantModel") -> str: + def to_access_token(self, provider: OAuth2Provider, token: BaseGrantModel) -> str: """Encode id_token for use as access token, adding fields""" final = self.to_dict() final["azp"] = provider.client_id @@ -155,6 +155,6 @@ class IDToken: final["scope"] = " ".join(token.scope) return provider.encode(final) - def to_jwt(self, provider: "OAuth2Provider") -> str: + def to_jwt(self, provider: OAuth2Provider) -> str: """Shortcut to encode id_token to jwt, signed by self.provider""" return provider.encode(self.to_dict()) diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index 2e2a5d05a7..79f236b79e 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -514,7 +514,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr return f"Access Token for {self.provider_id} for user {self.user_id}" @property - def id_token(self) -> "IDToken": + def id_token(self) -> IDToken: """Load ID Token from json""" from authentik.providers.oauth2.id_token import IDToken @@ -522,7 +522,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr return from_dict(IDToken, raw_token) @id_token.setter - def id_token(self, value: "IDToken"): + def id_token(self, value: IDToken): self.token = value.to_access_token(self.provider, self) self._id_token = json.dumps(asdict(value)) @@ -567,7 +567,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG return f"Refresh Token for {self.provider_id} for user {self.user_id}" @property - def id_token(self) -> "IDToken": + def id_token(self) -> IDToken: """Load ID Token from json""" from authentik.providers.oauth2.id_token import IDToken @@ -575,7 +575,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG return from_dict(IDToken, raw_token) @id_token.setter - def id_token(self, value: "IDToken"): + def id_token(self, value: IDToken): self._id_token = json.dumps(asdict(value)) @property diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 433d317518..8e2d2a5c5b 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -105,7 +105,7 @@ class OAuthAuthorizationParams: github_compat: InitVar[bool] = False @staticmethod - def from_request(request: HttpRequest, github_compat=False) -> "OAuthAuthorizationParams": + def from_request(request: HttpRequest, github_compat=False) -> OAuthAuthorizationParams: """ Get all the params used by the Authorization Code Flow (and also for the Implicit and Hybrid). diff --git a/authentik/providers/oauth2/views/introspection.py b/authentik/providers/oauth2/views/introspection.py index 1d3c6d665e..9eaf973fa4 100644 --- a/authentik/providers/oauth2/views/introspection.py +++ b/authentik/providers/oauth2/views/introspection.py @@ -40,7 +40,7 @@ class TokenIntrospectionParams: raise TokenIntrospectionError() @staticmethod - def from_request(request: HttpRequest) -> "TokenIntrospectionParams": + def from_request(request: HttpRequest) -> TokenIntrospectionParams: """Extract required Parameters from HTTP Request""" raw_token = request.POST.get("token") provider = authenticate_provider(request) diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 0afc9589eb..c546f2ac16 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -104,7 +104,7 @@ class TokenParams: provider: OAuth2Provider, client_id: str, client_secret: str, - ) -> "TokenParams": + ) -> TokenParams: """Parse params for request""" return TokenParams( # Init vars diff --git a/authentik/providers/oauth2/views/token_revoke.py b/authentik/providers/oauth2/views/token_revoke.py index 1f5da76913..52729b0c0d 100644 --- a/authentik/providers/oauth2/views/token_revoke.py +++ b/authentik/providers/oauth2/views/token_revoke.py @@ -27,7 +27,7 @@ class TokenRevocationParams: provider: OAuth2Provider @staticmethod - def from_request(request: HttpRequest) -> "TokenRevocationParams": + def from_request(request: HttpRequest) -> TokenRevocationParams: """Extract required Parameters from HTTP Request""" raw_token = request.POST.get("token") diff --git a/authentik/providers/proxy/controllers/k8s/httproute.py b/authentik/providers/proxy/controllers/k8s/httproute.py index a6fcf3a32c..552fcda0b4 100644 --- a/authentik/providers/proxy/controllers/k8s/httproute.py +++ b/authentik/providers/proxy/controllers/k8s/httproute.py @@ -82,7 +82,7 @@ class HTTPRoute: class HTTPRouteReconciler(KubernetesObjectReconciler): """Kubernetes Gateway API HTTPRoute Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api_ex = ApiextensionsV1Api(controller.client) self.api = CustomObjectsApi(controller.client) diff --git a/authentik/providers/proxy/controllers/k8s/ingress.py b/authentik/providers/proxy/controllers/k8s/ingress.py index 6d3d825aa7..2307ff5229 100644 --- a/authentik/providers/proxy/controllers/k8s/ingress.py +++ b/authentik/providers/proxy/controllers/k8s/ingress.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: class IngressReconciler(KubernetesObjectReconciler[V1Ingress]): """Kubernetes Ingress Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api = NetworkingV1Api(controller.client) diff --git a/authentik/providers/proxy/controllers/k8s/traefik.py b/authentik/providers/proxy/controllers/k8s/traefik.py index 406619231f..4f3e0530df 100644 --- a/authentik/providers/proxy/controllers/k8s/traefik.py +++ b/authentik/providers/proxy/controllers/k8s/traefik.py @@ -12,7 +12,7 @@ from authentik.providers.proxy.controllers.k8s.traefik_3 import ( class TraefikMiddlewareReconciler(KubernetesObjectReconciler): """Kubernetes Traefik Middleware Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.reconciler = Traefik3MiddlewareReconciler(controller) if not self.reconciler.crd_exists(): diff --git a/authentik/providers/proxy/controllers/k8s/traefik_2.py b/authentik/providers/proxy/controllers/k8s/traefik_2.py index 16e4011bc2..2cfcea7e9c 100644 --- a/authentik/providers/proxy/controllers/k8s/traefik_2.py +++ b/authentik/providers/proxy/controllers/k8s/traefik_2.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: class Traefik2MiddlewareReconciler(Traefik3MiddlewareReconciler): """Kubernetes Traefik Middleware Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.crd_name = "middlewares.traefik.containo.us" self.crd_group = "traefik.containo.us" diff --git a/authentik/providers/proxy/controllers/k8s/traefik_3.py b/authentik/providers/proxy/controllers/k8s/traefik_3.py index a15f9caf4e..70c39466e7 100644 --- a/authentik/providers/proxy/controllers/k8s/traefik_3.py +++ b/authentik/providers/proxy/controllers/k8s/traefik_3.py @@ -57,7 +57,7 @@ class TraefikMiddleware: class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]): """Kubernetes Traefik Middleware Reconciler""" - def __init__(self, controller: "KubernetesController") -> None: + def __init__(self, controller: KubernetesController) -> None: super().__init__(controller) self.api_ex = ApiextensionsV1Api(controller.client) self.api = CustomObjectsApi(controller.client) diff --git a/authentik/providers/scim/clients/auth.py b/authentik/providers/scim/clients/auth.py index 7711808c0c..541f530866 100644 --- a/authentik/providers/scim/clients/auth.py +++ b/authentik/providers/scim/clients/auth.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: class SCIMTokenAuth: - def __init__(self, provider: "SCIMProvider"): + def __init__(self, provider: SCIMProvider): self.provider = provider def __call__(self, request: Request) -> Request: diff --git a/authentik/providers/scim/clients/schema.py b/authentik/providers/scim/clients/schema.py index 9a59f43bf7..dfc4460d40 100644 --- a/authentik/providers/scim/clients/schema.py +++ b/authentik/providers/scim/clients/schema.py @@ -132,7 +132,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration): return self._is_fallback @staticmethod - def default() -> "ServiceProviderConfiguration": + def default() -> ServiceProviderConfiguration: """Get default configuration, which doesn't support any optional features as fallback""" return ServiceProviderConfiguration( patch=Patch(supported=False), diff --git a/authentik/sources/ldap/models.py b/authentik/sources/ldap/models.py index 9d66c7ad79..9004d287fe 100644 --- a/authentik/sources/ldap/models.py +++ b/authentik/sources/ldap/models.py @@ -183,7 +183,7 @@ class LDAPSource(IncomingSyncSource): ] @property - def property_mapping_type(self) -> "type[PropertyMapping]": + def property_mapping_type(self) -> type[PropertyMapping]: from authentik.sources.ldap.models import LDAPSourcePropertyMapping return LDAPSourcePropertyMapping diff --git a/authentik/sources/oauth/models.py b/authentik/sources/oauth/models.py index a681b9a1b1..fd8b83e483 100644 --- a/authentik/sources/oauth/models.py +++ b/authentik/sources/oauth/models.py @@ -85,7 +85,7 @@ class OAuthSource(NonCreatableType, Source): ) @property - def source_type(self) -> type["SourceType"]: + def source_type(self) -> type[SourceType]: """Return the provider instance for this source""" from authentik.sources.oauth.types.registry import registry diff --git a/authentik/sources/saml/processors/response.py b/authentik/sources/saml/processors/response.py index f9855d1ca8..2c670f8dad 100644 --- a/authentik/sources/saml/processors/response.py +++ b/authentik/sources/saml/processors/response.py @@ -222,7 +222,7 @@ class ResponseProcessor: policy_context={}, ) - def _get_name_id(self) -> "Element": + def _get_name_id(self) -> Element: """Get NameID Element""" assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion") if assertion is None: diff --git a/authentik/sources/telegram/models.py b/authentik/sources/telegram/models.py index 6be980efc1..711b80d283 100644 --- a/authentik/sources/telegram/models.py +++ b/authentik/sources/telegram/models.py @@ -86,7 +86,7 @@ class TelegramSource(Source): ) @property - def property_mapping_type(self) -> "type[PropertyMapping]": + def property_mapping_type(self) -> type[PropertyMapping]: return TelegramSourcePropertyMapping def get_base_user_properties( diff --git a/authentik/stages/authenticator/__init__.py b/authentik/stages/authenticator/__init__.py index da685aaf96..1b5c781240 100644 --- a/authentik/stages/authenticator/__init__.py +++ b/authentik/stages/authenticator/__init__.py @@ -68,7 +68,7 @@ def match_token(user, token): return device -def devices_for_user(user: "User", confirmed: bool | None = True, for_verify: bool = False): +def devices_for_user(user: User, confirmed: bool | None = True, for_verify: bool = False): """ Return an iterable of all devices registered to the given user. diff --git a/authentik/stages/authenticator_email/models.py b/authentik/stages/authenticator_email/models.py index 03cd964ec4..01c74c6345 100644 --- a/authentik/stages/authenticator_email/models.py +++ b/authentik/stages/authenticator_email/models.py @@ -100,7 +100,7 @@ class AuthenticatorEmailStage(ConfigurableStage, FriendlyNamedStage, Stage): timeout=self.timeout, ) - def send(self, device: "EmailDevice"): + def send(self, device: EmailDevice): # Lazy import here to avoid circular import from authentik.stages.email.tasks import send_mails diff --git a/authentik/stages/authenticator_sms/models.py b/authentik/stages/authenticator_sms/models.py index 7137203e6c..62a5930b0a 100644 --- a/authentik/stages/authenticator_sms/models.py +++ b/authentik/stages/authenticator_sms/models.py @@ -68,7 +68,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): help_text=_("Optionally modify the payload being sent to custom providers."), ) - def send(self, request: HttpRequest, token: str, device: "SMSDevice"): + def send(self, request: HttpRequest, token: str, device: SMSDevice): """Send message via selected provider""" if self.provider == SMSProviders.TWILIO: return self.send_twilio(request, token, device) @@ -80,7 +80,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): """Get SMS message""" return _("Use this code to authenticate in authentik: {token}".format_map({"token": token})) - def send_twilio(self, request: HttpRequest, token: str, device: "SMSDevice"): + def send_twilio(self, request: HttpRequest, token: str, device: SMSDevice): """send sms via twilio provider""" client = Client(self.account_sid, self.auth) message_body = str(self.get_message(token)) @@ -105,7 +105,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage): LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg) raise ValidationError(exc.msg) from None - def send_generic(self, request: HttpRequest, token: str, device: "SMSDevice"): + def send_generic(self, request: HttpRequest, token: str, device: SMSDevice): """Send SMS via outside API""" payload = { "From": self.from_number, diff --git a/authentik/stages/authenticator_validate/challenge.py b/authentik/stages/authenticator_validate/challenge.py index 1fb1df4878..8f6c968e6d 100644 --- a/authentik/stages/authenticator_validate/challenge.py +++ b/authentik/stages/authenticator_validate/challenge.py @@ -55,7 +55,7 @@ class DeviceChallenge(PassiveSerializer): def get_challenge_for_device( - stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device + stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage, device: Device ) -> dict: """Generate challenge for a single device""" if isinstance(device, WebAuthnDevice): @@ -67,7 +67,7 @@ def get_challenge_for_device( def get_webauthn_challenge_without_user( - stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage + stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage ) -> dict: """Same as `get_webauthn_challenge`, but allows any client device. We can then later check who the device belongs to.""" @@ -85,7 +85,7 @@ def get_webauthn_challenge_without_user( def get_webauthn_challenge( - stage_view: "AuthenticatorValidateStageView", + stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage, device: WebAuthnDevice | None = None, ) -> dict: diff --git a/authentik/tasks/schedules/common.py b/authentik/tasks/schedules/common.py index 4a807b2fad..ca8c9c3eae 100644 --- a/authentik/tasks/schedules/common.py +++ b/authentik/tasks/schedules/common.py @@ -41,7 +41,7 @@ class ScheduleSpec: options["uid"] = self.uid return pickle.dumps(options) - def update_or_create(self) -> "Schedule": + def update_or_create(self) -> Schedule: from django.contrib.contenttypes.models import ContentType from authentik.tasks.schedules.models import Schedule diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index dfdef057fe..b93439af85 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -53,27 +53,27 @@ workers = CONFIG.get_int("web.workers", default_workers) threads = CONFIG.get_int("web.threads", 4) -def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): +def post_fork(server: Arbiter, worker: DjangoUvicornWorker): """Tell prometheus to use worker number instead of process ID for multiprocess""" from prometheus_client import values values.ValueClass = MultiProcessValue(lambda: worker._worker_id) -def worker_exit(server: "Arbiter", worker: DjangoUvicornWorker): +def worker_exit(server: Arbiter, worker: DjangoUvicornWorker): """Remove pid dbs when worker is shutdown""" from prometheus_client import multiprocess multiprocess.mark_process_dead(worker._worker_id) -def on_starting(server: "Arbiter"): +def on_starting(server: Arbiter): """Attach a set of IDs that can be temporarily reused. Used on reloads when each worker exists twice.""" server._worker_id_overload = set() -def nworkers_changed(server: "Arbiter", new_value, old_value): +def nworkers_changed(server: Arbiter, new_value, old_value): """Gets called on startup too. Set the current number of workers. Required if we raise the worker count temporarily using TTIN because server.cfg.workers won't be updated and if @@ -81,7 +81,7 @@ def nworkers_changed(server: "Arbiter", new_value, old_value): server._worker_id_current_workers = new_value -def _next_worker_id(server: "Arbiter"): +def _next_worker_id(server: Arbiter): """If there are IDs open for reuse, take one. Else look for a free one.""" if server._worker_id_overload: return server._worker_id_overload.pop() @@ -92,12 +92,12 @@ def _next_worker_id(server: "Arbiter"): return free.pop() -def on_reload(server: "Arbiter"): +def on_reload(server: Arbiter): """Add a full set of ids into overload so it can be reused once.""" server._worker_id_overload = set(range(1, server.cfg.workers + 1)) -def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker): +def pre_fork(server: Arbiter, worker: DjangoUvicornWorker): """Attach the next free worker_id before forking off.""" worker._worker_id = _next_worker_id(server) diff --git a/packages/django-channels-postgres/django_channels_postgres/layer.py b/packages/django-channels-postgres/django_channels_postgres/layer.py index e5e20ebcf0..ff8388bb1b 100644 --- a/packages/django-channels-postgres/django_channels_postgres/layer.py +++ b/packages/django-channels-postgres/django_channels_postgres/layer.py @@ -29,7 +29,7 @@ MESSAGE_TABLE = Message._meta.db_table async def _async_proxy( - obj: "PostgresChannelLayerLoopProxy", + obj: PostgresChannelLayerLoopProxy, name: str, *args: Any, **kwargs: Any, @@ -40,7 +40,7 @@ async def _async_proxy( return await getattr(layer, name)(*args, **kwargs) -def _wrap_close(proxy: "PostgresChannelLayerLoopProxy", loop: asyncio.AbstractEventLoop) -> None: +def _wrap_close(proxy: PostgresChannelLayerLoopProxy, loop: asyncio.AbstractEventLoop) -> None: original_impl = loop.close def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None: @@ -90,7 +90,7 @@ class PostgresChannelLayerLoopProxy: m = zlib.decompress(message) return cast(dict[str, Any], msgpack.unpackb(m, raw=False)) - def _get_layer(self) -> "PostgresChannelLoopLayer": + def _get_layer(self) -> PostgresChannelLoopLayer: loop = asyncio.get_running_loop() try: diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py index 25c157dbd5..edc4a4621f 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/broker.py @@ -94,7 +94,7 @@ class PostgresBroker(Broker): return cast(DatabaseWrapper, connections[self.db_alias]) @property - def consumer_class(self) -> "type[_PostgresConsumer]": + def consumer_class(self) -> type[_PostgresConsumer]: return _PostgresConsumer @cached_property diff --git a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py index f2a3d47708..11763196d2 100644 --- a/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py +++ b/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py @@ -71,7 +71,7 @@ class DbConnectionMiddleware(Middleware): class TaskStateBeforeMiddleware(Middleware): - def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None: + def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, @@ -82,7 +82,7 @@ class TaskStateBeforeMiddleware(Middleware): class TaskStateAfterMiddleware(Middleware): - def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None: + def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, @@ -91,7 +91,7 @@ class TaskStateAfterMiddleware(Middleware): state=TaskState.RUNNING, ) - def after_skip_message(self, broker: "PostgresBroker", message: Message[Any]) -> None: + def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, @@ -102,7 +102,7 @@ class TaskStateAfterMiddleware(Middleware): def after_process_message( self, - broker: "PostgresBroker", + broker: PostgresBroker, message: Message[Any], *, result: Any | None = None, diff --git a/pyproject.toml b/pyproject.toml index da143a557d..05080c95da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,7 @@ exclude = 'node_modules' [tool.ruff] line-length = 100 -target-version = "py313" +target-version = "py314" exclude = ["**/migrations/**", "**/node_modules/**"] [tool.ruff.lint]