diff --git a/authentik/providers/oauth2/tests/test_device_backchannel.py b/authentik/providers/oauth2/tests/test_device_backchannel.py index 1e75b7f8ef..1c54ca54c4 100644 --- a/authentik/providers/oauth2/tests/test_device_backchannel.py +++ b/authentik/providers/oauth2/tests/test_device_backchannel.py @@ -6,10 +6,11 @@ from urllib.parse import quote from django.urls import reverse +from authentik.blueprints.tests import apply_blueprint from authentik.core.models import Application from authentik.core.tests.utils import create_test_flow from authentik.lib.generators import generate_id -from authentik.providers.oauth2.models import OAuth2Provider +from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping from authentik.providers.oauth2.tests.utils import OAuthTestCase @@ -110,3 +111,57 @@ class TesOAuth2DeviceBackchannel(OAuthTestCase): self.assertEqual(res.status_code, 200) body = loads(res.content.decode()) self.assertEqual(body["expires_in"], 60) + + @apply_blueprint("system/providers-oauth2.yaml") + def test_backchannel_scopes(self): + """Test backchannel""" + self.provider.property_mappings.set( + ScopeMapping.objects.filter( + managed__in=[ + "goauthentik.io/providers/oauth2/scope-openid", + "goauthentik.io/providers/oauth2/scope-email", + "goauthentik.io/providers/oauth2/scope-profile", + ] + ) + ) + creds = b64encode(f"{self.provider.client_id}:".encode()).decode() + res = self.client.post( + reverse("authentik_providers_oauth2:device"), + HTTP_AUTHORIZATION=f"Basic {creds}", + data={"scope": "openid email"}, + ) + self.assertEqual(res.status_code, 200) + body = loads(res.content.decode()) + self.assertEqual(body["expires_in"], 60) + token = DeviceToken.objects.filter(device_code=body["device_code"]).first() + self.assertIsNotNone(token) + self.assertEqual(len(token.scope), 2) + self.assertIn("openid", token.scope) + self.assertIn("email", token.scope) + + @apply_blueprint("system/providers-oauth2.yaml") + def test_backchannel_scopes_extra(self): + """Test backchannel""" + self.provider.property_mappings.set( + ScopeMapping.objects.filter( + managed__in=[ + "goauthentik.io/providers/oauth2/scope-openid", + "goauthentik.io/providers/oauth2/scope-email", + "goauthentik.io/providers/oauth2/scope-profile", + ] + ) + ) + creds = b64encode(f"{self.provider.client_id}:".encode()).decode() + res = self.client.post( + reverse("authentik_providers_oauth2:device"), + HTTP_AUTHORIZATION=f"Basic {creds}", + data={"scope": "openid email foo"}, + ) + self.assertEqual(res.status_code, 200) + body = loads(res.content.decode()) + self.assertEqual(body["expires_in"], 60) + token = DeviceToken.objects.filter(device_code=body["device_code"]).first() + self.assertIsNotNone(token) + self.assertEqual(len(token.scope), 2) + self.assertIn("openid", token.scope) + self.assertIn("email", token.scope) diff --git a/authentik/providers/oauth2/views/device_backchannel.py b/authentik/providers/oauth2/views/device_backchannel.py index d8e73f9894..301f36492d 100644 --- a/authentik/providers/oauth2/views/device_backchannel.py +++ b/authentik/providers/oauth2/views/device_backchannel.py @@ -15,7 +15,7 @@ from authentik.core.models import Application from authentik.lib.config import CONFIG from authentik.lib.utils.time import timedelta_from_string from authentik.providers.oauth2.errors import DeviceCodeError -from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider +from authentik.providers.oauth2.models import DeviceToken, OAuth2Provider, ScopeMapping from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth from authentik.providers.oauth2.views.device_init import QS_KEY_CODE @@ -28,7 +28,7 @@ class DeviceView(View): client_id: str provider: OAuth2Provider - scopes: list[str] = [] + scopes: set[str] = [] def parse_request(self): """Parse incoming request""" @@ -44,7 +44,21 @@ class DeviceView(View): raise DeviceCodeError("invalid_client") from None self.provider = provider self.client_id = client_id - self.scopes = self.request.POST.get("scope", "").split(" ") + + scopes_to_check = set(self.request.POST.get("scope", "").split()) + default_scope_names = set( + ScopeMapping.objects.filter(provider__in=[self.provider]).values_list( + "scope_name", flat=True + ) + ) + self.scopes = scopes_to_check + if not scopes_to_check.issubset(default_scope_names): + LOGGER.info( + "Application requested scopes not configured, setting to overlap", + scope_allowed=default_scope_names, + scope_given=self.scopes, + ) + self.scopes = self.scopes.intersection(default_scope_names) def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: throttle = AnonRateThrottle()