mirror of
https://github.com/goauthentik/authentik.git
synced 2026-06-17 19:09:11 +03:00
root: upgrade ruff lint for 3.14 (#19461)
* root: upgrade ruff lint for 3.14 Signed-off-by: Jens Langhammer <jens@goauthentik.io> * redo makefile Signed-off-by: Jens Langhammer <jens@goauthentik.io> * format Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
@@ -24,6 +24,8 @@ BREW_LDFLAGS :=
|
||||
BREW_CPPFLAGS :=
|
||||
BREW_PKG_CONFIG_PATH :=
|
||||
|
||||
UV := uv
|
||||
|
||||
# For macOS users, add the libxml2 installed from brew libxmlsec1 to the build path
|
||||
# to prevent SAML-related tests from failing and ensure correct pip dependency compilation
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
@@ -33,22 +35,21 @@ ifeq ($(UNAME_S),Darwin)
|
||||
LIBXML2_EXISTS := $(shell brew list libxml2 2> /dev/null)
|
||||
ifdef LIBXML2_EXISTS
|
||||
_xml_pref := $(shell brew --prefix libxml2)
|
||||
BREW_LDFLAGS += -L${_xml_pref}/lib $(LDFLAGS)
|
||||
BREW_CPPFLAGS += -I${_xml_pref}/include $(CPPFLAGS)
|
||||
BREW_PKG_CONFIG_PATH += ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
|
||||
BREW_LDFLAGS += -L${_xml_pref}/lib
|
||||
BREW_CPPFLAGS += -I${_xml_pref}/include
|
||||
BREW_PKG_CONFIG_PATH = ${_xml_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
|
||||
endif
|
||||
KRB5_EXISTS := $(shell brew list krb5 2> /dev/null)
|
||||
ifdef KRB5_EXISTS
|
||||
_krb5_pref := $(shell brew --prefix krb5)
|
||||
BREW_LDFLAGS += -L${_krb5_pref}/lib $(LDFLAGS)
|
||||
BREW_CPPFLAGS += -I${_krb5_pref}/include $(CPPFLAGS)
|
||||
BREW_PKG_CONFIG_PATH += ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
|
||||
BREW_LDFLAGS += -L${_krb5_pref}/lib
|
||||
BREW_CPPFLAGS += -I${_krb5_pref}/include
|
||||
BREW_PKG_CONFIG_PATH = ${_krb5_pref}/lib/pkgconfig:$(PKG_CONFIG_PATH)
|
||||
endif
|
||||
endif
|
||||
UV := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)" uv
|
||||
endif
|
||||
|
||||
UV_ARGS := LDFLAGS="$(BREW_LDFLAGS)" CPPFLAGS="$(BREW_CPPFLAGS)" PKG_CONFIG_PATH="$(BREW_PKG_CONFIG_PATH)"
|
||||
|
||||
all: lint-fix lint gen web test ## Lint, build, and test everything
|
||||
|
||||
HELP_WIDTH := $(shell grep -h '^[a-z][^ ]*:.*\#\#' $(MAKEFILE_LIST) 2>/dev/null | \
|
||||
@@ -65,47 +66,47 @@ go-test:
|
||||
go test -timeout 0 -v -race -cover ./...
|
||||
|
||||
test: ## Run the server tests and produce a coverage report (locally)
|
||||
$(KRB_PATH) uv run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik)
|
||||
uv run coverage html
|
||||
uv run coverage report
|
||||
$(UV) run coverage run manage.py test --keepdb $(or $(filter-out $@,$(MAKECMDGOALS)),authentik)
|
||||
$(UV) run coverage html
|
||||
$(UV) run coverage report
|
||||
|
||||
lint-fix: lint-codespell ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
||||
uv run black $(PY_SOURCES)
|
||||
uv run ruff check --fix $(PY_SOURCES)
|
||||
$(UV) run black $(PY_SOURCES)
|
||||
$(UV) run ruff check --fix $(PY_SOURCES)
|
||||
|
||||
lint-codespell: ## Reports spelling errors.
|
||||
uv run codespell -w
|
||||
$(UV) run codespell -w
|
||||
|
||||
lint: ## Lint the python and golang sources
|
||||
uv run bandit -c pyproject.toml -r $(PY_SOURCES)
|
||||
$(UV) run bandit -c pyproject.toml -r $(PY_SOURCES)
|
||||
golangci-lint run -v
|
||||
|
||||
core-install:
|
||||
ifeq ($(UNAME_S),Darwin)
|
||||
# Clear cache to ensure fresh compilation
|
||||
uv cache clean
|
||||
$(UV) cache clean
|
||||
# Force compilation from source for lxml and xmlsec with correct environment
|
||||
$(UV_ARGS) uv sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec
|
||||
$(UV) sync --frozen --reinstall-package lxml --reinstall-package xmlsec --no-binary-package lxml --no-binary-package xmlsec
|
||||
else
|
||||
uv sync --frozen
|
||||
$(UV) sync --frozen
|
||||
endif
|
||||
|
||||
migrate: ## Run the Authentik Django server's migrations
|
||||
uv run python -m lifecycle.migrate
|
||||
$(UV) run python -m lifecycle.migrate
|
||||
|
||||
i18n-extract: core-i18n-extract web-i18n-extract ## Extract strings that require translation into files to send to a translation service
|
||||
|
||||
aws-cfn:
|
||||
cd lifecycle/aws && npm i && uv run npm run aws-cfn
|
||||
cd lifecycle/aws && npm i && $(UV) run npm run aws-cfn
|
||||
|
||||
run-server: ## Run the main authentik server process
|
||||
uv run ak server
|
||||
$(UV) run ak server
|
||||
|
||||
run-worker: ## Run the main authentik worker process
|
||||
uv run ak worker
|
||||
$(UV) run ak worker
|
||||
|
||||
core-i18n-extract:
|
||||
uv run ak makemessages \
|
||||
$(UV) run ak makemessages \
|
||||
--add-location file \
|
||||
--no-obsolete \
|
||||
--ignore web \
|
||||
@@ -118,17 +119,17 @@ core-i18n-extract:
|
||||
install: node-install docs-install core-install ## Install all requires dependencies for `node`, `docs` and `core`
|
||||
|
||||
dev-drop-db:
|
||||
$(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null))
|
||||
$(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null))
|
||||
$(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null))
|
||||
$(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null))
|
||||
$(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null))
|
||||
$(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null))
|
||||
dropdb -U ${pg_user} -h ${pg_host} ${pg_name} || true
|
||||
# Also remove the test-db if it exists
|
||||
dropdb -U ${pg_user} -h ${pg_host} test_${pg_name} || true
|
||||
|
||||
dev-create-db:
|
||||
$(eval pg_user := $(shell uv run python -m authentik.lib.config postgresql.user 2>/dev/null))
|
||||
$(eval pg_host := $(shell uv run python -m authentik.lib.config postgresql.host 2>/dev/null))
|
||||
$(eval pg_name := $(shell uv run python -m authentik.lib.config postgresql.name 2>/dev/null))
|
||||
$(eval pg_user := $(shell $(UV) run python -m authentik.lib.config postgresql.user 2>/dev/null))
|
||||
$(eval pg_host := $(shell $(UV) run python -m authentik.lib.config postgresql.host 2>/dev/null))
|
||||
$(eval pg_name := $(shell $(UV) run python -m authentik.lib.config postgresql.name 2>/dev/null))
|
||||
createdb -U ${pg_user} -h ${pg_host} ${pg_name}
|
||||
|
||||
dev-reset: dev-drop-db dev-create-db migrate ## Drop and restore the Authentik PostgreSQL instance to a "fresh install" state.
|
||||
@@ -156,10 +157,10 @@ gen-build: ## Extract the schema from the database
|
||||
AUTHENTIK_DEBUG=true \
|
||||
AUTHENTIK_TENANTS__ENABLED=true \
|
||||
AUTHENTIK_OUTPOSTS__DISABLE_EMBEDDED_OUTPOST=true \
|
||||
uv run ak build_schema
|
||||
$(UV) run ak build_schema
|
||||
|
||||
gen-compose:
|
||||
uv run scripts/generate_compose.py
|
||||
$(UV) run scripts/generate_compose.py
|
||||
|
||||
gen-changelog: ## (Release) generate the changelog based from the commits since the last tag
|
||||
git log --pretty=format:" - %s" $(shell git describe --tags $(shell git rev-list --tags --max-count=1))...$(shell git branch --show-current) | sort > changelog.md
|
||||
@@ -227,7 +228,7 @@ endif
|
||||
go mod edit -replace goauthentik.io/api/v3=./${GEN_API_GO}
|
||||
|
||||
gen-dev-config: ## Generate a local development config file
|
||||
uv run scripts/generate_config.py
|
||||
$(UV) run scripts/generate_config.py
|
||||
|
||||
gen: gen-build gen-client-ts
|
||||
|
||||
@@ -327,24 +328,24 @@ ci--meta-debug:
|
||||
node --version
|
||||
|
||||
ci-mypy: ci--meta-debug
|
||||
uv run mypy --strict $(PY_SOURCES)
|
||||
$(UV) run mypy --strict $(PY_SOURCES)
|
||||
|
||||
ci-black: ci--meta-debug
|
||||
uv run black --check $(PY_SOURCES)
|
||||
$(UV) run black --check $(PY_SOURCES)
|
||||
|
||||
ci-ruff: ci--meta-debug
|
||||
uv run ruff check $(PY_SOURCES)
|
||||
$(UV) run ruff check $(PY_SOURCES)
|
||||
|
||||
ci-codespell: ci--meta-debug
|
||||
uv run codespell -s
|
||||
$(UV) run codespell -s
|
||||
|
||||
ci-bandit: ci--meta-debug
|
||||
uv run bandit -r $(PY_SOURCES)
|
||||
$(UV) run bandit -r $(PY_SOURCES)
|
||||
|
||||
ci-pending-migrations: ci--meta-debug
|
||||
uv run ak makemigrations --check
|
||||
$(UV) run ak makemigrations --check
|
||||
|
||||
ci-test: ci--meta-debug
|
||||
uv run coverage run manage.py test --keepdb authentik
|
||||
uv run coverage report
|
||||
uv run coverage xml
|
||||
$(UV) run coverage run manage.py test --keepdb authentik
|
||||
$(UV) run coverage report
|
||||
$(UV) run coverage xml
|
||||
|
||||
@@ -9,7 +9,7 @@ from functools import reduce
|
||||
from json import JSONDecodeError, loads
|
||||
from operator import ixor
|
||||
from os import getenv
|
||||
from typing import Any, Literal, Union
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from deepmerge import always_merger
|
||||
@@ -70,19 +70,17 @@ class BlueprintEntryDesiredState(Enum):
|
||||
class BlueprintEntryPermission:
|
||||
"""Describe object-level permissions"""
|
||||
|
||||
permission: Union[str, "YAMLTag"]
|
||||
user: Union[int, "YAMLTag", None] = field(default=None)
|
||||
role: Union[str, "YAMLTag", None] = field(default=None)
|
||||
permission: str | YAMLTag
|
||||
user: int | YAMLTag | None = field(default=None)
|
||||
role: str | YAMLTag | None = field(default=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlueprintEntry:
|
||||
"""Single entry of a blueprint"""
|
||||
|
||||
model: Union[str, "YAMLTag"]
|
||||
state: Union[BlueprintEntryDesiredState, "YAMLTag"] = field(
|
||||
default=BlueprintEntryDesiredState.PRESENT
|
||||
)
|
||||
model: str | YAMLTag
|
||||
state: BlueprintEntryDesiredState | YAMLTag = field(default=BlueprintEntryDesiredState.PRESENT)
|
||||
conditions: list[Any] = field(default_factory=list)
|
||||
identifiers: dict[str, Any] = field(default_factory=dict)
|
||||
attrs: dict[str, Any] | None = field(default_factory=dict)
|
||||
@@ -96,7 +94,7 @@ class BlueprintEntry:
|
||||
self.__tag_contexts: list[YAMLTagContext] = []
|
||||
|
||||
@staticmethod
|
||||
def from_model(model: SerializerModel, *extra_identifier_names: str) -> "BlueprintEntry":
|
||||
def from_model(model: SerializerModel, *extra_identifier_names: str) -> BlueprintEntry:
|
||||
"""Convert a SerializerModel instance to a blueprint Entry"""
|
||||
identifiers = {
|
||||
"pk": model.pk,
|
||||
@@ -114,8 +112,8 @@ class BlueprintEntry:
|
||||
def get_tag_context(
|
||||
self,
|
||||
depth: int = 0,
|
||||
context_tag_type: type["YAMLTagContext"] | tuple["YAMLTagContext", ...] | None = None,
|
||||
) -> "YAMLTagContext":
|
||||
context_tag_type: type[YAMLTagContext] | tuple[YAMLTagContext, ...] | None = None,
|
||||
) -> YAMLTagContext:
|
||||
"""Get a YAMLTagContext object located at a certain depth in the tag tree"""
|
||||
if depth < 0:
|
||||
raise ValueError("depth must be a positive number or zero")
|
||||
@@ -130,7 +128,7 @@ class BlueprintEntry:
|
||||
except IndexError as exc:
|
||||
raise ValueError(f"invalid depth: {depth}. Max depth: {len(contexts) - 1}") from exc
|
||||
|
||||
def tag_resolver(self, value: Any, blueprint: "Blueprint") -> Any:
|
||||
def tag_resolver(self, value: Any, blueprint: Blueprint) -> Any:
|
||||
"""Check if we have any special tags that need handling"""
|
||||
val = copy(value)
|
||||
|
||||
@@ -152,23 +150,23 @@ class BlueprintEntry:
|
||||
|
||||
return val
|
||||
|
||||
def get_attrs(self, blueprint: "Blueprint") -> dict[str, Any]:
|
||||
def get_attrs(self, blueprint: Blueprint) -> dict[str, Any]:
|
||||
"""Get attributes of this entry, with all yaml tags resolved"""
|
||||
return self.tag_resolver(self.attrs, blueprint)
|
||||
|
||||
def get_identifiers(self, blueprint: "Blueprint") -> dict[str, Any]:
|
||||
def get_identifiers(self, blueprint: Blueprint) -> dict[str, Any]:
|
||||
"""Get attributes of this entry, with all yaml tags resolved"""
|
||||
return self.tag_resolver(self.identifiers, blueprint)
|
||||
|
||||
def get_state(self, blueprint: "Blueprint") -> BlueprintEntryDesiredState:
|
||||
def get_state(self, blueprint: Blueprint) -> BlueprintEntryDesiredState:
|
||||
"""Get the blueprint state, with yaml tags resolved if present"""
|
||||
return BlueprintEntryDesiredState(self.tag_resolver(self.state, blueprint))
|
||||
|
||||
def get_model(self, blueprint: "Blueprint") -> str:
|
||||
def get_model(self, blueprint: Blueprint) -> str:
|
||||
"""Get the blueprint model, with yaml tags resolved if present"""
|
||||
return str(self.tag_resolver(self.model, blueprint))
|
||||
|
||||
def get_permissions(self, blueprint: "Blueprint") -> Generator[BlueprintEntryPermission]:
|
||||
def get_permissions(self, blueprint: Blueprint) -> Generator[BlueprintEntryPermission]:
|
||||
"""Get permissions of this entry, with all yaml tags resolved"""
|
||||
for perm in self.permissions:
|
||||
yield BlueprintEntryPermission(
|
||||
@@ -177,7 +175,7 @@ class BlueprintEntry:
|
||||
role=self.tag_resolver(perm.role, blueprint),
|
||||
)
|
||||
|
||||
def check_all_conditions_match(self, blueprint: "Blueprint") -> bool:
|
||||
def check_all_conditions_match(self, blueprint: Blueprint) -> bool:
|
||||
"""Check all conditions of this entry match (evaluate to True)"""
|
||||
return all(self.tag_resolver(self.conditions, blueprint))
|
||||
|
||||
@@ -232,7 +230,7 @@ class KeyOf(YAMLTag):
|
||||
|
||||
id_from: str
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None:
|
||||
super().__init__()
|
||||
self.id_from = node.value
|
||||
|
||||
@@ -258,7 +256,7 @@ class Env(YAMLTag):
|
||||
key: str
|
||||
default: Any | None
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.default = None
|
||||
if isinstance(node, ScalarNode):
|
||||
@@ -277,7 +275,7 @@ class File(YAMLTag):
|
||||
path: str
|
||||
default: Any | None
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.default = None
|
||||
if isinstance(node, ScalarNode):
|
||||
@@ -305,7 +303,7 @@ class Context(YAMLTag):
|
||||
key: str
|
||||
default: Any | None
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode | SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: ScalarNode | SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.default = None
|
||||
if isinstance(node, ScalarNode):
|
||||
@@ -328,7 +326,7 @@ class ParseJSON(YAMLTag):
|
||||
|
||||
raw: str
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: ScalarNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: ScalarNode) -> None:
|
||||
super().__init__()
|
||||
self.raw = node.value
|
||||
|
||||
@@ -345,7 +343,7 @@ class Format(YAMLTag):
|
||||
format_string: str
|
||||
args: list[Any]
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.format_string = loader.construct_object(node.value[0])
|
||||
self.args = []
|
||||
@@ -372,7 +370,7 @@ class Find(YAMLTag):
|
||||
model_name: str | YAMLTag
|
||||
conditions: list[list]
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.model_name = loader.construct_object(node.value[0])
|
||||
self.conditions = []
|
||||
@@ -444,7 +442,7 @@ class Condition(YAMLTag):
|
||||
"XNOR": lambda args: not (reduce(ixor, args) if len(args) > 1 else args[0]),
|
||||
}
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.mode = loader.construct_object(node.value[0])
|
||||
self.args = []
|
||||
@@ -478,7 +476,7 @@ class If(YAMLTag):
|
||||
when_true: Any
|
||||
when_false: Any
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.condition = loader.construct_object(node.value[0])
|
||||
if len(node.value) == 1:
|
||||
@@ -518,7 +516,7 @@ class Enumerate(YAMLTag, YAMLTagContext):
|
||||
),
|
||||
}
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.iterable = loader.construct_object(node.value[0])
|
||||
self.output_body = loader.construct_object(node.value[1])
|
||||
@@ -584,7 +582,7 @@ class EnumeratedItem(YAMLTag):
|
||||
|
||||
_SUPPORTED_CONTEXT_TAGS = (Enumerate,)
|
||||
|
||||
def __init__(self, _loader: "BlueprintLoader", node: ScalarNode) -> None:
|
||||
def __init__(self, _loader: BlueprintLoader, node: ScalarNode) -> None:
|
||||
super().__init__()
|
||||
self.depth = int(node.value)
|
||||
|
||||
@@ -640,7 +638,7 @@ class AtIndex(YAMLTag):
|
||||
attribute: int | str | YAMLTag
|
||||
default: Any | UNSET
|
||||
|
||||
def __init__(self, loader: "BlueprintLoader", node: SequenceNode) -> None:
|
||||
def __init__(self, loader: BlueprintLoader, node: SequenceNode) -> None:
|
||||
super().__init__()
|
||||
self.obj = loader.construct_object(node.value[0])
|
||||
self.attribute = loader.construct_object(node.value[1])
|
||||
@@ -757,7 +755,7 @@ class EntryInvalidError(SentryIgnoredException):
|
||||
@staticmethod
|
||||
def from_entry(
|
||||
msg_or_exc: str | Exception, entry: BlueprintEntry, *args, **kwargs
|
||||
) -> "EntryInvalidError":
|
||||
) -> EntryInvalidError:
|
||||
"""Create EntryInvalidError with the context of an entry"""
|
||||
error = EntryInvalidError(msg_or_exc, *args, **kwargs)
|
||||
if isinstance(msg_or_exc, ValidationError):
|
||||
|
||||
@@ -147,7 +147,7 @@ class Importer:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def from_string(yaml_input: str, context: dict | None = None) -> "Importer":
|
||||
def from_string(yaml_input: str, context: dict | None = None) -> Importer:
|
||||
"""Parse YAML string and create blueprint importer from it"""
|
||||
import_dict = load(yaml_input, BlueprintLoader)
|
||||
try:
|
||||
|
||||
@@ -23,7 +23,7 @@ class ApplyBlueprintMetaSerializer(PassiveSerializer):
|
||||
|
||||
# We cannot override `instance` as that will confuse rest_framework
|
||||
# and make it attempt to update the instance
|
||||
blueprint_instance: "BlueprintInstance"
|
||||
blueprint_instance: BlueprintInstance
|
||||
|
||||
def validate(self, attrs):
|
||||
from authentik.blueprints.models import BlueprintInstance
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, Optional, Self
|
||||
from typing import Any, Self
|
||||
from uuid import uuid4
|
||||
|
||||
import pgtrigger
|
||||
@@ -225,7 +225,7 @@ class Group(SerializerModel, AttributesMixin):
|
||||
# in the LDAP Outpost we use the last 5 chars so match here
|
||||
return int(str(self.pk.int)[:5])
|
||||
|
||||
def is_member(self, user: "User") -> bool:
|
||||
def is_member(self, user: User) -> bool:
|
||||
"""Recursively check if `user` is member of us, or any parent."""
|
||||
return user.all_groups().filter(group_uuid=self.group_uuid).exists()
|
||||
|
||||
@@ -466,7 +466,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
always_merger.merge(final_attributes, self.attributes)
|
||||
return final_attributes
|
||||
|
||||
def app_entitlements(self, app: "Application | None") -> QuerySet["ApplicationEntitlement"]:
|
||||
def app_entitlements(self, app: Application | None) -> QuerySet[ApplicationEntitlement]:
|
||||
"""Get all entitlements this user has for `app`."""
|
||||
if not app:
|
||||
return []
|
||||
@@ -485,7 +485,7 @@ class User(SerializerModel, AttributesMixin, AbstractUser):
|
||||
).order_by("name")
|
||||
return qs
|
||||
|
||||
def app_entitlements_attributes(self, app: "Application | None") -> dict:
|
||||
def app_entitlements_attributes(self, app: Application | None) -> dict:
|
||||
"""Get a dictionary containing all merged attributes from app entitlements for `app`."""
|
||||
final_attributes = {}
|
||||
for attrs in self.app_entitlements(app).values_list("attributes", flat=True):
|
||||
@@ -654,7 +654,7 @@ class BackchannelProvider(Provider):
|
||||
|
||||
|
||||
class ApplicationQuerySet(QuerySet):
|
||||
def with_provider(self) -> "QuerySet[Application]":
|
||||
def with_provider(self) -> QuerySet[Application]:
|
||||
qs = self.select_related("provider")
|
||||
for subclass in Provider.objects.get_queryset()._get_subclasses_recurse(Provider):
|
||||
qs = qs.select_related(f"provider__{subclass}")
|
||||
@@ -713,9 +713,7 @@ class Application(SerializerModel, PolicyBindingModel):
|
||||
|
||||
return get_file_manager(FileUsage.MEDIA).file_url(self.meta_icon)
|
||||
|
||||
def get_launch_url(
|
||||
self, user: Optional["User"] = None, user_data: dict | None = None
|
||||
) -> str | None:
|
||||
def get_launch_url(self, user: User | None = None, user_data: dict | None = None) -> str | None:
|
||||
"""Get launch URL if set, otherwise attempt to get launch URL based on provider.
|
||||
|
||||
Args:
|
||||
@@ -948,7 +946,7 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def property_mapping_type(self) -> "type[PropertyMapping]":
|
||||
def property_mapping_type(self) -> type[PropertyMapping]:
|
||||
"""Return property mapping type used by this object"""
|
||||
if self.managed == self.MANAGED_INBUILT:
|
||||
from authentik.core.models import PropertyMapping
|
||||
@@ -1069,7 +1067,7 @@ class ExpiringModel(models.Model):
|
||||
return self.delete(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def filter_not_expired(cls, **kwargs) -> QuerySet["Self"]:
|
||||
def filter_not_expired(cls, **kwargs) -> QuerySet[Self]:
|
||||
"""Filer for tokens which are not expired yet or are not expiring,
|
||||
and match filters in `kwargs`"""
|
||||
for obj in cls.objects.filter(**kwargs).filter(Q(expires__lt=now(), expiring=True)):
|
||||
@@ -1265,7 +1263,7 @@ class AuthenticatedSession(SerializerModel):
|
||||
return f"Authenticated Session {str(self.pk)[:10]}"
|
||||
|
||||
@staticmethod
|
||||
def from_request(request: HttpRequest, user: User) -> Optional["AuthenticatedSession"]:
|
||||
def from_request(request: HttpRequest, user: User) -> AuthenticatedSession | None:
|
||||
"""Create a new session from a http request"""
|
||||
if not hasattr(request, "session") or not request.session.exists(
|
||||
request.session.session_key
|
||||
|
||||
@@ -63,7 +63,7 @@ def user_logged_in_session(sender, request: HttpRequest, user: User, **_):
|
||||
|
||||
|
||||
@receiver(post_delete, sender=AuthenticatedSession)
|
||||
def authenticated_session_delete(sender: type[Model], instance: "AuthenticatedSession", **_):
|
||||
def authenticated_session_delete(sender: type[Model], instance: AuthenticatedSession, **_):
|
||||
"""Delete session when authenticated session is deleted"""
|
||||
Session.objects.filter(session_key=instance.pk).delete()
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ class SourceMapper:
|
||||
def build_object_properties(
|
||||
self,
|
||||
object_type: type[User | Group],
|
||||
manager: "PropertyMappingManager | None" = None,
|
||||
manager: PropertyMappingManager | None = None,
|
||||
user: User | None = None,
|
||||
request: HttpRequest | None = None,
|
||||
**kwargs,
|
||||
|
||||
@@ -62,7 +62,7 @@ class AgentConfigSerializer(PassiveSerializer):
|
||||
def get_system_config(self, instance: AgentConnector) -> ConfigSerializer:
|
||||
return ConfigView.get_config(self.context["request"]).data
|
||||
|
||||
def get_license_status(self, instance: AgentConnector) -> "LicenseUsageStatus":
|
||||
def get_license_status(self, instance: AgentConnector) -> LicenseUsageStatus:
|
||||
try:
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class AgentConnector(Connector):
|
||||
return AuthenticatorEndpointStageView
|
||||
|
||||
@property
|
||||
def controller(self) -> type["AgentConnectorController"]:
|
||||
def controller(self) -> type[AgentConnectorController]:
|
||||
from authentik.endpoints.connectors.agent.controller import AgentConnectorController
|
||||
|
||||
return AgentConnectorController
|
||||
|
||||
@@ -43,7 +43,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi
|
||||
return f"goauthentik.io/endpoints/devices/{self.device_uuid}/facts"
|
||||
|
||||
@property
|
||||
def cached_facts(self) -> "DeviceFactSnapshot":
|
||||
def cached_facts(self) -> DeviceFactSnapshot:
|
||||
if cached := cache.get(self.cache_key_facts):
|
||||
return cached
|
||||
facts = self.facts
|
||||
@@ -51,7 +51,7 @@ class Device(InternallyManagedMixin, ExpiringModel, AttributesMixin, PolicyBindi
|
||||
return facts
|
||||
|
||||
@property
|
||||
def facts(self) -> "DeviceFactSnapshot":
|
||||
def facts(self) -> DeviceFactSnapshot:
|
||||
data = {}
|
||||
last_updated = datetime.fromtimestamp(0, UTC)
|
||||
for snapshot_data, snapshort_created in DeviceFactSnapshot.filter_not_expired(
|
||||
@@ -157,7 +157,7 @@ class Connector(ScheduledModel, SerializerModel):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def controller(self) -> type["BaseController[Connector]"]:
|
||||
def controller(self) -> type[BaseController[Connector]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@@ -205,7 +205,7 @@ class EndpointStage(Stage):
|
||||
mode = models.TextField(choices=StageMode.choices, default=StageMode.OPTIONAL)
|
||||
|
||||
@property
|
||||
def view(self) -> type["StageView"]:
|
||||
def view(self) -> type[StageView]:
|
||||
from authentik.endpoints.stage import EndpointStageView
|
||||
|
||||
return EndpointStageView
|
||||
|
||||
@@ -93,7 +93,7 @@ class LicenseKey:
|
||||
license_flags: list[LicenseFlags] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
|
||||
def validate(jwt: str, check_expiry=True) -> LicenseKey:
|
||||
"""Validate the license from a given JWT"""
|
||||
try:
|
||||
headers = get_unverified_header(jwt)
|
||||
@@ -128,7 +128,7 @@ class LicenseKey:
|
||||
return body
|
||||
|
||||
@staticmethod
|
||||
def get_total() -> "LicenseKey":
|
||||
def get_total() -> LicenseKey:
|
||||
"""Get a summarized version of all (not expired) licenses"""
|
||||
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
|
||||
for lic in License.objects.all():
|
||||
|
||||
@@ -50,7 +50,7 @@ class License(SerializerModel):
|
||||
return LicenseSerializer
|
||||
|
||||
@property
|
||||
def status(self) -> "LicenseKey":
|
||||
def status(self) -> LicenseKey:
|
||||
"""Get parsed license status"""
|
||||
from authentik.enterprise.license import LicenseKey
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class SCIMOAuthException(SCIMRequestException):
|
||||
|
||||
class SCIMOAuthAuth:
|
||||
|
||||
def __init__(self, provider: "SCIMProvider"):
|
||||
def __init__(self, provider: SCIMProvider):
|
||||
self.provider = provider
|
||||
self.user = provider.auth_oauth_user
|
||||
self.logger = get_logger().bind()
|
||||
|
||||
@@ -17,9 +17,9 @@ if TYPE_CHECKING:
|
||||
class SSFTokenAuth(BaseAuthentication):
|
||||
"""SSF Token auth"""
|
||||
|
||||
view: "SSFView"
|
||||
view: SSFView
|
||||
|
||||
def __init__(self, view: "SSFView") -> None:
|
||||
def __init__(self, view: SSFView) -> None:
|
||||
super().__init__()
|
||||
self.view = view
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""ASN Enricher"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
from django.http import HttpRequest
|
||||
from geoip2.errors import GeoIP2Error
|
||||
@@ -27,7 +27,7 @@ class ASNDict(TypedDict):
|
||||
class ASNContextProcessor(MMDBContextProcessor):
|
||||
"""ASN Database reader wrapper"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
def capability(self) -> Capabilities | None:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
|
||||
return Capabilities.CAN_ASN
|
||||
@@ -35,7 +35,7 @@ class ASNContextProcessor(MMDBContextProcessor):
|
||||
def path(self) -> str | None:
|
||||
return CONFIG.get("events.context_processors.asn")
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
def enrich_event(self, event: Event):
|
||||
asn = self.asn_dict(event.client_ip)
|
||||
if not asn:
|
||||
return
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Base event enricher"""
|
||||
|
||||
from functools import cache
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.http import HttpRequest
|
||||
|
||||
@@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
class EventContextProcessor:
|
||||
"""Base event enricher"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
def capability(self) -> Capabilities | None:
|
||||
"""Return the capability this context processor provides"""
|
||||
return None
|
||||
|
||||
@@ -21,7 +21,7 @@ class EventContextProcessor:
|
||||
"""Return true if this context processor is configured"""
|
||||
return False
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
def enrich_event(self, event: Event):
|
||||
"""Modify event"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""events GeoIP Reader"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
from django.http import HttpRequest
|
||||
from geoip2.errors import GeoIP2Error
|
||||
@@ -29,7 +29,7 @@ class GeoIPDict(TypedDict):
|
||||
class GeoIPContextProcessor(MMDBContextProcessor):
|
||||
"""Slim wrapper around GeoIP API"""
|
||||
|
||||
def capability(self) -> Optional["Capabilities"]:
|
||||
def capability(self) -> Capabilities | None:
|
||||
from authentik.api.v3.config import Capabilities
|
||||
|
||||
return Capabilities.CAN_GEO_IP
|
||||
@@ -37,7 +37,7 @@ class GeoIPContextProcessor(MMDBContextProcessor):
|
||||
def path(self) -> str | None:
|
||||
return CONFIG.get("events.context_processors.geoip")
|
||||
|
||||
def enrich_event(self, event: "Event"):
|
||||
def enrich_event(self, event: Event):
|
||||
city = self.city_dict(event.client_ip)
|
||||
if not city:
|
||||
return
|
||||
|
||||
@@ -25,7 +25,7 @@ class LogEvent:
|
||||
attributes: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@staticmethod
|
||||
def from_event_dict(item: EventDict) -> "LogEvent":
|
||||
def from_event_dict(item: EventDict) -> LogEvent:
|
||||
event = item.pop("event")
|
||||
log_level = item.pop("level").lower()
|
||||
timestamp = datetime.fromisoformat(item.pop("timestamp")).replace(tzinfo=UTC)
|
||||
|
||||
@@ -151,7 +151,7 @@ class Event(SerializerModel, ExpiringModel):
|
||||
action: str | EventAction,
|
||||
app: str | None = None,
|
||||
**kwargs,
|
||||
) -> "Event":
|
||||
) -> Event:
|
||||
"""Create new Event instance from arguments. Instance is NOT saved."""
|
||||
if not isinstance(action, EventAction):
|
||||
action = EventAction.CUSTOM_PREFIX + action
|
||||
@@ -169,19 +169,19 @@ class Event(SerializerModel, ExpiringModel):
|
||||
event = Event(action=action, app=app, context=cleaned_kwargs)
|
||||
return event
|
||||
|
||||
def with_exception(self, exc: Exception) -> "Event":
|
||||
def with_exception(self, exc: Exception) -> Event:
|
||||
"""Add data from 'exc' to the event in a database-saveable format"""
|
||||
self.context.setdefault("message", str(exc))
|
||||
self.context["exception"] = exception_to_dict(exc)
|
||||
return self
|
||||
|
||||
def set_user(self, user: User) -> "Event":
|
||||
def set_user(self, user: User) -> Event:
|
||||
"""Set `.user` based on user, ensuring the correct attributes are copied.
|
||||
This should only be used when self.from_http is *not* used."""
|
||||
self.user = get_user(user)
|
||||
return self
|
||||
|
||||
def from_http(self, request: HttpRequest, user: User | None = None) -> "Event":
|
||||
def from_http(self, request: HttpRequest, user: User | None = None) -> Event:
|
||||
"""Add data from a Django-HttpRequest, allowing the creation of
|
||||
Events independently from requests.
|
||||
`user` arguments optionally overrides user from requests."""
|
||||
@@ -343,7 +343,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
),
|
||||
)
|
||||
|
||||
def send(self, notification: "Notification") -> list[str]:
|
||||
def send(self, notification: Notification) -> list[str]:
|
||||
"""Send notification to user, called from async task"""
|
||||
if self.mode == TransportMode.LOCAL:
|
||||
return self.send_local(notification)
|
||||
@@ -355,7 +355,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
return self.send_email(notification)
|
||||
raise ValueError(f"Invalid mode {self.mode} set")
|
||||
|
||||
def send_local(self, notification: "Notification") -> list[str]:
|
||||
def send_local(self, notification: Notification) -> list[str]:
|
||||
"""Local notification delivery"""
|
||||
if self.webhook_mapping_body:
|
||||
self.webhook_mapping_body.evaluate(
|
||||
@@ -375,7 +375,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
)
|
||||
return []
|
||||
|
||||
def send_webhook(self, notification: "Notification") -> list[str]:
|
||||
def send_webhook(self, notification: Notification) -> list[str]:
|
||||
"""Send notification to generic webhook"""
|
||||
default_body = {
|
||||
"body": notification.body,
|
||||
@@ -419,7 +419,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
response.text,
|
||||
]
|
||||
|
||||
def send_webhook_slack(self, notification: "Notification") -> list[str]:
|
||||
def send_webhook_slack(self, notification: Notification) -> list[str]:
|
||||
"""Send notification to slack or slack-compatible endpoints"""
|
||||
fields = [
|
||||
{
|
||||
@@ -477,7 +477,7 @@ class NotificationTransport(TasksModel, SerializerModel):
|
||||
response.text,
|
||||
]
|
||||
|
||||
def send_email(self, notification: "Notification") -> list[str]:
|
||||
def send_email(self, notification: Notification) -> list[str]:
|
||||
"""Send notification via global email configuration"""
|
||||
from authentik.stages.email.tasks import send_mail
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class DiagramElement:
|
||||
identifier: str
|
||||
description: str
|
||||
action: str | None = None
|
||||
source: list["DiagramElement"] | None = None
|
||||
source: list[DiagramElement] | None = None
|
||||
|
||||
style: list[str] = field(default_factory=lambda: ["[", "]"])
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Optional, TypedDict
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
from django.core.serializers.json import DjangoJSONEncoder
|
||||
@@ -137,7 +137,7 @@ class PermissionDict(TypedDict):
|
||||
class ChallengeResponse(PassiveSerializer):
|
||||
"""Base class for all challenge responses"""
|
||||
|
||||
stage: Optional["StageView"]
|
||||
stage: StageView | None
|
||||
component = CharField(default="xak-flow-response-default")
|
||||
|
||||
def __init__(self, instance=None, data=None, **kwargs):
|
||||
|
||||
@@ -23,7 +23,7 @@ class StageMarker:
|
||||
|
||||
def process(
|
||||
self,
|
||||
plan: "FlowPlan",
|
||||
plan: FlowPlan,
|
||||
binding: FlowStageBinding,
|
||||
http_request: HttpRequest,
|
||||
) -> FlowStageBinding | None:
|
||||
@@ -40,7 +40,7 @@ class ReevaluateMarker(StageMarker):
|
||||
|
||||
def process(
|
||||
self,
|
||||
plan: "FlowPlan",
|
||||
plan: FlowPlan,
|
||||
binding: FlowStageBinding,
|
||||
http_request: HttpRequest,
|
||||
) -> FlowStageBinding | None:
|
||||
|
||||
@@ -90,7 +90,7 @@ class Stage(SerializerModel):
|
||||
objects = InheritanceManager()
|
||||
|
||||
@property
|
||||
def view(self) -> type["StageView"]:
|
||||
def view(self) -> type[StageView]:
|
||||
"""Return StageView class that implements logic for this stage"""
|
||||
# This is a bit of a workaround, since we can't set class methods with setattr
|
||||
if hasattr(self, "__in_memory_type"):
|
||||
@@ -117,7 +117,7 @@ class Stage(SerializerModel):
|
||||
return f"Stage {self.name}"
|
||||
|
||||
|
||||
def in_memory_stage(view: type["StageView"], **kwargs) -> Stage:
|
||||
def in_memory_stage(view: type[StageView], **kwargs) -> Stage:
|
||||
"""Creates an in-memory stage instance, based on a `view` as view.
|
||||
Any key-word arguments are set as attributes on the stage object,
|
||||
accessible via `self.executor.current_stage`."""
|
||||
@@ -310,13 +310,13 @@ class FlowToken(InternallyManagedMixin, Token):
|
||||
revoke_on_execution = models.BooleanField(default=True)
|
||||
|
||||
@staticmethod
|
||||
def pickle(plan: "FlowPlan") -> str:
|
||||
def pickle(plan: FlowPlan) -> str:
|
||||
"""Pickle into string"""
|
||||
data = dumps(plan)
|
||||
return b64encode(data).decode()
|
||||
|
||||
@property
|
||||
def plan(self) -> "FlowPlan":
|
||||
def plan(self) -> FlowPlan:
|
||||
"""Load Flow plan from pickled version"""
|
||||
return loads(b64decode(self._plan.encode())) # nosec
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ class FlowPlan:
|
||||
|
||||
def requires_flow_executor(
|
||||
self,
|
||||
allowed_silent_types: list["StageView"] | None = None,
|
||||
allowed_silent_types: list[StageView] | None = None,
|
||||
):
|
||||
# Check if we actually need to show the Flow executor, or if we can jump straight to the end
|
||||
found_unskippable = True
|
||||
@@ -145,7 +145,7 @@ class FlowPlan:
|
||||
request: HttpRequest,
|
||||
flow: Flow,
|
||||
next: str | None = None,
|
||||
allowed_silent_types: list["StageView"] | None = None,
|
||||
allowed_silent_types: list[StageView] | None = None,
|
||||
) -> HttpResponse:
|
||||
"""Redirect to the flow executor for this flow plan"""
|
||||
from authentik.flows.views.executor import (
|
||||
|
||||
@@ -46,13 +46,13 @@ HIST_FLOWS_STAGE_TIME = Histogram(
|
||||
class StageView(View):
|
||||
"""Abstract Stage"""
|
||||
|
||||
executor: "FlowExecutorView"
|
||||
executor: FlowExecutorView
|
||||
|
||||
request: HttpRequest = None
|
||||
|
||||
logger: BoundLogger
|
||||
|
||||
def __init__(self, executor: "FlowExecutorView", **kwargs):
|
||||
def __init__(self, executor: FlowExecutorView, **kwargs):
|
||||
self.executor = executor
|
||||
current_stage = getattr(self.executor, "current_stage", None)
|
||||
self.logger = get_logger().bind(
|
||||
@@ -257,7 +257,7 @@ class AccessDeniedStage(ChallengeStageView):
|
||||
|
||||
error_message: str | None
|
||||
|
||||
def __init__(self, executor: "FlowExecutorView", error_message: str | None = None, **kwargs):
|
||||
def __init__(self, executor: FlowExecutorView, error_message: str | None = None, **kwargs):
|
||||
super().__init__(executor, **kwargs)
|
||||
self.error_message = error_message
|
||||
|
||||
|
||||
@@ -37,18 +37,18 @@ SVG_FONTS = [
|
||||
]
|
||||
|
||||
|
||||
def avatar_mode_none(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_none(user: User, mode: str) -> str | None:
|
||||
"""No avatar"""
|
||||
return DEFAULT_AVATAR
|
||||
|
||||
|
||||
def avatar_mode_attribute(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_attribute(user: User, mode: str) -> str | None:
|
||||
"""Avatars based on a user attribute"""
|
||||
avatar = get_path_from_dict(user.attributes, mode[11:], default=None)
|
||||
return avatar
|
||||
|
||||
|
||||
def avatar_mode_gravatar(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_gravatar(user: User, mode: str) -> str | None:
|
||||
"""Gravatar avatars"""
|
||||
|
||||
mail_hash = sha256(user.email.lower().encode("utf-8")).hexdigest() # nosec
|
||||
@@ -141,7 +141,7 @@ def generate_avatar_from_name(
|
||||
return etree.tostring(root_element).decode()
|
||||
|
||||
|
||||
def avatar_mode_generated(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_generated(user: User, mode: str) -> str | None:
|
||||
"""Wrapper that converts generated avatar to base64 svg"""
|
||||
# By default generate based off of user's display name
|
||||
name = user.name.strip()
|
||||
@@ -155,7 +155,7 @@ def avatar_mode_generated(user: "User", mode: str) -> str | None:
|
||||
return f"data:image/svg+xml;base64,{b64encode(svg.encode('utf-8')).decode('utf-8')}"
|
||||
|
||||
|
||||
def avatar_mode_url(user: "User", mode: str) -> str | None:
|
||||
def avatar_mode_url(user: User, mode: str) -> str | None:
|
||||
"""Format url"""
|
||||
mail_hash = md5(user.email.lower().encode("utf-8"), usedforsecurity=False).hexdigest() # nosec
|
||||
|
||||
@@ -197,7 +197,7 @@ def avatar_mode_url(user: "User", mode: str) -> str | None:
|
||||
return formatted_url
|
||||
|
||||
|
||||
def get_avatar(user: "User", request: HttpRequest | None = None) -> str:
|
||||
def get_avatar(user: User, request: HttpRequest | None = None) -> str:
|
||||
"""Get avatar with configured mode"""
|
||||
mode_map = {
|
||||
"none": avatar_mode_none,
|
||||
|
||||
@@ -238,7 +238,7 @@ class BaseEvaluator:
|
||||
address: str | list[str],
|
||||
subject: str,
|
||||
body: str | None = None,
|
||||
stage: "EmailStage | None" = None,
|
||||
stage: EmailStage | None = None,
|
||||
template: str | None = None,
|
||||
context: dict | None = None,
|
||||
) -> bool:
|
||||
|
||||
@@ -36,9 +36,9 @@ def get_version() -> str:
|
||||
class KubernetesObjectReconciler[T]:
|
||||
"""Base Kubernetes Reconciler, handles the basic logic."""
|
||||
|
||||
controller: "KubernetesController"
|
||||
controller: KubernetesController
|
||||
|
||||
def __init__(self, controller: "KubernetesController"):
|
||||
def __init__(self, controller: KubernetesController):
|
||||
self.controller = controller
|
||||
self.namespace = controller.outpost.config.kubernetes_namespace
|
||||
self.logger = get_logger().bind(type=self.__class__.__name__)
|
||||
|
||||
@@ -39,7 +39,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
||||
|
||||
outpost: Outpost
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api = AppsV1Api(controller.client)
|
||||
self.outpost = self.controller.outpost
|
||||
|
||||
@@ -21,7 +21,7 @@ def b64string(source: str) -> str:
|
||||
class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
||||
"""Kubernetes Secret Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@ if TYPE_CHECKING:
|
||||
class ServiceReconciler(KubernetesObjectReconciler[V1Service]):
|
||||
"""Kubernetes Service Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ CRD_PLURAL = "servicemonitors"
|
||||
class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusServiceMonitor]):
|
||||
"""Kubernetes Prometheus ServiceMonitor Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api_ex = ApiextensionsV1Api(controller.client)
|
||||
self.api = CustomObjectsApi(controller.client)
|
||||
|
||||
@@ -304,7 +304,7 @@ class Outpost(ScheduledModel, SerializerModel, ManagedModel):
|
||||
return f"goauthentik.io/outposts/state/{self.uuid.hex}"
|
||||
|
||||
@property
|
||||
def state(self) -> list["OutpostState"]:
|
||||
def state(self) -> list[OutpostState]:
|
||||
"""Get outpost's health status"""
|
||||
return OutpostState.for_outpost(self)
|
||||
|
||||
@@ -480,7 +480,7 @@ class OutpostState:
|
||||
return parse(self.version) != OUR_VERSION
|
||||
|
||||
@staticmethod
|
||||
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
|
||||
def for_outpost(outpost: Outpost) -> list[OutpostState]:
|
||||
"""Get all states for an outpost"""
|
||||
keys = cache.keys(f"{outpost.state_cache_prefix}/*")
|
||||
if not keys:
|
||||
@@ -492,7 +492,7 @@ class OutpostState:
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
||||
def for_instance_uid(outpost: Outpost, uid: str) -> OutpostState:
|
||||
"""Get state for a single instance"""
|
||||
key = f"{outpost.state_cache_prefix}/{uid}"
|
||||
default_data = {"uid": uid}
|
||||
|
||||
@@ -140,7 +140,7 @@ class PolicyEngine:
|
||||
passing = False
|
||||
self.__static_result = PolicyResult(passing)
|
||||
|
||||
def build(self) -> "PolicyEngine":
|
||||
def build(self) -> PolicyEngine:
|
||||
"""Build wrapper which monitors performance"""
|
||||
with (
|
||||
start_span(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""authentik expression policy evaluator"""
|
||||
|
||||
from ipaddress import ip_address
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from django.http import HttpRequest
|
||||
from structlog.stdlib import get_logger
|
||||
@@ -23,7 +23,7 @@ class PolicyEvaluator(BaseEvaluator):
|
||||
|
||||
_messages: list[str]
|
||||
|
||||
policy: Optional["ExpressionPolicy"] = None
|
||||
policy: ExpressionPolicy | None = None
|
||||
|
||||
def __init__(self, policy_name: str | None = None):
|
||||
super().__init__(policy_name or "PolicyEvaluator")
|
||||
|
||||
@@ -55,7 +55,7 @@ class PolicyBindingModel(models.Model):
|
||||
class BoundPolicyQuerySet(models.QuerySet):
|
||||
"""QuerySet for filtering enabled bindings for a Policy type"""
|
||||
|
||||
def for_policy(self, policy: "Policy"):
|
||||
def for_policy(self, policy: Policy):
|
||||
return self.filter(policy__in=policy._default_manager.all()).filter(enabled=True)
|
||||
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ class IDToken:
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
provider: "OAuth2Provider", token: "BaseGrantModel", request: HttpRequest, **kwargs
|
||||
) -> "IDToken":
|
||||
provider: OAuth2Provider, token: BaseGrantModel, request: HttpRequest, **kwargs
|
||||
) -> IDToken:
|
||||
"""Create ID Token"""
|
||||
id_token = IDToken(provider, token, **kwargs)
|
||||
id_token.exp = int(
|
||||
@@ -147,7 +147,7 @@ class IDToken:
|
||||
id_dict.update(self.claims)
|
||||
return id_dict
|
||||
|
||||
def to_access_token(self, provider: "OAuth2Provider", token: "BaseGrantModel") -> str:
|
||||
def to_access_token(self, provider: OAuth2Provider, token: BaseGrantModel) -> str:
|
||||
"""Encode id_token for use as access token, adding fields"""
|
||||
final = self.to_dict()
|
||||
final["azp"] = provider.client_id
|
||||
@@ -155,6 +155,6 @@ class IDToken:
|
||||
final["scope"] = " ".join(token.scope)
|
||||
return provider.encode(final)
|
||||
|
||||
def to_jwt(self, provider: "OAuth2Provider") -> str:
|
||||
def to_jwt(self, provider: OAuth2Provider) -> str:
|
||||
"""Shortcut to encode id_token to jwt, signed by self.provider"""
|
||||
return provider.encode(self.to_dict())
|
||||
|
||||
@@ -514,7 +514,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr
|
||||
return f"Access Token for {self.provider_id} for user {self.user_id}"
|
||||
|
||||
@property
|
||||
def id_token(self) -> "IDToken":
|
||||
def id_token(self) -> IDToken:
|
||||
"""Load ID Token from json"""
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
|
||||
@@ -522,7 +522,7 @@ class AccessToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseGr
|
||||
return from_dict(IDToken, raw_token)
|
||||
|
||||
@id_token.setter
|
||||
def id_token(self, value: "IDToken"):
|
||||
def id_token(self, value: IDToken):
|
||||
self.token = value.to_access_token(self.provider, self)
|
||||
self._id_token = json.dumps(asdict(value))
|
||||
|
||||
@@ -567,7 +567,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG
|
||||
return f"Refresh Token for {self.provider_id} for user {self.user_id}"
|
||||
|
||||
@property
|
||||
def id_token(self) -> "IDToken":
|
||||
def id_token(self) -> IDToken:
|
||||
"""Load ID Token from json"""
|
||||
from authentik.providers.oauth2.id_token import IDToken
|
||||
|
||||
@@ -575,7 +575,7 @@ class RefreshToken(InternallyManagedMixin, SerializerModel, ExpiringModel, BaseG
|
||||
return from_dict(IDToken, raw_token)
|
||||
|
||||
@id_token.setter
|
||||
def id_token(self, value: "IDToken"):
|
||||
def id_token(self, value: IDToken):
|
||||
self._id_token = json.dumps(asdict(value))
|
||||
|
||||
@property
|
||||
|
||||
@@ -105,7 +105,7 @@ class OAuthAuthorizationParams:
|
||||
github_compat: InitVar[bool] = False
|
||||
|
||||
@staticmethod
|
||||
def from_request(request: HttpRequest, github_compat=False) -> "OAuthAuthorizationParams":
|
||||
def from_request(request: HttpRequest, github_compat=False) -> OAuthAuthorizationParams:
|
||||
"""
|
||||
Get all the params used by the Authorization Code Flow
|
||||
(and also for the Implicit and Hybrid).
|
||||
|
||||
@@ -40,7 +40,7 @@ class TokenIntrospectionParams:
|
||||
raise TokenIntrospectionError()
|
||||
|
||||
@staticmethod
|
||||
def from_request(request: HttpRequest) -> "TokenIntrospectionParams":
|
||||
def from_request(request: HttpRequest) -> TokenIntrospectionParams:
|
||||
"""Extract required Parameters from HTTP Request"""
|
||||
raw_token = request.POST.get("token")
|
||||
provider = authenticate_provider(request)
|
||||
|
||||
@@ -104,7 +104,7 @@ class TokenParams:
|
||||
provider: OAuth2Provider,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
) -> "TokenParams":
|
||||
) -> TokenParams:
|
||||
"""Parse params for request"""
|
||||
return TokenParams(
|
||||
# Init vars
|
||||
|
||||
@@ -27,7 +27,7 @@ class TokenRevocationParams:
|
||||
provider: OAuth2Provider
|
||||
|
||||
@staticmethod
|
||||
def from_request(request: HttpRequest) -> "TokenRevocationParams":
|
||||
def from_request(request: HttpRequest) -> TokenRevocationParams:
|
||||
"""Extract required Parameters from HTTP Request"""
|
||||
raw_token = request.POST.get("token")
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class HTTPRoute:
|
||||
class HTTPRouteReconciler(KubernetesObjectReconciler):
|
||||
"""Kubernetes Gateway API HTTPRoute Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api_ex = ApiextensionsV1Api(controller.client)
|
||||
self.api = CustomObjectsApi(controller.client)
|
||||
|
||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
class IngressReconciler(KubernetesObjectReconciler[V1Ingress]):
|
||||
"""Kubernetes Ingress Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api = NetworkingV1Api(controller.client)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from authentik.providers.proxy.controllers.k8s.traefik_3 import (
|
||||
class TraefikMiddlewareReconciler(KubernetesObjectReconciler):
|
||||
"""Kubernetes Traefik Middleware Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.reconciler = Traefik3MiddlewareReconciler(controller)
|
||||
if not self.reconciler.crd_exists():
|
||||
|
||||
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
|
||||
class Traefik2MiddlewareReconciler(Traefik3MiddlewareReconciler):
|
||||
"""Kubernetes Traefik Middleware Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.crd_name = "middlewares.traefik.containo.us"
|
||||
self.crd_group = "traefik.containo.us"
|
||||
|
||||
@@ -57,7 +57,7 @@ class TraefikMiddleware:
|
||||
class Traefik3MiddlewareReconciler(KubernetesObjectReconciler[TraefikMiddleware]):
|
||||
"""Kubernetes Traefik Middleware Reconciler"""
|
||||
|
||||
def __init__(self, controller: "KubernetesController") -> None:
|
||||
def __init__(self, controller: KubernetesController) -> None:
|
||||
super().__init__(controller)
|
||||
self.api_ex = ApiextensionsV1Api(controller.client)
|
||||
self.api = CustomObjectsApi(controller.client)
|
||||
|
||||
@@ -8,7 +8,7 @@ if TYPE_CHECKING:
|
||||
|
||||
class SCIMTokenAuth:
|
||||
|
||||
def __init__(self, provider: "SCIMProvider"):
|
||||
def __init__(self, provider: SCIMProvider):
|
||||
self.provider = provider
|
||||
|
||||
def __call__(self, request: Request) -> Request:
|
||||
|
||||
@@ -132,7 +132,7 @@ class ServiceProviderConfiguration(BaseServiceProviderConfiguration):
|
||||
return self._is_fallback
|
||||
|
||||
@staticmethod
|
||||
def default() -> "ServiceProviderConfiguration":
|
||||
def default() -> ServiceProviderConfiguration:
|
||||
"""Get default configuration, which doesn't support any optional features as fallback"""
|
||||
return ServiceProviderConfiguration(
|
||||
patch=Patch(supported=False),
|
||||
|
||||
@@ -183,7 +183,7 @@ class LDAPSource(IncomingSyncSource):
|
||||
]
|
||||
|
||||
@property
|
||||
def property_mapping_type(self) -> "type[PropertyMapping]":
|
||||
def property_mapping_type(self) -> type[PropertyMapping]:
|
||||
from authentik.sources.ldap.models import LDAPSourcePropertyMapping
|
||||
|
||||
return LDAPSourcePropertyMapping
|
||||
|
||||
@@ -85,7 +85,7 @@ class OAuthSource(NonCreatableType, Source):
|
||||
)
|
||||
|
||||
@property
|
||||
def source_type(self) -> type["SourceType"]:
|
||||
def source_type(self) -> type[SourceType]:
|
||||
"""Return the provider instance for this source"""
|
||||
from authentik.sources.oauth.types.registry import registry
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class ResponseProcessor:
|
||||
policy_context={},
|
||||
)
|
||||
|
||||
def _get_name_id(self) -> "Element":
|
||||
def _get_name_id(self) -> Element:
|
||||
"""Get NameID Element"""
|
||||
assertion = self._root.find(f"{{{NS_SAML_ASSERTION}}}Assertion")
|
||||
if assertion is None:
|
||||
|
||||
@@ -86,7 +86,7 @@ class TelegramSource(Source):
|
||||
)
|
||||
|
||||
@property
|
||||
def property_mapping_type(self) -> "type[PropertyMapping]":
|
||||
def property_mapping_type(self) -> type[PropertyMapping]:
|
||||
return TelegramSourcePropertyMapping
|
||||
|
||||
def get_base_user_properties(
|
||||
|
||||
@@ -68,7 +68,7 @@ def match_token(user, token):
|
||||
return device
|
||||
|
||||
|
||||
def devices_for_user(user: "User", confirmed: bool | None = True, for_verify: bool = False):
|
||||
def devices_for_user(user: User, confirmed: bool | None = True, for_verify: bool = False):
|
||||
"""
|
||||
Return an iterable of all devices registered to the given user.
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ class AuthenticatorEmailStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
timeout=self.timeout,
|
||||
)
|
||||
|
||||
def send(self, device: "EmailDevice"):
|
||||
def send(self, device: EmailDevice):
|
||||
# Lazy import here to avoid circular import
|
||||
from authentik.stages.email.tasks import send_mails
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
help_text=_("Optionally modify the payload being sent to custom providers."),
|
||||
)
|
||||
|
||||
def send(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send(self, request: HttpRequest, token: str, device: SMSDevice):
|
||||
"""Send message via selected provider"""
|
||||
if self.provider == SMSProviders.TWILIO:
|
||||
return self.send_twilio(request, token, device)
|
||||
@@ -80,7 +80,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
"""Get SMS message"""
|
||||
return _("Use this code to authenticate in authentik: {token}".format_map({"token": token}))
|
||||
|
||||
def send_twilio(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send_twilio(self, request: HttpRequest, token: str, device: SMSDevice):
|
||||
"""send sms via twilio provider"""
|
||||
client = Client(self.account_sid, self.auth)
|
||||
message_body = str(self.get_message(token))
|
||||
@@ -105,7 +105,7 @@ class AuthenticatorSMSStage(ConfigurableStage, FriendlyNamedStage, Stage):
|
||||
LOGGER.warning("Error sending token by Twilio SMS", exc=exc, msg=exc.msg)
|
||||
raise ValidationError(exc.msg) from None
|
||||
|
||||
def send_generic(self, request: HttpRequest, token: str, device: "SMSDevice"):
|
||||
def send_generic(self, request: HttpRequest, token: str, device: SMSDevice):
|
||||
"""Send SMS via outside API"""
|
||||
payload = {
|
||||
"From": self.from_number,
|
||||
|
||||
@@ -55,7 +55,7 @@ class DeviceChallenge(PassiveSerializer):
|
||||
|
||||
|
||||
def get_challenge_for_device(
|
||||
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage, device: Device
|
||||
stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage, device: Device
|
||||
) -> dict:
|
||||
"""Generate challenge for a single device"""
|
||||
if isinstance(device, WebAuthnDevice):
|
||||
@@ -67,7 +67,7 @@ def get_challenge_for_device(
|
||||
|
||||
|
||||
def get_webauthn_challenge_without_user(
|
||||
stage_view: "AuthenticatorValidateStageView", stage: AuthenticatorValidateStage
|
||||
stage_view: AuthenticatorValidateStageView, stage: AuthenticatorValidateStage
|
||||
) -> dict:
|
||||
"""Same as `get_webauthn_challenge`, but allows any client device. We can then later check
|
||||
who the device belongs to."""
|
||||
@@ -85,7 +85,7 @@ def get_webauthn_challenge_without_user(
|
||||
|
||||
|
||||
def get_webauthn_challenge(
|
||||
stage_view: "AuthenticatorValidateStageView",
|
||||
stage_view: AuthenticatorValidateStageView,
|
||||
stage: AuthenticatorValidateStage,
|
||||
device: WebAuthnDevice | None = None,
|
||||
) -> dict:
|
||||
|
||||
@@ -41,7 +41,7 @@ class ScheduleSpec:
|
||||
options["uid"] = self.uid
|
||||
return pickle.dumps(options)
|
||||
|
||||
def update_or_create(self) -> "Schedule":
|
||||
def update_or_create(self) -> Schedule:
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
|
||||
from authentik.tasks.schedules.models import Schedule
|
||||
|
||||
@@ -53,27 +53,27 @@ workers = CONFIG.get_int("web.workers", default_workers)
|
||||
threads = CONFIG.get_int("web.threads", 4)
|
||||
|
||||
|
||||
def post_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||
def post_fork(server: Arbiter, worker: DjangoUvicornWorker):
|
||||
"""Tell prometheus to use worker number instead of process ID for multiprocess"""
|
||||
from prometheus_client import values
|
||||
|
||||
values.ValueClass = MultiProcessValue(lambda: worker._worker_id)
|
||||
|
||||
|
||||
def worker_exit(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||
def worker_exit(server: Arbiter, worker: DjangoUvicornWorker):
|
||||
"""Remove pid dbs when worker is shutdown"""
|
||||
from prometheus_client import multiprocess
|
||||
|
||||
multiprocess.mark_process_dead(worker._worker_id)
|
||||
|
||||
|
||||
def on_starting(server: "Arbiter"):
|
||||
def on_starting(server: Arbiter):
|
||||
"""Attach a set of IDs that can be temporarily reused.
|
||||
Used on reloads when each worker exists twice."""
|
||||
server._worker_id_overload = set()
|
||||
|
||||
|
||||
def nworkers_changed(server: "Arbiter", new_value, old_value):
|
||||
def nworkers_changed(server: Arbiter, new_value, old_value):
|
||||
"""Gets called on startup too.
|
||||
Set the current number of workers. Required if we raise the worker count
|
||||
temporarily using TTIN because server.cfg.workers won't be updated and if
|
||||
@@ -81,7 +81,7 @@ def nworkers_changed(server: "Arbiter", new_value, old_value):
|
||||
server._worker_id_current_workers = new_value
|
||||
|
||||
|
||||
def _next_worker_id(server: "Arbiter"):
|
||||
def _next_worker_id(server: Arbiter):
|
||||
"""If there are IDs open for reuse, take one. Else look for a free one."""
|
||||
if server._worker_id_overload:
|
||||
return server._worker_id_overload.pop()
|
||||
@@ -92,12 +92,12 @@ def _next_worker_id(server: "Arbiter"):
|
||||
return free.pop()
|
||||
|
||||
|
||||
def on_reload(server: "Arbiter"):
|
||||
def on_reload(server: Arbiter):
|
||||
"""Add a full set of ids into overload so it can be reused once."""
|
||||
server._worker_id_overload = set(range(1, server.cfg.workers + 1))
|
||||
|
||||
|
||||
def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker):
|
||||
def pre_fork(server: Arbiter, worker: DjangoUvicornWorker):
|
||||
"""Attach the next free worker_id before forking off."""
|
||||
worker._worker_id = _next_worker_id(server)
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ MESSAGE_TABLE = Message._meta.db_table
|
||||
|
||||
|
||||
async def _async_proxy(
|
||||
obj: "PostgresChannelLayerLoopProxy",
|
||||
obj: PostgresChannelLayerLoopProxy,
|
||||
name: str,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
@@ -40,7 +40,7 @@ async def _async_proxy(
|
||||
return await getattr(layer, name)(*args, **kwargs)
|
||||
|
||||
|
||||
def _wrap_close(proxy: "PostgresChannelLayerLoopProxy", loop: asyncio.AbstractEventLoop) -> None:
|
||||
def _wrap_close(proxy: PostgresChannelLayerLoopProxy, loop: asyncio.AbstractEventLoop) -> None:
|
||||
original_impl = loop.close
|
||||
|
||||
def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None:
|
||||
@@ -90,7 +90,7 @@ class PostgresChannelLayerLoopProxy:
|
||||
m = zlib.decompress(message)
|
||||
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
|
||||
|
||||
def _get_layer(self) -> "PostgresChannelLoopLayer":
|
||||
def _get_layer(self) -> PostgresChannelLoopLayer:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
try:
|
||||
|
||||
@@ -94,7 +94,7 @@ class PostgresBroker(Broker):
|
||||
return cast(DatabaseWrapper, connections[self.db_alias])
|
||||
|
||||
@property
|
||||
def consumer_class(self) -> "type[_PostgresConsumer]":
|
||||
def consumer_class(self) -> type[_PostgresConsumer]:
|
||||
return _PostgresConsumer
|
||||
|
||||
@cached_property
|
||||
|
||||
@@ -71,7 +71,7 @@ class DbConnectionMiddleware(Middleware):
|
||||
|
||||
|
||||
class TaskStateBeforeMiddleware(Middleware):
|
||||
def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -82,7 +82,7 @@ class TaskStateBeforeMiddleware(Middleware):
|
||||
|
||||
|
||||
class TaskStateAfterMiddleware(Middleware):
|
||||
def before_process_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
|
||||
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -91,7 +91,7 @@ class TaskStateAfterMiddleware(Middleware):
|
||||
state=TaskState.RUNNING,
|
||||
)
|
||||
|
||||
def after_skip_message(self, broker: "PostgresBroker", message: Message[Any]) -> None:
|
||||
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
||||
broker.query_set.filter(
|
||||
message_id=message.message_id,
|
||||
queue_name=message.queue_name,
|
||||
@@ -102,7 +102,7 @@ class TaskStateAfterMiddleware(Middleware):
|
||||
|
||||
def after_process_message(
|
||||
self,
|
||||
broker: "PostgresBroker",
|
||||
broker: PostgresBroker,
|
||||
message: Message[Any],
|
||||
*,
|
||||
result: Any | None = None,
|
||||
|
||||
+1
-1
@@ -185,7 +185,7 @@ exclude = 'node_modules'
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py313"
|
||||
target-version = "py314"
|
||||
exclude = ["**/migrations/**", "**/node_modules/**"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
Reference in New Issue
Block a user