From 25820f063e2b0ab6c43434a779798e1fdfbf3f27 Mon Sep 17 00:00:00 2001 From: Connor Peshek Date: Tue, 27 Jan 2026 08:15:24 -0600 Subject: [PATCH] providers/oauth2: Support login_hint (#19498) * clean up code * simplify skipping logic * clean up reading flag, fix user submission on identification stage * do not auto add login_hint if user doesnt exist and pretend_user_exists is off * rework Signed-off-by: Jens Langhammer * more tests Signed-off-by: Jens Langhammer * fix tests Signed-off-by: Jens Langhammer * sigh Signed-off-by: Jens Langhammer * fix login_hint conformance test Signed-off-by: Jens Langhammer --------- Signed-off-by: Jens Langhammer Co-authored-by: Jens Langhammer --- authentik/flows/tests/test_inspector.py | 1 + authentik/providers/oauth2/constants.py | 2 + .../providers/oauth2/tests/test_authorize.py | 29 +++++++ authentik/providers/oauth2/views/authorize.py | 11 ++- authentik/stages/identification/stage.py | 42 +++++++++- authentik/stages/identification/tests.py | 78 ++++++++++++++----- schema.yml | 3 + tests/e2e/utils.py | 63 ++++++++------- tests/openid_conformance/base.py | 6 +- .../identification/IdentificationStage.ts | 4 +- 10 files changed, 188 insertions(+), 51 deletions(-) diff --git a/authentik/flows/tests/test_inspector.py b/authentik/flows/tests/test_inspector.py index 1f2ff416ed..37bd8f0974 100644 --- a/authentik/flows/tests/test_inspector.py +++ b/authentik/flows/tests/test_inspector.py @@ -62,6 +62,7 @@ class TestFlowInspector(APITestCase): "primary_action": "Log in", "sources": [], "show_source_labels": False, + "pending_user_identifier": None, "user_fields": ["username"], }, ) diff --git a/authentik/providers/oauth2/constants.py b/authentik/providers/oauth2/constants.py index f8a938f761..cc4d5309d3 100644 --- a/authentik/providers/oauth2/constants.py +++ b/authentik/providers/oauth2/constants.py @@ -10,6 +10,8 @@ GRANT_TYPE_CLIENT_CREDENTIALS = "client_credentials" GRANT_TYPE_PASSWORD = "password" # nosec GRANT_TYPE_DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code" +QS_LOGIN_HINT = "login_hint" + CLIENT_ASSERTION = "client_assertion" CLIENT_ASSERTION_TYPE = "client_assertion_type" CLIENT_ASSERTION_TYPE_JWT = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index 20cad56624..85ba3bf4aa 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -12,6 +12,8 @@ from authentik.core.models import Application from authentik.core.tests.utils import create_test_admin_user, create_test_brand, create_test_flow from authentik.events.models import Event, EventAction from authentik.flows.models import FlowStageBinding +from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER +from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.lib.generators import generate_id from authentik.lib.utils.time import timedelta_from_string from authentik.providers.oauth2.constants import SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, TOKEN_TYPE @@ -765,3 +767,30 @@ class TestAuthorize(OAuthTestCase): self.assertEqual(response.status_code, 302) self.assertIn(auth_flow.slug, response.url) self.assertNotIn(global_auth.slug, response.url) + + @apply_blueprint("default/flow-default-authentication-flow.yaml") + def test_login_hint(self): + """Login hint""" + flow = create_test_flow() + provider = OAuth2Provider.objects.create( + name=generate_id(), + client_id="test", + authorization_flow=flow, + redirect_uris=[RedirectURI(RedirectURIMatchingMode.STRICT, "foo://localhost")], + access_code_validity="seconds=100", + ) + Application.objects.create(name="app", slug="app", provider=provider) + state = generate_id() + response = self.client.get( + reverse("authentik_providers_oauth2:authorize"), + data={ + "response_type": "code", + "client_id": "test", + "state": state, + "redirect_uri": "foo://localhost", + "login_hint": "foo", + }, + ) + self.assertEqual(response.status_code, 302) + plan = self.client.session.get(SESSION_KEY_PLAN) + self.assertEqual(plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER], "foo") diff --git a/authentik/providers/oauth2/views/authorize.py b/authentik/providers/oauth2/views/authorize.py index 8e2d2a5c5b..516979d52d 100644 --- a/authentik/providers/oauth2/views/authorize.py +++ b/authentik/providers/oauth2/views/authorize.py @@ -5,6 +5,7 @@ from datetime import timedelta from json import dumps from re import error as RegexError from re import fullmatch +from typing import Any from urllib.parse import parse_qs, quote, urlencode, urlparse, urlsplit, urlunparse, urlunsplit from uuid import uuid4 @@ -25,9 +26,9 @@ from authentik.flows.challenge import ( HttpChallengeResponse, ) from authentik.flows.exceptions import FlowNonApplicableException -from authentik.flows.models import in_memory_stage +from authentik.flows.models import Flow, in_memory_stage from authentik.flows.planner import PLAN_CONTEXT_APPLICATION, PLAN_CONTEXT_SSO, FlowPlanner -from authentik.flows.stage import StageView +from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView from authentik.lib.utils.time import timedelta_from_string from authentik.lib.views import bad_request_message from authentik.policies.types import PolicyRequest @@ -38,6 +39,7 @@ from authentik.providers.oauth2.constants import ( PROMPT_CONSENT, PROMPT_LOGIN, PROMPT_NONE, + QS_LOGIN_HINT, SCOPE_GITHUB, SCOPE_OFFLINE_ACCESS, SCOPE_OPENID, @@ -379,6 +381,11 @@ class AuthorizationFlowInitView(BufferedPolicyAccessView): self.provider = get_object_or_404(OAuth2Provider, client_id=client_id) self.application = self.provider.application + def modify_flow_context(self, flow: Flow, context: dict[str, Any]) -> dict[str, Any]: + if QS_LOGIN_HINT in self.request.GET: + context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = self.request.GET.get(QS_LOGIN_HINT) + return super().modify_flow_context(flow, context) + def modify_policy_request(self, request: PolicyRequest) -> PolicyRequest: request.context["oauth_scopes"] = self.params.scope request.context["oauth_grant_type"] = self.params.grant_type diff --git a/authentik/stages/identification/stage.py b/authentik/stages/identification/stage.py index 5d60bbac8b..418d467aee 100644 --- a/authentik/stages/identification/stage.py +++ b/authentik/stages/identification/stage.py @@ -6,7 +6,7 @@ from typing import Any from django.contrib.auth.hashers import make_password from django.core.exceptions import PermissionDenied from django.db.models import Q -from django.http import HttpResponse +from django.http import HttpRequest, HttpResponse from django.utils.timezone import now from django.utils.translation import gettext as _ from drf_spectacular.utils import PolymorphicProxySerializer, extend_schema_field @@ -96,6 +96,8 @@ class IdentificationChallenge(Challenge): """Identification challenges with all UI elements""" user_fields = ListField(child=CharField(), allow_empty=True, allow_null=True) + pending_user_identifier = CharField(required=False, allow_null=True) + password_fields = BooleanField() allow_show_password = BooleanField(default=False) application_pre = CharField(required=False) @@ -285,6 +287,39 @@ class IdentificationStageView(ChallengeStageView): self.logger.debug("Generated passkey challenge", challenge=challenge) return challenge + def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: + """Check for existing pending user identifier and skip stage if possible""" + current_stage: IdentificationStage = self.executor.current_stage + pending_user_identifier = self.executor.plan.context.get( + PLAN_CONTEXT_PENDING_USER_IDENTIFIER + ) + + if not pending_user_identifier: + return super().get(request, *args, **kwargs) + + # Only skip if this is a "simple" identification stage with no extra features + can_skip = ( + not current_stage.password_stage + and not current_stage.captcha_stage + and not current_stage.webauthn_stage + and not self.executor.current_binding.policies.exists() + ) + + if can_skip: + # Use the normal validation flow (handles timing protection, logging, signals) + response = IdentificationChallengeResponse( + data={"uid_field": pending_user_identifier}, + stage=self, + ) + if response.is_valid(): + return self.challenge_valid(response) + # Validation failed (user doesn't exist and pretend_user_exists is off) + # Don't pre-fill invalid username, fall through to show the challenge + self.executor.plan.context.pop(PLAN_CONTEXT_PENDING_USER_IDENTIFIER, None) + + # Can't skip - just pre-fill the username field + return super().get(request, *args, **kwargs) + def get_challenge(self) -> Challenge: current_stage: IdentificationStage = self.executor.current_stage challenge = IdentificationChallenge( @@ -360,6 +395,11 @@ class IdentificationStageView(ChallengeStageView): button["challenge"] = source_challenge.data ui_sources.append(button) challenge.initial_data["sources"] = ui_sources + + # Pre-fill username from login_hint unless user clicked "Not you?" + if prefill := self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER_IDENTIFIER): + challenge.initial_data["pending_user_identifier"] = prefill + return challenge def challenge_valid(self, response: IdentificationChallengeResponse) -> HttpResponse: diff --git a/authentik/stages/identification/tests.py b/authentik/stages/identification/tests.py index ecbffafef2..18096a1a50 100644 --- a/authentik/stages/identification/tests.py +++ b/authentik/stages/identification/tests.py @@ -6,6 +6,7 @@ from rest_framework.exceptions import ValidationError from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.flows.models import FlowDesignation, FlowStageBinding +from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER from authentik.flows.tests import FlowTestCase from authentik.lib.generators import generate_id from authentik.sources.oauth.models import OAuthSource @@ -137,7 +138,7 @@ class TestIdentificationStage(FlowTestCase): self.user = create_test_admin_user() # OAuthSource for the login view - source = OAuthSource.objects.create(name="test", slug="test") + self.source = OAuthSource.objects.create(name=generate_id(), slug=generate_id()) self.flow = create_test_flow(FlowDesignation.AUTHENTICATION) self.stage = IdentificationStage.objects.create( @@ -145,7 +146,7 @@ class TestIdentificationStage(FlowTestCase): user_fields=[UserFields.E_MAIL], pretend_user_exists=False, ) - self.stage.sources.set([source]) + self.stage.sources.set([self.source]) self.stage.save() FlowStageBinding.objects.create( target=self.flow, @@ -203,10 +204,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -239,10 +240,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -314,10 +315,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -370,10 +371,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -433,10 +434,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -481,10 +482,10 @@ class TestIdentificationStage(FlowTestCase): sources=[ { "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "promoted": False, } @@ -520,10 +521,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -548,10 +549,10 @@ class TestIdentificationStage(FlowTestCase): { "challenge": { "component": "xak-flow-redirect", - "to": "/source/oauth/login/test/", + "to": f"/source/oauth/login/{self.source.slug}/", }, "icon_url": "/static/authentik/sources/default.svg", - "name": "test", + "name": self.source.name, "promoted": False, } ], @@ -579,3 +580,44 @@ class TestIdentificationStage(FlowTestCase): "sources": [], } ).is_valid(raise_exception=True) + + def test_prefill(self): + """Username prefill from existing flow context""" + pw_stage = PasswordStage.objects.create(name=generate_id(), backends=[BACKEND_INBUILT]) + self.stage.password_stage = pw_stage + self.stage.save() + + self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + ) + + plan = self.get_flow_plan() + plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = "foo" + self.set_flow_plan(plan) + with self.assertFlowFinishes() as plan: + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + ) + self.assertEqual(response.status_code, 200) + self.assertStageResponse( + response, + self.flow, + component="ak-stage-identification", + pending_user_identifier="foo", + ) + + def test_prefill_simple(self): + """Username prefill from existing flow context""" + self.stage.pretend_user_exists = True + self.stage.save() + + self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + ) + plan = self.get_flow_plan() + plan.context[PLAN_CONTEXT_PENDING_USER_IDENTIFIER] = "foo" + self.set_flow_plan(plan) + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}) + ) + self.assertStageRedirects(response, reverse("authentik_core:root-redirect")) diff --git a/schema.yml b/schema.yml index 7fb4bbe821..14f9a8adec 100644 --- a/schema.yml +++ b/schema.yml @@ -39715,6 +39715,9 @@ components: type: object additionalProperties: {} nullable: true + pending_user_identifier: + type: string + nullable: true required: - flow_designation - password_fields diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 12e4147495..fd1cdcc032 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -332,41 +332,48 @@ class SeleniumTestCase(DockerTestCase, StaticLiveServerTestCase): return wrapper(self.driver) - def login(self, shadow_dom=True): + def login(self, shadow_dom=True, skip_stages: list[str] | None = None): """Perform the entire authentik login flow.""" + skip_stages = skip_stages or [] - if shadow_dom: - flow_executor = self.get_shadow_root("ak-flow-executor") - identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor) - else: - flow_executor = self.shady_dom() - identification_stage = self.shady_dom() + if "ak-stage-identification" not in skip_stages: + if shadow_dom: + flow_executor = self.get_shadow_root("ak-flow-executor") + identification_stage = self.get_shadow_root( + "ak-stage-identification", flow_executor + ) + else: + flow_executor = self.shady_dom() + identification_stage = self.shady_dom() - wait = WebDriverWait(identification_stage, self.wait_timeout) - wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=uidField]"))) + wait = WebDriverWait(identification_stage, self.wait_timeout) + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=uidField]"))) - identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").click() - identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( - self.user.username - ) - identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( - Keys.ENTER - ) + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").click() + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( + self.user.username + ) + identification_stage.find_element(By.CSS_SELECTOR, "input[name=uidField]").send_keys( + Keys.ENTER + ) - if shadow_dom: - flow_executor = self.get_shadow_root("ak-flow-executor") - password_stage = self.get_shadow_root("ak-stage-password", flow_executor) - else: - flow_executor = self.shady_dom() - password_stage = self.shady_dom() + if "ak-stage-password" not in skip_stages: + if shadow_dom: + flow_executor = self.get_shadow_root("ak-flow-executor") + password_stage = self.get_shadow_root("ak-stage-password", flow_executor) + else: + flow_executor = self.shady_dom() + password_stage = self.shady_dom() - wait = WebDriverWait(password_stage, self.wait_timeout) - wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=password]"))) + wait = WebDriverWait(password_stage, self.wait_timeout) + wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "input[name=password]"))) - password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( - self.user.username - ) - password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys(Keys.ENTER) + password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( + self.user.username + ) + password_stage.find_element(By.CSS_SELECTOR, "input[name=password]").send_keys( + Keys.ENTER + ) sleep(1) def assert_user(self, expected_user: User): diff --git a/tests/openid_conformance/base.py b/tests/openid_conformance/base.py index a562cb03a4..f9f8aeb38a 100644 --- a/tests/openid_conformance/base.py +++ b/tests/openid_conformance/base.py @@ -43,6 +43,7 @@ class TestOpenIDConformance(SeleniumTestCase): "authentik_providers_oauth2:provider-info", application_slug="oidc-conformance-1", ), + "login_hint": self.user.username, }, "client": { "client_id": "4054d882aff59755f2f279968b97ce8806a926e1", @@ -138,7 +139,10 @@ class TestOpenIDConformance(SeleniumTestCase): should_expect_completion = False if "if/flow/default-authentication-flow" in self.driver.current_url: self.logger.debug("Logging in") - self.login() + skipped = [] + if "login_hint" in self.driver.current_url: + skipped.append("ak-stage-identification") + self.login(skip_stages=skipped) should_expect_completion = True if "prompt=consent" in url or "offline_access" in url: self.logger.debug("Authorizing") diff --git a/web/src/flow/stages/identification/IdentificationStage.ts b/web/src/flow/stages/identification/IdentificationStage.ts index 7eba73e40c..ec5bdf5ec2 100644 --- a/web/src/flow/stages/identification/IdentificationStage.ts +++ b/web/src/flow/stages/identification/IdentificationStage.ts @@ -437,7 +437,9 @@ export class IdentificationStage extends BaseStage< autocomplete=${autocomplete} spellcheck="false" class="pf-c-form-control" - value=${this.#rememberMe?.username ?? ""} + value=${this.#rememberMe?.username ?? + this.challenge.pendingUserIdentifier ?? + ""} required /> ${this.#rememberMe.render()}