root: upgrade ruff lint for 3.14 (#19461)

* root: upgrade ruff lint for 3.14

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* redo makefile

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2026-01-15 03:15:02 +01:00
committed by GitHub
parent 51a7eb96fb
commit 06ed43002f
63 changed files with 197 additions and 200 deletions
+42 -41
View File
@@ -24,6 +24,8 @@ BREW_LDFLAGS :=
BREW_CPPFLAGS :=
BREW_PKG_CONFIG_PATH :=
UV := uv
# For macOS users, add the libxml2 installed from brew libxmlsec1 to the build path
# to prevent SAML-related tests from failing and ensure correct pip dependency compilation
ifeq ($(UNAME_S),Darwin)
@@ -33,22 +35,21 @@ ifeq ($(UNAME_S),Darwin)
LIBXML2_EXISTS := $(shell brew list libxml2 2> /dev/null)
ifdef LIBXML2_EXISTS
_xml_pref := $(shell brew --prefix libxml2)
BREW_LDFLAGS += -L${_xml_pref}/lib $(LDFLAGS)
BREW_CPPFLAGS += -I${_xml_pref}/include $(CPPFLAGS)
BREW_PKG_CONFIG_PATH += ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
BREW_LDFLAGS += -L${_xml_pref}/lib
BREW_CPPFLAGS += -I${_xml_pref}/include
BREW_PKG_CONFIG_PATH = ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
endif
KRB5_EXISTS := $(shell brew list krb5 2> /dev/null)
ifdef KRB5_EXISTS
_krb5_pref := $(shell brew --prefix krb5)
BREW_LDFLAGS += -L${_krb5_pref}/lib $(LDFLAGS)
BREW_CPPFLAGS += -I${_krb5_pref}/include $(CPPFLAGS)
BREW_PKG_CONFIG_PATH += ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
BREW_LDFLAGS += -L${_krb5_pref}/lib
BREW_CPPFLAGS += -I${_krb5_pref}/include
BREW_PKG_CONFIG_PATH = ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
endif
endif
UV := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)" uv
endif
UV_ARGS := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)"
all: lint-fix lint gen web test ## Lint, build, and test everything
HELP_WIDTH := $(shell grep -h '^[a-z][^ ]*:.*\#\#' $(MAKEFILE_LIST) 2>/dev/null | \
@@ -65,47 +66,47 @@ go-test:
go test -timeout 0 -v -race -cover ./...
test: ## Run the server tests and produce a coverage report (locally)
$(KRB_PATH) uv run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik)
uv run coverage html
uv run coverage report
$(UV) run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik)
$(UV) run coverage html
$(UV) run coverage report
lint-fix: lint-codespell ## Lint and automatically fix errors in the python source code. Reports spelling errors.
uv run black $(PY_SOURCES)
uv run ruff check --fix $(PY_SOURCES)
$(UV) run black $(PY_SOURCES)
$(UV) run ruff check --fix $(PY_SOURCES)
lint-codespell: ## Reports spelling errors.
uv run codespell -w
$(UV) run codespell -w
lint: ## Lint the python and golang sources
uv run bandit -c pyproject.toml -r $(PY_SOURCES)
$(UV) run bandit -c pyproject.toml -r $(PY_SOURCES)
golangci-lint run -v
core-install:
ifeq ($(UNAME_S),Darwin)
# Clear cache to ensure fresh compilation
uv cache clean
$(UV) cache clean
# Force compilation from source for lxml and xmlsec with correct environment
$(UV_ARGS) uv sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec
$(UV) sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec
else
uv sync --frozen
$(UV) sync --frozen
endif
migrate: ## Run the Authentik Django server's migrations
uv run python -m lifecycle.migrate
$(UV) run python -m lifecycle.migrate
i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that require translation into files to send to a translation service
aws-cfn:
cd lifecycle/aws && npm i && uv run npm run aws-cfn
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
run-server: ## Run the main authentik server process
uv run ak server
$(UV) run ak server
run-worker: ## Run the main authentik worker process
uv run ak worker
$(UV) run ak worker
core-i18n-extract:
uv run ak makemessages \
$(UV) run ak makemessages \
--add-location file \
--no-obsolete \
--ignore web \
@@ -118,17 +119,17 @@ core-i18n-extract:
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
dev-drop-db:
$(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null))
$(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null))
$(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null))
$(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null))
$(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null))
$(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null))
dropdb -U ${pg_user} -h ${pg_host} ${pg_name} || true
# Also remove the test-db if it exists
dropdb -U ${pg_user} -h ${pg_host} test_${pg_name} || true
dev-create-db:
$(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null))
$(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null))
$(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null))
$(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null))
$(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null))
$(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null))
createdb -U ${pg_user} -h ${pg_host} ${pg_name}
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
@@ -156,10 +157,10 @@ gen-build: ## Extract the schema from the database
AUTHENTIK_DEBUG=true \
AUTHENTIK_TENANTS__ENABLED=true \
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
uv run ak build_schema
$(UV) run ak build_schema
gen-compose:
uv run scripts/generate_compose.py
$(UV) run scripts/generate_compose.py
gen-changelog: ## (Release) generate the changelog based from the commits since the last tag
git log --pretty=format:" - %s" $(shell git describe --tags $(shell git rev-list --tags --max-count=1))...$(shell git branch --show-current) | sort > changelog.md
@@ -227,7 +228,7 @@ endif
go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO}
gen-dev-config: ## Generate a local development config file
uv run scripts/generate_config.py
$(UV) run scripts/generate_config.py
gen: gen-build gen-client-ts
@@ -327,24 +328,24 @@ ci--meta-debug:
node --version
ci-mypy: ci--meta-debug
uv run mypy --strict $(PY_SOURCES)
$(UV) run mypy --strict $(PY_SOURCES)
ci-black: ci--meta-debug
uv run black --check $(PY_SOURCES)
$(UV) run black --check $(PY_SOURCES)
ci-ruff: ci--meta-debug
uv run ruff check $(PY_SOURCES)
$(UV) run ruff check $(PY_SOURCES)
ci-codespell: ci--meta-debug
uv run codespell -s
$(UV) run codespell -s
ci-bandit: ci--meta-debug
uv run bandit -r $(PY_SOURCES)
$(UV) run bandit -r $(PY_SOURCES)
ci-pending-migrations: ci--meta-debug
uv run ak makemigrations --check
$(UV) run ak makemigrations --check
ci-test: ci--meta-debug
uv run coverage run manage.py test --keepdb authentik
uv run coverage report
uv run coverage xml
$(UV) run coverage run manage.py test --keepdb authentik
$(UV) run coverage report
$(UV) run coverage xml
+29 -31
View File
@@ -9,7 +9,7 @@ from functools import reduce
from json import JSONDecodeError, loads
from operator import ixor
from os import getenv
from typing import Any, Literal, Union
from typing import Any, Literal
from uuid import UUID
from deepmerge import always_merger
@@ -70,19 +70,17 @@ class BlueprintEntryDesiredState(Enum):
class BlueprintEntryPermission:
"""Describe object-level permissions"""
permission: Union[str, "YAMLTag"]
user: Union[int, "YAMLTag", None] = field(default=None)
role: Union[str, "YAMLTag", None] = field(default=None)
permission: str | YAMLTag
user: int | YAMLTag | None = field(default=None)
role: str | YAMLTag | None = field(default=None)
@dataclass
class BlueprintEntry:
"""Single entry of a blueprint"""
model: Union[str, "YAMLTag"]
state: Union[BlueprintEntryDesiredState, "YAMLTag"] = field(
default=BlueprintEntryDesiredState.PRESENT
)
model: str | YAMLTag
state: BlueprintEntryDesiredState | YAMLTag = field(default=BlueprintEntryDesiredState.PRESENT)
conditions: list[Any] = field(default_factory=list)
identifiers: dict[str, Any] = field(default_factory=dict)
attrs: dict[str, Any] | None = field(default_factory=dict)
@@ -96,7 +94,7 @@ class BlueprintEntry:
self.__tag_contexts: list[YAMLTagContext] = []
@staticmethod
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
def from_model(model: SerializerModel, *extra_identifier_names: str) -> BlueprintEntry:
"""Convert a SerializerModel instance to a blueprint Entry"""
identifiers = {
"pk": model.pk,
@@ -114,8 +112,8 @@ class BlueprintEntry:
def get_tag_context(
self,
depth: int = 0,
context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None,
) -> "YAMLTagContext":
context_tag_type: type[YAMLTagContext] | tuple[YAMLTagContext, ...] | None = None,
) -> YAMLTagContext:
"""Get a YAMLTagContext object located at a certain depth in the tag tree"""
if depth < 0:
raise ValueError("depth must be a positive number or zero")
@@ -130,7 +128,7 @@ class BlueprintEntry:
except IndexError as exc:
raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc
def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any:
def tag_resolver(self, value: Any, blueprint: Blueprint) -> Any:
"""Check if we have any special tags that need handling"""
val = copy(value)
@@ -152,23 +150,23 @@ class BlueprintEntry:
return val
def get_attrs(self, blueprint: "Blueprint") -> dict[str, Any]:
def get_attrs(self, blueprint: Blueprint) -> dict[str, Any]:
"""Get attributes of this entry, with all yaml tags resolved"""
return self.tag_resolver(self.attrs, blueprint)
def get_identifiers(self, blueprint: "Blueprint") -> dict[str, Any]:
def get_identifiers(self, blueprint: Blueprint) -> dict[str, Any]:
"""Get attributes of this entry, with all yaml tags resolved"""
return self.tag_resolver(self.identifiers, blueprint)
def get_state(self, blueprint: "Blueprint") -> BlueprintEntryDesiredState:
def get_state(self, blueprint: Blueprint) -> BlueprintEntryDesiredState:
"""Get the blueprint state, with yaml tags resolved if present"""
return BlueprintEntryDesiredState(self.tag_resolver(self.state, blueprint))
def get_model(self, blueprint: "Blueprint") -> str:
def get_model(self, blueprint: Blueprint) -> str:
"""Get the blueprint model, with yaml tags resolved if present"""
return str(self.tag_resolver(self.model, blueprint))
def get_permissions(self, blueprint: "Blueprint") -> Generator[BlueprintEntryPermission]:
def get_permissions(self, blueprint: Blueprint) -> Generator[BlueprintEntryPermission]:
"""Get permissions of this entry, with all yaml tags resolved"""
for perm in self.permissions:
yield BlueprintEntryPermission(
@@ -177,7 +175,7 @@ class BlueprintEntry:
role=self.tag_resolver(perm.role, blueprint),
)
def check_all_conditions_match(self, blueprint: "Blueprint") -> bool:
def check_all_conditions_match(self, blueprint: Blueprint) -> bool:
"""Check all conditions of this entry match (evaluate to True)"""
return all(self.tag_resolver(self.conditions, blueprint))
@@ -232,7 +230,7 @@ class KeyOf(YAMLTag):
id_from: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None:
super().__init__()
self.id_from = node.value
@@ -258,7 +256,7 @@ class Env(YAMLTag):
key: str
default: Any | None
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
super().__init__()
self.default = None
if isinstance(node, ScalarNode):
@@ -277,7 +275,7 @@ class File(YAMLTag):
path: str
default: Any | None
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
super().__init__()
self.default = None
if isinstance(node, ScalarNode):
@@ -305,7 +303,7 @@ class Context(YAMLTag):
key: str
default: Any | None
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
super().__init__()
self.default = None
if isinstance(node, ScalarNode):
@@ -328,7 +326,7 @@ class ParseJSON(YAMLTag):
raw: str
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None:
super().__init__()
self.raw = node.value
@@ -345,7 +343,7 @@ class Format(YAMLTag):
format_string: str
args: list[Any]
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.format_string = loader.construct_object(node.value[0])
self.args = []
@@ -372,7 +370,7 @@ class Find(YAMLTag):
model_name: str | YAMLTag
conditions: list[list]
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.model_name = loader.construct_object(node.value[0])
self.conditions = []
@@ -444,7 +442,7 @@ class Condition(YAMLTag):
"XNOR": lambda args: not (reduce(ixor, args) if len(args) > 1 else args[0]),
}
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.mode = loader.construct_object(node.value[0])
self.args = []
@@ -478,7 +476,7 @@ class If(YAMLTag):
when_true: Any
when_false: Any
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.condition = loader.construct_object(node.value[0])
if len(node.value) == 1:
@@ -518,7 +516,7 @@ class Enumerate(YAMLTag, YAMLTagContext):
),
}
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.iterable = loader.construct_object(node.value[0])
self.output_body = loader.construct_object(node.value[1])
@@ -584,7 +582,7 @@ class EnumeratedItem(YAMLTag):
_SUPPORTED_CONTEXT_TAGS = (Enumerate,)
def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None:
def __init__(self, _loader: BlueprintLoader, node: ScalarNode) -> None:
super().__init__()
self.depth = int(node.value)
@@ -640,7 +638,7 @@ class AtIndex(YAMLTag):
attribute: int | str | YAMLTag
default: Any | UNSET
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
super().__init__()
self.obj = loader.construct_object(node.value[0])
self.attribute = loader.construct_object(node.value[1])
@@ -757,7 +755,7 @@ class EntryInvalidError(SentryIgnoredException):
@staticmethod
def from_entry(
msg_or_exc: str | Exception, entry: BlueprintEntry, *args, **kwargs
) -> "EntryInvalidError":
) -> EntryInvalidError:
"""Create EntryInvalidError with the context of an entry"""
error = EntryInvalidError(msg_or_exc, *args, **kwargs)
if isinstance(msg_or_exc, ValidationError):
+1 -1
View File
@@ -147,7 +147,7 @@ class Importer:
}
@staticmethod
def from_string(yaml_input: str, context: dict | None = None) -> "Importer":
def from_string(yaml_input: str, context: dict | None = None) -> Importer:
"""Parse YAML string and create blueprint importer from it"""
import_dict = load(yaml_input, BlueprintLoader)
try:
@@ -23,7 +23,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
# We cannot override `instance` as that will confuse rest_framework
# and make it attempt to update the instance
blueprint_instance: "BlueprintInstance"
blueprint_instance: BlueprintInstance
def validate(self, attrs):
from authentik.blueprints.models import BlueprintInstance
+9 -11
View File
@@ -3,7 +3,7 @@
from datetime import datetime
from enum import StrEnum
from hashlib import sha256
from typing import Any, Optional, Self
from typing import Any, Self
from uuid import uuid4
import pgtrigger
@@ -225,7 +225,7 @@ class Group(SerializerModel, AttributesMixin):
# in the LDAP Outpost we use the last 5 chars so match here
return int(str(self.pk.int)[:5])
def is_member(self, user: "User") -> bool:
def is_member(self, user: User) -> bool:
"""Recursively check if `user` is member of us, or any parent."""
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
@@ -466,7 +466,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
always_merger.merge(final_attributes, self.attributes)
return final_attributes
def app_entitlements(self, app: "Application | None") -> QuerySet["ApplicationEntitlement"]:
def app_entitlements(self, app: Application | None) -> QuerySet[ApplicationEntitlement]:
"""Get all entitlements this user has for `app`."""
if not app:
return []
@@ -485,7 +485,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
).order_by("name")
return qs
def app_entitlements_attributes(self, app: "Application | None") -> dict:
def app_entitlements_attributes(self, app: Application | None) -> dict:
"""Get a dictionary containing all merged attributes from app entitlements for `app`."""
final_attributes = {}
for attrs in self.app_entitlements(app).values_list("attributes", flat=True):
@@ -654,7 +654,7 @@ class BackchannelProvider(Provider):
class ApplicationQuerySet(QuerySet):
def with_provider(self) -> "QuerySet[Application]":
def with_provider(self) -> QuerySet[Application]:
qs = self.select_related("provider")
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
qs = qs.select_related(f"provider__{subclass}")
@@ -713,9 +713,7 @@ class Application(SerializerModel, PolicyBindingModel):
return get_file_manager(FileUsage.MEDIA).file_url(self.meta_icon)
def get_launch_url(
self, user: Optional["User"] = None, user_data: dict | None = None
) -> str | None:
def get_launch_url(self, user: User | None = None, user_data: dict | None = None) -> str | None:
"""Get launch URL if set, otherwise attempt to get launch URL based on provider.
Args:
@@ -948,7 +946,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
raise NotImplementedError
@property
def property_mapping_type(self) -> "type[PropertyMapping]":
def property_mapping_type(self) -> type[PropertyMapping]:
"""Return property mapping type used by this object"""
if self.managed == self.MANAGED_INBUILT:
from authentik.core.models import PropertyMapping
@@ -1069,7 +1067,7 @@ class ExpiringModel(models.Model):
return self.delete(*args, **kwargs)
@classmethod
def filter_not_expired(cls, **kwargs) -> QuerySet["Self"]:
def filter_not_expired(cls, **kwargs) -> QuerySet[Self]:
"""Filer for tokens which are not expired yet or are not expiring,
and match filters in `kwargs`"""
for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)):
@@ -1265,7 +1263,7 @@ class AuthenticatedSession(SerializerModel):
return f"Authenticated Session {str(self.pk)[:10]}"
@staticmethod
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
def from_request(request: HttpRequest, user: User) -> AuthenticatedSession | None:
"""Create a new session from a http request"""
if not hasattr(request, "session") or not request.session.exists(
request.session.session_key
+1 -1
View File
@@ -63,7 +63,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
@receiver(post_delete, sender=AuthenticatedSession)
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
def authenticated_session_delete(sender: type[Model], instance: AuthenticatedSession, **_):
"""Delete session when authenticated session is deleted"""
Session.objects.filter(session_key=instance.pk).delete()
+1 -1
View File
@@ -49,7 +49,7 @@ class SourceMapper:
def build_object_properties(
self,
object_type: type[User | Group],
manager: "PropertyMappingManager | None" = None,
manager: PropertyMappingManager | None = None,
user: User | None = None,
request: HttpRequest | None = None,
**kwargs,
@@ -62,7 +62,7 @@ class AgentConfigSerializer(PassiveSerializer):
def get_system_config(self, instance: AgentConnector) -> ConfigSerializer:
return ConfigView.get_config(self.context["request"]).data
def get_license_status(self, instance: AgentConnector) -> "LicenseUsageStatus":
def get_license_status(self, instance: AgentConnector) -> LicenseUsageStatus:
try:
from authentik.enterprise.license import LicenseKey
@@ -68,7 +68,7 @@ class AgentConnector(Connector):
return AuthenticatorEndpointStageView
@property
def controller(self) -> type["AgentConnectorController"]:
def controller(self) -> type[AgentConnectorController]:
from authentik.endpoints.connectors.agent.controller import AgentConnectorController
return AgentConnectorController
+4 -4
View File
@@ -43,7 +43,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi
return f"goauthentik.io/endpoints/devices/{self.device_uuid}/facts"
@property
def cached_facts(self) -> "DeviceFactSnapshot":
def cached_facts(self) -> DeviceFactSnapshot:
if cached := cache.get(self.cache_key_facts):
return cached
facts = self.facts
@@ -51,7 +51,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi
return facts
@property
def facts(self) -> "DeviceFactSnapshot":
def facts(self) -> DeviceFactSnapshot:
data = {}
last_updated = datetime.fromtimestamp(0, UTC)
for snapshot_data, snapshort_created in DeviceFactSnapshot.filter_not_expired(
@@ -157,7 +157,7 @@ class Connector(ScheduledModel, SerializerModel):
raise NotImplementedError
@property
def controller(self) -> type["BaseController[Connector]"]:
def controller(self) -> type[BaseController[Connector]]:
raise NotImplementedError
@property
@@ -205,7 +205,7 @@ class EndpointStage(Stage):
mode = models.TextField(choices=StageMode.choices, default=StageMode.OPTIONAL)
@property
def view(self) -> type["StageView"]:
def view(self) -> type[StageView]:
from authentik.endpoints.stage import EndpointStageView
return EndpointStageView
+2 -2
View File
@@ -93,7 +93,7 @@ class LicenseKey:
license_flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
def validate(jwt: str, check_expiry=True) -> LicenseKey:
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
@@ -128,7 +128,7 @@ class LicenseKey:
return body
@staticmethod
def get_total() -> "LicenseKey":
def get_total() -> LicenseKey:
"""Get a summarized version of all (not expired) licenses"""
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in License.objects.all():
+1 -1
View File
@@ -50,7 +50,7 @@ class License(SerializerModel):
return LicenseSerializer
@property
def status(self) -> "LicenseKey":
def status(self) -> LicenseKey:
"""Get parsed license status"""
from authentik.enterprise.license import LicenseKey
@@ -19,7 +19,7 @@ class SCIMOAuthException(SCIMRequestException):
class SCIMOAuthAuth:
def __init__(self, provider: "SCIMProvider"):
def __init__(self, provider: SCIMProvider):
self.provider = provider
self.user = provider.auth_oauth_user
self.logger = get_logger().bind()
@@ -17,9 +17,9 @@ if TYPE_CHECKING:
class SSFTokenAuth(BaseAuthentication):
"""SSF Token auth"""
view: "SSFView"
view: SSFView
def __init__(self, view: "SSFView") -> None:
def __init__(self, view: SSFView) -> None:
super().__init__()
self.view = view
+3 -3
View File
@@ -1,6 +1,6 @@
"""ASN Enricher"""
from typing import TYPE_CHECKING, Optional, TypedDict
from typing import TYPE_CHECKING, TypedDict
from django.http import HttpRequest
from geoip2.errors import GeoIP2Error
@@ -27,7 +27,7 @@ class ASNDict(TypedDict):
class ASNContextProcessor(MMDBContextProcessor):
"""ASN Database reader wrapper"""
def capability(self) -> Optional["Capabilities"]:
def capability(self) -> Capabilities | None:
from authentik.api.v3.config import Capabilities
return Capabilities.CAN_ASN
@@ -35,7 +35,7 @@ class ASNContextProcessor(MMDBContextProcessor):
def path(self) -> str | None:
return CONFIG.get("events.context_processors.asn")
def enrich_event(self, event: "Event"):
def enrich_event(self, event: Event):
asn = self.asn_dict(event.client_ip)
if not asn:
return
+3 -3
View File
@@ -1,7 +1,7 @@
"""Base event enricher"""
from functools import cache
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from django.http import HttpRequest
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
class EventContextProcessor:
"""Base event enricher"""
def capability(self) -> Optional["Capabilities"]:
def capability(self) -> Capabilities | None:
"""Return the capability this context processor provides"""
return None
@@ -21,7 +21,7 @@ class EventContextProcessor:
"""Return true if this context processor is configured"""
return False
def enrich_event(self, event: "Event"):
def enrich_event(self, event: Event):
"""Modify event"""
raise NotImplementedError
+3 -3
View File
@@ -1,6 +1,6 @@
"""events GeoIP Reader"""
from typing import TYPE_CHECKING, Optional, TypedDict
from typing import TYPE_CHECKING, TypedDict
from django.http import HttpRequest
from geoip2.errors import GeoIP2Error
@@ -29,7 +29,7 @@ class GeoIPDict(TypedDict):
class GeoIPContextProcessor(MMDBContextProcessor):
"""Slim wrapper around GeoIP API"""
def capability(self) -> Optional["Capabilities"]:
def capability(self) -> Capabilities | None:
from authentik.api.v3.config import Capabilities
return Capabilities.CAN_GEO_IP
@@ -37,7 +37,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
def path(self) -> str | None:
return CONFIG.get("events.context_processors.geoip")
def enrich_event(self, event: "Event"):
def enrich_event(self, event: Event):
city = self.city_dict(event.client_ip)
if not city:
return
+1 -1
View File
@@ -25,7 +25,7 @@ class LogEvent:
attributes: dict[str, Any] = field(default_factory=dict)
@staticmethod
def from_event_dict(item: EventDict) -> "LogEvent":
def from_event_dict(item: EventDict) -> LogEvent:
event = item.pop("event")
log_level = item.pop("level").lower()
timestamp = datetime.fromisoformat(item.pop("timestamp")).replace(tzinfo=UTC)
+9 -9
View File
@@ -151,7 +151,7 @@ class Event(SerializerModel, ExpiringModel):
action: str | EventAction,
app: str | None = None,
**kwargs,
) -> "Event":
) -> Event:
"""Create new Event instance from arguments. Instance is NOT saved."""
if not isinstance(action, EventAction):
action = EventAction.CUSTOM_PREFIX + action
@@ -169,19 +169,19 @@ class Event(SerializerModel, ExpiringModel):
event = Event(action=action, app=app, context=cleaned_kwargs)
return event
def with_exception(self, exc: Exception) -> "Event":
def with_exception(self, exc: Exception) -> Event:
"""Add data from 'exc' to the event in a database-saveable format"""
self.context.setdefault("message", str(exc))
self.context["exception"] = exception_to_dict(exc)
return self
def set_user(self, user: User) -> "Event":
def set_user(self, user: User) -> Event:
"""Set `.user` based on user, ensuring the correct attributes are copied.
This should only be used when self.from_http is *not* used."""
self.user = get_user(user)
return self
def from_http(self, request: HttpRequest, user: User | None = None) -> "Event":
def from_http(self, request: HttpRequest, user: User | None = None) -> Event:
"""Add data from a Django-HttpRequest, allowing the creation of
Events independently from requests.
`user` arguments optionally overrides user from requests."""
@@ -343,7 +343,7 @@ class NotificationTransport(TasksModel, SerializerModel):
),
)
def send(self, notification: "Notification") -> list[str]:
def send(self, notification: Notification) -> list[str]:
"""Send notification to user, called from async task"""
if self.mode == TransportMode.LOCAL:
return self.send_local(notification)
@@ -355,7 +355,7 @@ class NotificationTransport(TasksModel, SerializerModel):
return self.send_email(notification)
raise ValueError(f"Invalid mode {self.mode} set")
def send_local(self, notification: "Notification") -> list[str]:
def send_local(self, notification: Notification) -> list[str]:
"""Local notification delivery"""
if self.webhook_mapping_body:
self.webhook_mapping_body.evaluate(
@@ -375,7 +375,7 @@ class NotificationTransport(TasksModel, SerializerModel):
)
return []
def send_webhook(self, notification: "Notification") -> list[str]:
def send_webhook(self, notification: Notification) -> list[str]:
"""Send notification to generic webhook"""
default_body = {
"body": notification.body,
@@ -419,7 +419,7 @@ class NotificationTransport(TasksModel, SerializerModel):
response.text,
]
def send_webhook_slack(self, notification: "Notification") -> list[str]:
def send_webhook_slack(self, notification: Notification) -> list[str]:
"""Send notification to slack or slack-compatible endpoints"""
fields = [
{
@@ -477,7 +477,7 @@ class NotificationTransport(TasksModel, SerializerModel):
response.text,
]
def send_email(self, notification: "Notification") -> list[str]:
def send_email(self, notification: Notification) -> list[str]:
"""Send notification via global email configuration"""
from authentik.stages.email.tasks import send_mail
+1 -1
View File
@@ -18,7 +18,7 @@ class DiagramElement:
identifier: str
description: str
action: str | None = None
source: list["DiagramElement"] | None = None
source: list[DiagramElement] | None = None
style: list[str] = field(default_factory=lambda: ["[", "]"])
+2 -2
View File
@@ -2,7 +2,7 @@
from dataclasses import asdict, is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Optional, TypedDict
from typing import TYPE_CHECKING, TypedDict
from uuid import UUID
from django.core.serializers.json import DjangoJSONEncoder
@@ -137,7 +137,7 @@ class PermissionDict(TypedDict):
class ChallengeResponse(PassiveSerializer):
"""Base class for all challenge responses"""
stage: Optional["StageView"]
stage: StageView | None
component = CharField(default="xak-flow-response-default")
def __init__(self, instance=None, data=None, **kwargs):
+2 -2
View File
@@ -23,7 +23,7 @@ class StageMarker:
def process(
self,
plan: "FlowPlan",
plan: FlowPlan,
binding: FlowStageBinding,
http_request: HttpRequest,
) -> FlowStageBinding | None:
@@ -40,7 +40,7 @@ class ReevaluateMarker(StageMarker):
def process(
self,
plan: "FlowPlan",
plan: FlowPlan,
binding: FlowStageBinding,
http_request: HttpRequest,
) -> FlowStageBinding | None:
+4 -4
View File
@@ -90,7 +90,7 @@ class Stage(SerializerModel):
objects = InheritanceManager()
@property
def view(self) -> type["StageView"]:
def view(self) -> type[StageView]:
"""Return StageView class that implements logic for this stage"""
# This is a bit of a workaround, since we can't set class methods with setattr
if hasattr(self, "__in_memory_type"):
@@ -117,7 +117,7 @@ class Stage(SerializerModel):
return f"Stage {self.name}"
def in_memory_stage(view: type["StageView"], **kwargs) -> Stage:
def in_memory_stage(view: type[StageView], **kwargs) -> Stage:
"""Creates an in-memory stage instance, based on a `view` as view.
Any key-word arguments are set as attributes on the stage object,
accessible via `self.executor.current_stage`."""
@@ -310,13 +310,13 @@ class FlowToken(InternallyManagedMixin, Token):
revoke_on_execution = models.BooleanField(default=True)
@staticmethod
def pickle(plan: "FlowPlan") -> str:
def pickle(plan: FlowPlan) -> str:
"""Pickle into string"""
data = dumps(plan)
return b64encode(data).decode()
@property
def plan(self) -> "FlowPlan":
def plan(self) -> FlowPlan:
"""Load Flow plan from pickled version"""
return loads(b64decode(self._plan.encode())) # nosec
+2 -2
View File
@@ -123,7 +123,7 @@ class FlowPlan:
def requires_flow_executor(
self,
allowed_silent_types: list["StageView"] | None = None,
allowed_silent_types: list[StageView] | None = None,
):
# Check if we actually need to show the Flow executor, or if we can jump straight to the end
found_unskippable = True
@@ -145,7 +145,7 @@ class FlowPlan:
request: HttpRequest,
flow: Flow,
next: str | None = None,
allowed_silent_types: list["StageView"] | None = None,
allowed_silent_types: list[StageView] | None = None,
) -> HttpResponse:
"""Redirect to the flow executor for this flow plan"""
from authentik.flows.views.executor import (
+3 -3
View File
@@ -46,13 +46,13 @@ HIST_FLOWS_STAGE_TIME = Histogram(
class StageView(View):
"""Abstract Stage"""
executor: "FlowExecutorView"
executor: FlowExecutorView
request: HttpRequest = None
logger: BoundLogger
def __init__(self, executor: "FlowExecutorView", **kwargs):
def __init__(self, executor: FlowExecutorView, **kwargs):
self.executor = executor
current_stage = getattr(self.executor, "current_stage", None)
self.logger = get_logger().bind(
@@ -257,7 +257,7 @@ class AccessDeniedStage(ChallengeStageView):
error_message: str | None
def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs):
def __init__(self, executor: FlowExecutorView, error_message: str | None = None, **kwargs):
super().__init__(executor, **kwargs)
self.error_message = error_message
+6 -6
View File
@@ -37,18 +37,18 @@ SVG_FONTS = [
]
def avatar_mode_none(user: "User", mode: str) -> str | None:
def avatar_mode_none(user: User, mode: str) -> str | None:
"""No avatar"""
return DEFAULT_AVATAR
def avatar_mode_attribute(user: "User", mode: str) -> str | None:
def avatar_mode_attribute(user: User, mode: str) -> str | None:
"""Avatars based on a user attribute"""
avatar = get_path_from_dict(user.attributes, mode[11:], default=None)
return avatar
def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
def avatar_mode_gravatar(user: User, mode: str) -> str | None:
"""Gravatar avatars"""
mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec
@@ -141,7 +141,7 @@ def generate_avatar_from_name(
return etree.tostring(root_element).decode()
def avatar_mode_generated(user: "User", mode: str) -> str | None:
def avatar_mode_generated(user: User, mode: str) -> str | None:
"""Wrapper that converts generated avatar to base64 svg"""
# By default generate based off of user's display name
name = user.name.strip()
@@ -155,7 +155,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None:
return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}"
def avatar_mode_url(user: "User", mode: str) -> str | None:
def avatar_mode_url(user: User, mode: str) -> str | None:
"""Format url"""
mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
@@ -197,7 +197,7 @@ def avatar_mode_url(user: "User", mode: str) -> str | None:
return formatted_url
def get_avatar(user: "User", request: HttpRequest | None = None) -> str:
def get_avatar(user: User, request: HttpRequest | None = None) -> str:
"""Get avatar with configured mode"""
mode_map = {
"none": avatar_mode_none,
+1 -1
View File
@@ -238,7 +238,7 @@ class BaseEvaluator:
address: str | list[str],
subject: str,
body: str | None = None,
stage: "EmailStage | None" = None,
stage: EmailStage | None = None,
template: str | None = None,
context: dict | None = None,
) -> bool:
+2 -2
View File
@@ -36,9 +36,9 @@ def get_version() -> str:
class KubernetesObjectReconciler[T]:
"""Base Kubernetes Reconciler, handles the basic logic."""
controller: "KubernetesController"
controller: KubernetesController
def __init__(self, controller: "KubernetesController"):
def __init__(self, controller: KubernetesController):
self.controller = controller
self.namespace = controller.outpost.config.kubernetes_namespace
self.logger = get_logger().bind(type=self.__class__.__name__)
@@ -39,7 +39,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
outpost: Outpost
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api = AppsV1Api(controller.client)
self.outpost = self.controller.outpost
+1 -1
View File
@@ -21,7 +21,7 @@ def b64string(source: str) -> str:
class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
"""Kubernetes Secret Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api = CoreV1Api(controller.client)
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
"""Kubernetes Service Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api = CoreV1Api(controller.client)
@@ -65,7 +65,7 @@ CRD_PLURAL = "servicemonitors"
class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusServiceMonitor]):
"""Kubernetes Prometheus ServiceMonitor Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api_ex = ApiextensionsV1Api(controller.client)
self.api = CustomObjectsApi(controller.client)
+3 -3
View File
@@ -304,7 +304,7 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
return f"goauthentik.io/outposts/state/{self.uuid.hex}"
@property
def state(self) -> list["OutpostState"]:
def state(self) -> list[OutpostState]:
"""Get outpost's health status"""
return OutpostState.for_outpost(self)
@@ -480,7 +480,7 @@ class OutpostState:
return parse(self.version) != OUR_VERSION
@staticmethod
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
def for_outpost(outpost: Outpost) -> list[OutpostState]:
"""Get all states for an outpost"""
keys = cache.keys(f"{outpost.state_cache_prefix}/*")
if not keys:
@@ -492,7 +492,7 @@ class OutpostState:
return states
@staticmethod
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
def for_instance_uid(outpost: Outpost, uid: str) -> OutpostState:
"""Get state for a single instance"""
key = f"{outpost.state_cache_prefix}/{uid}"
default_data = {"uid": uid}
+1 -1
View File
@@ -140,7 +140,7 @@ class PolicyEngine:
passing = False
self.__static_result = PolicyResult(passing)
def build(self) -> "PolicyEngine":
def build(self) -> PolicyEngine:
"""Build wrapper which monitors performance"""
with (
start_span(
+2 -2
View File
@@ -1,7 +1,7 @@
"""authentik expression policy evaluator"""
from ipaddress import ip_address
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING
from django.http import HttpRequest
from structlog.stdlib import get_logger
@@ -23,7 +23,7 @@ class PolicyEvaluator(BaseEvaluator):
_messages: list[str]
policy: Optional["ExpressionPolicy"] = None
policy: ExpressionPolicy | None = None
def __init__(self, policy_name: str | None = None):
super().__init__(policy_name or "PolicyEvaluator")
+1 -1
View File
@@ -55,7 +55,7 @@ class PolicyBindingModel(models.Model):
class BoundPolicyQuerySet(models.QuerySet):
"""QuerySet for filtering enabled bindings for a Policy type"""
def for_policy(self, policy: "Policy"):
def for_policy(self, policy: Policy):
return self.filter(policy__in=policy._default_manager.all()).filter(enabled=True)
+4 -4
View File
@@ -73,8 +73,8 @@ class IDToken:
@staticmethod
def new(
provider: "OAuth2Provider", token: "BaseGrantModel", request: HttpRequest, **kwargs
) -> "IDToken":
provider: OAuth2Provider, token: BaseGrantModel, request: HttpRequest, **kwargs
) -> IDToken:
"""Create ID Token"""
id_token = IDToken(provider, token, **kwargs)
id_token.exp = int(
@@ -147,7 +147,7 @@ class IDToken:
id_dict.update(self.claims)
return id_dict
def to_access_token(self, provider: "OAuth2Provider", token: "BaseGrantModel") -> str:
def to_access_token(self, provider: OAuth2Provider, token: BaseGrantModel) -> str:
"""Encode id_token for use as access token, adding fields"""
final = self.to_dict()
final["azp"] = provider.client_id
@@ -155,6 +155,6 @@ class IDToken:
final["scope"] = " ".join(token.scope)
return provider.encode(final)
def to_jwt(self, provider: "OAuth2Provider") -> str:
def to_jwt(self, provider: OAuth2Provider) -> str:
"""Shortcut to encode id_token to jwt, signed by self.provider"""
return provider.encode(self.to_dict())
+4 -4
View File
@@ -514,7 +514,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr
return f"Access Token for {self.provider_id} for user {self.user_id}"
@property
def id_token(self) -> "IDToken":
def id_token(self) -> IDToken:
"""Load ID Token from json"""
from authentik.providers.oauth2.id_token import IDToken
@@ -522,7 +522,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr
return from_dict(IDToken, raw_token)
@id_token.setter
def id_token(self, value: "IDToken"):
def id_token(self, value: IDToken):
self.token = value.to_access_token(self.provider, self)
self._id_token = json.dumps(asdict(value))
@@ -567,7 +567,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG
return f"Refresh Token for {self.provider_id} for user {self.user_id}"
@property
def id_token(self) -> "IDToken":
def id_token(self) -> IDToken:
"""Load ID Token from json"""
from authentik.providers.oauth2.id_token import IDToken
@@ -575,7 +575,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG
return from_dict(IDToken, raw_token)
@id_token.setter
def id_token(self, value: "IDToken"):
def id_token(self, value: IDToken):
self._id_token = json.dumps(asdict(value))
@property
@@ -105,7 +105,7 @@ class OAuthAuthorizationParams:
github_compat: InitVar[bool] = False
@staticmethod
def from_request(request: HttpRequest, github_compat=False) -> "OAuthAuthorizationParams":
def from_request(request: HttpRequest, github_compat=False) -> OAuthAuthorizationParams:
"""
Get all the params used by the Authorization Code Flow
(and also for the Implicit and Hybrid).
@@ -40,7 +40,7 @@ class TokenIntrospectionParams:
raise TokenIntrospectionError()
@staticmethod
def from_request(request: HttpRequest) -> "TokenIntrospectionParams":
def from_request(request: HttpRequest) -> TokenIntrospectionParams:
"""Extract required Parameters from HTTP Request"""
raw_token = request.POST.get("token")
provider = authenticate_provider(request)
+1 -1
View File
@@ -104,7 +104,7 @@ class TokenParams:
provider: OAuth2Provider,
client_id: str,
client_secret: str,
) -> "TokenParams":
) -> TokenParams:
"""Parse params for request"""
return TokenParams(
# Init vars
@@ -27,7 +27,7 @@ class TokenRevocationParams:
provider: OAuth2Provider
@staticmethod
def from_request(request: HttpRequest) -> "TokenRevocationParams":
def from_request(request: HttpRequest) -> TokenRevocationParams:
"""Extract required Parameters from HTTP Request"""
raw_token = request.POST.get("token")
@@ -82,7 +82,7 @@ class HTTPRoute:
class HTTPRouteReconciler(KubernetesObjectReconciler):
"""Kubernetes Gateway API HTTPRoute Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api_ex = ApiextensionsV1Api(controller.client)
self.api = CustomObjectsApi(controller.client)
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
"""Kubernetes Ingress Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api = NetworkingV1Api(controller.client)
@@ -12,7 +12,7 @@ from authentik.providers.proxy.controllers.k8s.traefik_3 import (
class TraefikMiddlewareReconciler(KubernetesObjectReconciler):
"""Kubernetes Traefik Middleware Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.reconciler = Traefik3MiddlewareReconciler(controller)
if not self.reconciler.crd_exists():
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
class Traefik2MiddlewareReconciler(Traefik3MiddlewareReconciler):
"""Kubernetes Traefik Middleware Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.crd_name = "middlewares.traefik.containo.us"
self.crd_group = "traefik.containo.us"
@@ -57,7 +57,7 @@ class TraefikMiddleware:
class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]):
"""Kubernetes Traefik Middleware Reconciler"""
def __init__(self, controller: "KubernetesController") -> None:
def __init__(self, controller: KubernetesController) -> None:
super().__init__(controller)
self.api_ex = ApiextensionsV1Api(controller.client)
self.api = CustomObjectsApi(controller.client)
+1 -1
View File
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
class SCIMTokenAuth:
def __init__(self, provider: "SCIMProvider"):
def __init__(self, provider: SCIMProvider):
self.provider = provider
def __call__(self, request: Request) -> Request:
+1 -1
View File
@@ -132,7 +132,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
return self._is_fallback
@staticmethod
def default() -> "ServiceProviderConfiguration":
def default() -> ServiceProviderConfiguration:
"""Get default configuration, which doesn't support any optional features as fallback"""
return ServiceProviderConfiguration(
patch=Patch(supported=False),
+1 -1
View File
@@ -183,7 +183,7 @@ class LDAPSource(IncomingSyncSource):
]
@property
def property_mapping_type(self) -> "type[PropertyMapping]":
def property_mapping_type(self) -> type[PropertyMapping]:
from authentik.sources.ldap.models import LDAPSourcePropertyMapping
return LDAPSourcePropertyMapping
+1 -1
View File
@@ -85,7 +85,7 @@ class OAuthSource(NonCreatableType, Source):
)
@property
def source_type(self) -> type["SourceType"]:
def source_type(self) -> type[SourceType]:
"""Return the provider instance for this source"""
from authentik.sources.oauth.types.registry import registry
@@ -222,7 +222,7 @@ class ResponseProcessor:
policy_context={},
)
def _get_name_id(self) -> "Element":
def _get_name_id(self) -> Element:
"""Get NameID Element"""
assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
if assertion is None:
+1 -1
View File
@@ -86,7 +86,7 @@ class TelegramSource(Source):
)
@property
def property_mapping_type(self) -> "type[PropertyMapping]":
def property_mapping_type(self) -> type[PropertyMapping]:
return TelegramSourcePropertyMapping
def get_base_user_properties(
+1 -1
View File
@@ -68,7 +68,7 @@ def match_token(user, token):
return device
def devices_for_user(user: "User", confirmed: bool | None = True, for_verify: bool = False):
def devices_for_user(user: User, confirmed: bool | None = True, for_verify: bool = False):
"""
Return an iterable of all devices registered to the given user.
@@ -100,7 +100,7 @@ class AuthenticatorEmailStage(ConfigurableStage, FriendlyNamedStage, Stage):
timeout=self.timeout,
)
def send(self, device: "EmailDevice"):
def send(self, device: EmailDevice):
# Lazy import here to avoid circular import
from authentik.stages.email.tasks import send_mails
+3 -3
View File
@@ -68,7 +68,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
help_text=_("Optionally modify the payload being sent to custom providers."),
)
def send(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send(self, request: HttpRequest, token: str, device: SMSDevice):
"""Send message via selected provider"""
if self.provider == SMSProviders.TWILIO:
return self.send_twilio(request, token, device)
@@ -80,7 +80,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
"""Get SMS message"""
return _("Use this code to authenticate in authentik: {token}".format_map({"token": token}))
def send_twilio(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send_twilio(self, request: HttpRequest, token: str, device: SMSDevice):
"""send sms via twilio provider"""
client = Client(self.account_sid, self.auth)
message_body = str(self.get_message(token))
@@ -105,7 +105,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg)
raise ValidationError(exc.msg) from None
def send_generic(self, request: HttpRequest, token: str, device: "SMSDevice"):
def send_generic(self, request: HttpRequest, token: str, device: SMSDevice):
"""Send SMS via outside API"""
payload = {
"From": self.from_number,
@@ -55,7 +55,7 @@ class DeviceChallenge(PassiveSerializer):
def get_challenge_for_device(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage, device: Device
) -> dict:
"""Generate challenge for a single device"""
if isinstance(device, WebAuthnDevice):
@@ -67,7 +67,7 @@ def get_challenge_for_device(
def get_webauthn_challenge_without_user(
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage
) -> dict:
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check
who the device belongs to."""
@@ -85,7 +85,7 @@ def get_webauthn_challenge_without_user(
def get_webauthn_challenge(
stage_view: "AuthenticatorValidateStageView",
stage_view: AuthenticatorValidateStageView,
stage: AuthenticatorValidateStage,
device: WebAuthnDevice | None = None,
) -> dict:
+1 -1
View File
@@ -41,7 +41,7 @@ class ScheduleSpec:
options["uid"] = self.uid
return pickle.dumps(options)
def update_or_create(self) -> "Schedule":
def update_or_create(self) -> Schedule:
from django.contrib.contenttypes.models import ContentType
from authentik.tasks.schedules.models import Schedule
+7 -7
View File
@@ -53,27 +53,27 @@ workers = CONFIG.get_int("web.workers", default_workers)
threads = CONFIG.get_int("web.threads", 4)
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
def post_fork(server: Arbiter, worker: DjangoUvicornWorker):
"""Tell prometheus to use worker number instead of process ID for multiprocess"""
from prometheus_client import values
values.ValueClass = MultiProcessValue(lambda: worker._worker_id)
def worker_exit(server: "Arbiter", worker: DjangoUvicornWorker):
def worker_exit(server: Arbiter, worker: DjangoUvicornWorker):
"""Remove pid dbs when worker is shutdown"""
from prometheus_client import multiprocess
multiprocess.mark_process_dead(worker._worker_id)
def on_starting(server: "Arbiter"):
def on_starting(server: Arbiter):
"""Attach a set of IDs that can be temporarily reused.
Used on reloads when each worker exists twice."""
server._worker_id_overload = set()
def nworkers_changed(server: "Arbiter", new_value, old_value):
def nworkers_changed(server: Arbiter, new_value, old_value):
"""Gets called on startup too.
Set the current number of workers. Required if we raise the worker count
temporarily using TTIN because server.cfg.workers won't be updated and if
@@ -81,7 +81,7 @@ def nworkers_changed(server: "Arbiter", new_value, old_value):
server._worker_id_current_workers = new_value
def _next_worker_id(server: "Arbiter"):
def _next_worker_id(server: Arbiter):
"""If there are IDs open for reuse, take one. Else look for a free one."""
if server._worker_id_overload:
return server._worker_id_overload.pop()
@@ -92,12 +92,12 @@ def _next_worker_id(server: "Arbiter"):
return free.pop()
def on_reload(server: "Arbiter"):
def on_reload(server: Arbiter):
"""Add a full set of ids into overload so it can be reused once."""
server._worker_id_overload = set(range(1, server.cfg.workers + 1))
def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
def pre_fork(server: Arbiter, worker: DjangoUvicornWorker):
"""Attach the next free worker_id before forking off."""
worker._worker_id = _next_worker_id(server)
@@ -29,7 +29,7 @@ MESSAGE_TABLE = Message._meta.db_table
async def _async_proxy(
obj: "PostgresChannelLayerLoopProxy",
obj: PostgresChannelLayerLoopProxy,
name: str,
*args: Any,
**kwargs: Any,
@@ -40,7 +40,7 @@ async def _async_proxy(
return await getattr(layer, name)(*args, **kwargs)
def _wrap_close(proxy: "PostgresChannelLayerLoopProxy", loop: asyncio.AbstractEventLoop) -> None:
def _wrap_close(proxy: PostgresChannelLayerLoopProxy, loop: asyncio.AbstractEventLoop) -> None:
original_impl = loop.close
def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None:
@@ -90,7 +90,7 @@ class PostgresChannelLayerLoopProxy:
m = zlib.decompress(message)
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
def _get_layer(self) -> "PostgresChannelLoopLayer":
def _get_layer(self) -> PostgresChannelLoopLayer:
loop = asyncio.get_running_loop()
try:
@@ -94,7 +94,7 @@ class PostgresBroker(Broker):
return cast(DatabaseWrapper, connections[self.db_alias])
@property
def consumer_class(self) -> "type[_PostgresConsumer]":
def consumer_class(self) -> type[_PostgresConsumer]:
return _PostgresConsumer
@cached_property
@@ -71,7 +71,7 @@ class DbConnectionMiddleware(Middleware):
class TaskStateBeforeMiddleware(Middleware):
def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -82,7 +82,7 @@ class TaskStateBeforeMiddleware(Middleware):
class TaskStateAfterMiddleware(Middleware):
def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -91,7 +91,7 @@ class TaskStateAfterMiddleware(Middleware):
state=TaskState.RUNNING,
)
def after_skip_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
@@ -102,7 +102,7 @@ class TaskStateAfterMiddleware(Middleware):
def after_process_message(
self,
broker: "PostgresBroker",
broker: PostgresBroker,
message: Message[Any],
*,
result: Any | None = None,
+1 -1
View File
@@ -185,7 +185,7 @@ exclude = 'node_modules'
[tool.ruff]
line-length = 100
target-version = "py313"
target-version = "py314"
exclude = ["**/migrations/**", "**/node_modules/**"]
[tool.ruff.lint]