enterprise/providers/scim: add support for interactive OAuth2 (#22072)

* enterprise/providers/scim: add support for interactive OAuth2

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* prep different oauth mode

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* implement it

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add data to API

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* update ui

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fixes

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* cleanup

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* start adding tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* remove not-needed migration

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fixup

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix last_updated not being updated

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2026-05-13 18:27:34 +02:00
committed by GitHub
parent 4cfb61f83b
commit a712e5bb2f
22 changed files with 687 additions and 204 deletions
+60 -2
View File
@@ -1,14 +1,72 @@
from datetime import datetime
from django.urls import reverse
from django.utils.translation import gettext as _
from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMProvider
from authentik.sources.oauth.models import UserOAuthSourceConnection
class SCIMProviderSerializerMixin:
def _get_token(self, instance: SCIMProvider) -> UserOAuthSourceConnection | None:
user = instance.auth_oauth_user
conn = UserOAuthSourceConnection.objects.filter(
user=user, source=instance.auth_oauth
).first()
return conn
def get_auth_oauth_token_last_updated(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.last_updated if conn else None
def get_auth_oauth_token_expires(self, instance: SCIMProvider) -> datetime | None:
conn = self._get_token(instance)
return conn.expires if conn else None
def get_auth_oauth_url_callback(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:callback",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)
def get_auth_oauth_url_start(self, instance: SCIMProvider) -> str | None:
if (
instance.auth_mode
in [
SCIMAuthenticationMode.TOKEN,
SCIMAuthenticationMode.OAUTH_SILENT,
]
or not instance.backchannel_application
):
return None
relative_url = reverse(
"authentik_enterprise_providers_scim:start",
kwargs={"application_slug": instance.backchannel_application.slug},
)
if "request" not in self.context:
return relative_url
return self.context["request"].build_absolute_uri(relative_url)
def validate_auth_mode(self, auth_mode: SCIMAuthenticationMode) -> SCIMAuthenticationMode:
if auth_mode == SCIMAuthenticationMode.OAUTH:
if auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
if not LicenseKey.cached_summary().status.is_valid:
raise ValidationError(_("Enterprise is required to use the OAuth mode."))
return auth_mode
@@ -7,3 +7,4 @@ class AuthentikEnterpriseProviderSCIMConfig(EnterpriseConfig):
label = "authentik_enterprise_providers_scim"
verbose_name = "authentik Enterprise.Providers.SCIM"
default = True
mountpoint = "application/scim/"
@@ -1,12 +1,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from django.utils.timezone import now
from requests import Request, RequestException
from structlog.stdlib import get_logger
from authentik.common.oauth.constants import GRANT_TYPE_PASSWORD, GRANT_TYPE_REFRESH_TOKEN
from authentik.providers.scim.clients.exceptions import SCIMRequestException
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.providers.scim.models import SCIMAuthenticationMode
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
if TYPE_CHECKING:
@@ -18,23 +20,26 @@ class SCIMOAuthException(SCIMRequestException):
class SCIMOAuthAuth:
def __init__(self, provider: SCIMProvider):
self.provider = provider
self.user = provider.auth_oauth_user
self.logger = get_logger().bind()
self.connection = self.get_connection()
def retrieve_token(self):
if not self.provider.auth_oauth:
return None
def retrieve_token(self, conn: UserOAuthSourceConnection | None) -> dict[str, Any]:
source: OAuthSource = self.provider.auth_oauth
client = OAuth2Client(source, None)
client: BaseOAuthClient = source.source_type.callback_view(request=None).get_client(source)
access_token_url = source.source_type.access_token_url or ""
if source.source_type.urls_customizable and source.access_token_url:
access_token_url = source.access_token_url
data = client.get_access_token_args(None, None)
data["grant_type"] = "password"
if self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_SILENT:
data["grant_type"] = GRANT_TYPE_PASSWORD
elif self.provider.auth_mode == SCIMAuthenticationMode.OAUTH_INTERACTIVE:
data["grant_type"] = GRANT_TYPE_REFRESH_TOKEN
if not conn:
raise SCIMOAuthException(None, "Could not refresh SCIM OAuth token")
data["refresh_token"] = conn.refresh_token
data.update(self.provider.auth_oauth_params)
try:
response = client.do_request(
@@ -54,12 +59,14 @@ class SCIMOAuthAuth:
raise SCIMOAuthException(exc.response, message="Failed to get OAuth token") from exc
def get_connection(self):
token = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user, expires__gt=now()
if not self.provider.auth_oauth:
return None
conn = UserOAuthSourceConnection.objects.filter(
source=self.provider.auth_oauth, user=self.user
).first()
if token and token.access_token:
return token
token = self.retrieve_token()
if conn and conn.access_token and conn.expires > now():
return conn
token = self.retrieve_token(conn)
access_token = token["access_token"]
expires_in = int(token.get("expires_in", 0))
token, _ = UserOAuthSourceConnection.objects.update_or_create(
@@ -67,7 +74,10 @@ class SCIMOAuthAuth:
user=self.user,
defaults={
"access_token": access_token,
"refresh_token": token.get("refresh_token"),
"expires": now() + timedelta(seconds=expires_in),
# When using `update_or_create`, `last_updated` is not updated
"last_updated": now(),
},
)
return token
@@ -14,7 +14,10 @@ def scim_provider_post_save(sender: type[Model], instance: SCIMProvider, created
"""Create service account before provider is saved"""
identifier = f"ak-providers-scim-{instance.pk}"
with audit_ignore():
if instance.auth_mode == SCIMAuthenticationMode.OAUTH:
if instance.auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
user, user_created = User.objects.update_or_create(
username=identifier,
defaults={
@@ -0,0 +1,73 @@
"""SCIM OAuth tests"""
from unittest.mock import MagicMock, PropertyMock, patch
from django.urls import reverse
from rest_framework.test import APITestCase
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.lib.generators import generate_id
from authentik.sources.oauth.models import OAuthSource
class TestSCIMOAuthAPI(APITestCase):
"""SCIM User tests"""
def setUp(self):
self.source = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
access_token_url="http://localhost/token", # nosec
consumer_key=generate_id(),
consumer_secret=generate_id(),
provider_type="openidconnect",
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_api_create(self):
License.objects.create(key=generate_id())
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 201)
@patch(
"authentik.enterprise.models.LicenseUsageStatus.is_valid",
PropertyMock(return_value=False),
)
def test_api_create_no_license(self):
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 400)
self.assertJSONEqual(
res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
)
@@ -0,0 +1,100 @@
"""SCIM OAuth tests"""
from requests_mock import Mocker
from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
from authentik.sources.oauth.models import OAuthSource
from authentik.tenants.models import Tenant
class TestSCIMOAuthAuth(APITestCase):
"""SCIM User tests"""
@apply_blueprint("system/providers-scim.yaml")
def setUp(self) -> None:
# Delete all users and groups as the mocked HTTP responses only return one ID
# which will cause errors with multiple users
Tenant.objects.update(avatars="none")
User.objects.all().exclude_anonymous().delete()
Group.objects.all().delete()
self.source = OAuthSource.objects.create(
name=generate_id(),
slug=generate_id(),
access_token_url="http://localhost/token", # nosec
consumer_key=generate_id(),
consumer_secret=generate_id(),
provider_type="openidconnect",
)
self.provider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
auth_oauth=self.source,
auth_oauth_params={
"foo": "bar",
},
exclude_users_service_account=True,
)
self.app: Application = Application.objects.create(
name=generate_id(),
slug=generate_id(),
)
self.app.backchannel_providers.add(self.provider)
self.provider.property_mappings.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/user")
)
self.provider.property_mappings_group.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)
@Mocker()
def test_user_create(self, mock: Mocker):
"""Test user creation"""
scim_id = generate_id()
token = generate_id()
mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
mock.get(
"https://localhost/ServiceProviderConfig",
json={},
)
mock.post(
"https://localhost/Users",
json={
"id": scim_id,
},
)
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 3)
self.assertEqual(mock.request_history[1].method, "GET")
self.assertEqual(mock.request_history[2].method, "POST")
self.assertJSONEqual(
mock.request_history[2].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
"primary": True,
"type": "other",
"value": f"{uid}@goauthentik.io",
}
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"givenName": uid,
},
"displayName": f"{uid} {uid}",
"userName": uid,
},
)
@@ -2,7 +2,7 @@
from base64 import b64encode
from datetime import timedelta
from unittest.mock import MagicMock, PropertyMock, patch
from urllib.parse import parse_qs, urlencode, urlparse
from django.urls import reverse
from django.utils.timezone import now
@@ -11,17 +11,14 @@ from rest_framework.test import APITestCase
from authentik.blueprints.tests import apply_blueprint
from authentik.core.models import Application, Group, User
from authentik.core.tests.utils import create_test_admin_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.lib.generators import generate_id
from authentik.providers.scim.models import SCIMAuthenticationMode, SCIMMapping, SCIMProvider
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.tenants.models import Tenant
from tests.live import create_test_admin_user
class SCIMOAuthTests(APITestCase):
class TestSCIMOAuthToken(APITestCase):
"""SCIM User tests"""
@apply_blueprint("system/providers-scim.yaml")
@@ -42,7 +39,7 @@ class SCIMOAuthTests(APITestCase):
self.provider = SCIMProvider.objects.create(
name=generate_id(),
url="https://localhost",
auth_mode=SCIMAuthenticationMode.OAUTH,
auth_mode=SCIMAuthenticationMode.OAUTH_SILENT,
auth_oauth=self.source,
auth_oauth_params={
"foo": "bar",
@@ -60,8 +57,9 @@ class SCIMOAuthTests(APITestCase):
self.provider.property_mappings_group.add(
SCIMMapping.objects.get(managed="goauthentik.io/providers/scim/group")
)
self.admin = create_test_admin_user()
def test_retrieve_token(self):
def test_retrieve_token_silent(self):
"""Test token retrieval"""
with Mocker() as mocker:
token = generate_id()
@@ -86,6 +84,44 @@ class SCIMOAuthTests(APITestCase):
)
self.assertEqual(mocker.request_history[0].body, "grant_type=password&foo=bar")
def test_retrieve_token_interactive(self):
"""Test token retrieval"""
self.provider.auth_mode = SCIMAuthenticationMode.OAUTH_INTERACTIVE
self.provider.save()
refresh_token = generate_id()
access_token = generate_id()
UserOAuthSourceConnection.objects.create(
user=self.provider.auth_oauth_user,
source=self.source,
refresh_token=refresh_token,
access_token=access_token,
)
with Mocker() as mocker:
token = generate_id()
mocker.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
self.provider.scim_auth()
conn = UserOAuthSourceConnection.objects.filter(
source=self.source,
user=self.provider.auth_oauth_user,
).first()
self.assertIsNotNone(conn)
self.assertTrue(conn.is_valid)
auth = (
b64encode(
b":".join((self.source.consumer_key.encode(), self.source.consumer_secret.encode()))
)
.strip()
.decode()
)
self.assertEqual(
mocker.request_history[0].headers["Authorization"],
f"Basic {auth}",
)
self.assertEqual(
mocker.request_history[0].body,
f"grant_type=refresh_token&refresh_token={refresh_token}&foo=bar",
)
def test_existing_token(self):
"""Test existing token"""
UserOAuthSourceConnection.objects.create(
@@ -98,96 +134,54 @@ class SCIMOAuthTests(APITestCase):
self.provider.scim_auth()
self.assertEqual(len(mocker.request_history), 0)
@Mocker()
def test_user_create(self, mock: Mocker):
"""Test user creation"""
scim_id = generate_id()
token = generate_id()
mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
mock.get(
"https://localhost/ServiceProviderConfig",
json={},
)
mock.post(
"https://localhost/Users",
json={
"id": scim_id,
},
)
uid = generate_id()
user = User.objects.create(
username=uid,
name=f"{uid} {uid}",
email=f"{uid}@goauthentik.io",
)
self.assertEqual(mock.call_count, 3)
self.assertEqual(mock.request_history[1].method, "GET")
self.assertEqual(mock.request_history[2].method, "POST")
self.assertJSONEqual(
mock.request_history[2].body,
{
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True,
"emails": [
{
"primary": True,
"type": "other",
"value": f"{uid}@goauthentik.io",
}
],
"externalId": user.uid,
"name": {
"familyName": uid,
"formatted": f"{uid} {uid}",
"givenName": uid,
def test_interactive_start(self):
self.client.force_login(self.admin)
res = self.client.get(
reverse(
"authentik_enterprise_providers_scim:start",
kwargs={
"application_slug": self.app.slug,
},
"displayName": f"{uid} {uid}",
"userName": uid,
},
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
def test_api_create(self):
License.objects.create(key=generate_id())
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 201)
self.assertEqual(res.status_code, 302)
query = parse_qs(urlparse(res.url).query)
self.assertEqual(query["client_id"], [self.source.consumer_key])
self.assertEqual(
query["redirect_uri"],
[f"http://testserver/application/scim/{self.app.slug}/oauth2/callback/"],
)
self.assertEqual(query["response_type"], ["code"])
@patch(
"authentik.enterprise.models.LicenseUsageStatus.is_valid",
PropertyMock(return_value=False),
)
def test_api_create_no_license(self):
self.client.force_login(create_test_admin_user())
res = self.client.post(
reverse("authentik_api:scimprovider-list"),
{
"name": generate_id(),
"url": "http://localhost",
"auth_mode": "oauth",
"auth_oauth": str(self.source.pk),
},
)
self.assertEqual(res.status_code, 400)
self.assertJSONEqual(
res.content, {"auth_mode": ["Enterprise is required to use the OAuth mode."]}
def test_interactive_callback(self):
self.client.force_login(self.admin)
res = self.client.get(
reverse(
"authentik_enterprise_providers_scim:start",
kwargs={
"application_slug": self.app.slug,
},
)
)
self.assertEqual(res.status_code, 302)
query = parse_qs(urlparse(res.url).query)
with Mocker() as mock:
token = generate_id()
mock.post("http://localhost/token", json={"access_token": token, "expires_in": 3600})
res = self.client.get(
reverse(
"authentik_enterprise_providers_scim:callback",
kwargs={
"application_slug": self.app.slug,
},
)
+ "?"
+ urlencode({"state": query["state"][0], "code": generate_id()})
)
self.assertEqual(res.status_code, 302)
conn = UserOAuthSourceConnection.objects.filter(source=self.source).first()
self.assertIsNotNone(conn)
self.assertTrue(conn.is_valid)
@@ -0,0 +1,10 @@
from django.urls import path
from authentik.enterprise.providers.scim.views import SCIMOAuthStart, SCIMRedirectCallback
urlpatterns = [
path("<slug:application_slug>/oauth2/start/", SCIMOAuthStart.as_view(), name="start"),
path(
"<slug:application_slug>/oauth2/callback/", SCIMRedirectCallback.as_view(), name="callback"
),
]
@@ -0,0 +1,70 @@
from datetime import timedelta
from django.core.exceptions import PermissionDenied
from django.http import HttpRequest
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.models import Application
from authentik.providers.scim.models import SCIMProvider
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.registry import RequestKind, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
class SCIMOAuthViewMixin:
provider: SCIMProvider
def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
source: OAuthSource = self.provider.auth_oauth
source_cls = registry.find(source.provider_type, kind=RequestKind.CALLBACK)
if not source_cls.client_class:
return super().get_client(source, **kwargs)
return source_cls.client_class(source, self.request, **kwargs)
def _get_scim_provider(self, app_slug: str):
app = Application.objects.filter(slug=app_slug).first()
if not app:
return None
provider = SCIMProvider.objects.filter(backchannel_application=app)
return provider.first()
def dispatch(self, request: HttpRequest, application_slug: str):
if not request.user.is_authenticated:
raise PermissionDenied()
provider = self._get_scim_provider(application_slug)
if not provider or not provider.auth_oauth:
raise PermissionDenied()
if not request.user.has_perm(
"authentik_providers_scim.change_scimprovider",
provider,
):
raise PermissionDenied()
self.provider = provider
return super().dispatch(request, source_slug=provider.auth_oauth.slug)
class SCIMOAuthStart(SCIMOAuthViewMixin, OAuthRedirect):
def get_callback_url(self, source: OAuthSource):
return reverse("authentik_enterprise_providers_scim:callback", kwargs=self.kwargs)
class SCIMRedirectCallback(SCIMOAuthViewMixin, OAuthCallback):
def redirect_flow_manager(self, client: BaseOAuthClient):
expires_in = int(self.token.get("expires_in", 0))
UserOAuthSourceConnection.objects.update_or_create(
source=self.provider.auth_oauth,
user=self.provider.auth_oauth_user,
defaults={
"access_token": self.token.get("access_token"),
"refresh_token": self.token.get("refresh_token"),
"expires": now() + timedelta(seconds=expires_in),
},
)
return redirect("authentik_core:if-admin")
+10
View File
@@ -1,5 +1,6 @@
"""SCIM Provider API Views"""
from rest_framework.fields import SerializerMethodField
from rest_framework.viewsets import ModelViewSet
from authentik.core.api.providers import ProviderSerializer
@@ -16,6 +17,11 @@ class SCIMProviderSerializer(
):
"""SCIMProvider Serializer"""
auth_oauth_token_last_updated = SerializerMethodField()
auth_oauth_token_expires = SerializerMethodField()
auth_oauth_url_callback = SerializerMethodField()
auth_oauth_url_start = SerializerMethodField()
class Meta:
model = SCIMProvider
fields = [
@@ -35,6 +41,10 @@ class SCIMProviderSerializer(
"auth_mode",
"auth_oauth",
"auth_oauth_params",
"auth_oauth_token_last_updated",
"auth_oauth_token_expires",
"auth_oauth_url_callback",
"auth_oauth_url_start",
"compatibility_mode",
"service_provider_config_cache_timeout",
"exclude_users_service_account",
@@ -102,4 +102,16 @@ class Migration(migrations.Migration):
verbose_name="SCIM Compatibility Mode",
),
),
migrations.AlterField(
model_name="scimprovider",
name="auth_mode",
field=models.TextField(
choices=[
("token", "Token"),
("oauth", "OAuth (Silent)"),
("oauth_interactive", "OAuth (interactive)"),
],
default="token",
),
),
]
+6 -2
View File
@@ -72,7 +72,8 @@ class SCIMAuthenticationMode(models.TextChoices):
"""SCIM authentication modes"""
TOKEN = "token", _("Token")
OAUTH = "oauth", _("OAuth")
OAUTH_SILENT = "oauth", _("OAuth (Silent)")
OAUTH_INTERACTIVE = "oauth_interactive", _("OAuth (interactive)")
class SCIMCompatibilityMode(models.TextChoices):
@@ -144,7 +145,10 @@ class SCIMProvider(OutgoingSyncProvider, BackchannelProvider):
)
def scim_auth(self) -> AuthBase:
if self.auth_mode == SCIMAuthenticationMode.OAUTH:
if self.auth_mode in [
SCIMAuthenticationMode.OAUTH_SILENT,
SCIMAuthenticationMode.OAUTH_INTERACTIVE,
]:
try:
from authentik.enterprise.providers.scim.auth_oauth2 import SCIMOAuthAuth
+1 -2
View File
@@ -1,6 +1,5 @@
"""Source type manager"""
from collections.abc import Callable
from enum import Enum
from typing import Any
@@ -114,7 +113,7 @@ class SourceTypeRegistry:
)
return found_type
def find(self, type_name: str, kind: RequestKind) -> Callable:
def find(self, type_name: str, kind: RequestKind) -> type[OAuthCallback | OAuthRedirect]:
"""Find fitting Source Type"""
found_type = self.find_type(type_name)
if kind == RequestKind.CALLBACK:
+16 -4
View File
@@ -15,6 +15,7 @@ from structlog.stdlib import get_logger
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.events.models import Event, EventAction
from authentik.sources.oauth.clients.base import BaseOAuthClient
from authentik.sources.oauth.models import (
GroupOAuthSourceConnection,
OAuthSource,
@@ -29,7 +30,7 @@ class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view."
source: OAuthSource
token: dict | None = None
token: dict[str, Any] | None = None
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
"""View Get handler"""
@@ -49,20 +50,31 @@ class OAuthCallback(OAuthClientMixin, View):
if "error" in self.token:
return self.handle_login_failure(self.token["error"])
# Fetch profile info
try:
res = self.redirect_flow_manager(client)
except ValueError as exc:
# if we're authenticated and not in a source stage and this new flag is enabled,
# just continue
if self.request.user.is_authenticated:
pass
return self.handle_login_failure(exc.args[0])
return res
def redirect_flow_manager(self, client: BaseOAuthClient) -> HttpResponse:
try:
raw_info = client.get_profile_info(self.token)
if raw_info is None:
return self.handle_login_failure("Could not retrieve profile.")
raise ValueError("Could not retrieve profile.")
except JSONDecodeError as exc:
Event.new(
EventAction.CONFIGURATION_ERROR,
message="Failed to JSON-decode profile.",
raw_profile=exc.doc,
).from_http(self.request)
return self.handle_login_failure("Could not retrieve profile.")
raise ValueError("Could not retrieve profile.") from None
identifier = self.get_user_id(info=raw_info)
if identifier is None:
return self.handle_login_failure("Could not determine id.")
raise ValueError("Could not determine id.")
sfm = OAuthSourceFlowManager(
source=self.source,
request=self.request,
+2 -1
View File
@@ -11203,7 +11203,8 @@
"type": "string",
"enum": [
"token",
"oauth"
"oauth",
"oauth_interactive"
],
"title": "Auth mode"
},
@@ -19,6 +19,7 @@
export const SCIMAuthenticationModeEnum = {
Token: "token",
Oauth: "oauth",
OauthInteractive: "oauth_interactive",
UnknownDefaultOpenApi: "11184809",
} as const;
export type SCIMAuthenticationModeEnum =
+45
View File
@@ -125,6 +125,30 @@ export interface SCIMProvider {
* @memberof SCIMProvider
*/
authOauthParams?: { [key: string]: any };
/**
*
* @type {Date}
* @memberof SCIMProvider
*/
readonly authOauthTokenLastUpdated: Date | null;
/**
*
* @type {Date}
* @memberof SCIMProvider
*/
readonly authOauthTokenExpires: Date | null;
/**
*
* @type {string}
* @memberof SCIMProvider
*/
readonly authOauthUrlCallback: string | null;
/**
*
* @type {string}
* @memberof SCIMProvider
*/
readonly authOauthUrlStart: string | null;
/**
* Alter authentik behavior for vendor-specific SCIM implementations.
* @type {CompatibilityModeEnum}
@@ -190,6 +214,13 @@ export function instanceOfSCIMProvider(value: object): value is SCIMProvider {
if (!("verboseNamePlural" in value) || value["verboseNamePlural"] === undefined) return false;
if (!("metaModelName" in value) || value["metaModelName"] === undefined) return false;
if (!("url" in value) || value["url"] === undefined) return false;
if (!("authOauthTokenLastUpdated" in value) || value["authOauthTokenLastUpdated"] === undefined)
return false;
if (!("authOauthTokenExpires" in value) || value["authOauthTokenExpires"] === undefined)
return false;
if (!("authOauthUrlCallback" in value) || value["authOauthUrlCallback"] === undefined)
return false;
if (!("authOauthUrlStart" in value) || value["authOauthUrlStart"] === undefined) return false;
return true;
}
@@ -223,6 +254,16 @@ export function SCIMProviderFromJSONTyped(json: any, ignoreDiscriminator: boolea
: SCIMAuthenticationModeEnumFromJSON(json["auth_mode"]),
authOauth: json["auth_oauth"] == null ? undefined : json["auth_oauth"],
authOauthParams: json["auth_oauth_params"] == null ? undefined : json["auth_oauth_params"],
authOauthTokenLastUpdated:
json["auth_oauth_token_last_updated"] == null
? null
: new Date(json["auth_oauth_token_last_updated"]),
authOauthTokenExpires:
json["auth_oauth_token_expires"] == null
? null
: new Date(json["auth_oauth_token_expires"]),
authOauthUrlCallback: json["auth_oauth_url_callback"],
authOauthUrlStart: json["auth_oauth_url_start"],
compatibilityMode:
json["compatibility_mode"] == null
? undefined
@@ -256,6 +297,10 @@ export function SCIMProviderToJSONTyped(
| "verbose_name"
| "verbose_name_plural"
| "meta_model_name"
| "auth_oauth_token_last_updated"
| "auth_oauth_token_expires"
| "auth_oauth_url_callback"
| "auth_oauth_url_start"
> | null,
ignoreDiscriminator: boolean = false,
): any {
+23
View File
@@ -54916,6 +54916,7 @@ components:
enum:
- token
- oauth
- oauth_interactive
type: string
SCIMMapping:
type: object
@@ -55050,6 +55051,24 @@ components:
type: object
additionalProperties: {}
description: Additional OAuth parameters, such as grant_type
auth_oauth_token_last_updated:
type: string
format: date-time
nullable: true
readOnly: true
auth_oauth_token_expires:
type: string
format: date-time
nullable: true
readOnly: true
auth_oauth_url_callback:
type: string
nullable: true
readOnly: true
auth_oauth_url_start:
type: string
nullable: true
readOnly: true
compatibility_mode:
allOf:
- $ref: '#/components/schemas/CompatibilityModeEnum'
@@ -55082,6 +55101,10 @@ components:
required:
- assigned_backchannel_application_name
- assigned_backchannel_application_slug
- auth_oauth_token_expires
- auth_oauth_token_last_updated
- auth_oauth_url_callback
- auth_oauth_url_start
- component
- meta_model_name
- name
@@ -93,6 +93,7 @@ export function renderAuth(provider?: Partial<SCIMProvider>, errors: ValidationE
case SCIMAuthenticationModeEnum.Token:
return renderAuthToken(provider, errors);
case SCIMAuthenticationModeEnum.Oauth:
case SCIMAuthenticationModeEnum.OauthInteractive:
return renderAuthOAuth(provider, errors);
}
}
@@ -160,12 +161,18 @@ export function renderForm({ provider, errors, update }: SCIMProviderFormProps)
)}`,
},
{
label: msg("OAuth"),
label: msg("OAuth (Silent)"),
value: SCIMAuthenticationModeEnum.Oauth,
default: true,
description: html`${msg("Authenticate SCIM requests using OAuth.")}
<ak-license-notice></ak-license-notice>`,
},
{
label: msg("OAuth (Interactive)"),
value: SCIMAuthenticationModeEnum.OauthInteractive,
description: html`${msg(
"Authenticate SCIM requests using OAuth, interactively authorized.",
)} <ak-license-notice></ak-license-notice>`,
},
]}
></ak-radio>
</ak-form-element-horizontal>
@@ -13,6 +13,7 @@ import "#elements/buttons/ModalButton";
import "#elements/sync/SyncStatusCard";
import "#elements/tasks/ScheduleList";
import "#elements/tasks/TaskList";
import "#elements/timestamp/ak-timestamp";
import { DEFAULT_CONFIG } from "#common/api/config";
import { EVENT_REFRESH } from "#common/constants";
@@ -20,7 +21,14 @@ import { EVENT_REFRESH } from "#common/constants";
import { AKElement } from "#elements/Base";
import { SlottedTemplateResult } from "#elements/types";
import { ModelEnum, ProvidersApi, SCIMProvider } from "@goauthentik/api";
import renderDescriptionList from "#components/DescriptionList";
import {
ModelEnum,
ProvidersApi,
SCIMAuthenticationModeEnum,
SCIMProvider,
} from "@goauthentik/api";
import MDSCIMProvider from "~docs/add-secure-apps/providers/scim/index.md";
@@ -154,6 +162,42 @@ export class SCIMProviderViewPage extends AKElement {
</main>`;
}
renderSyncStatusExtra() {
if (
this.provider?.authMode !== SCIMAuthenticationModeEnum.Oauth &&
this.provider?.authMode !== SCIMAuthenticationModeEnum.OauthInteractive
)
return nothing;
return html`
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text"
>${msg("OAuth Token last updated")}</span
>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
<ak-timestamp
.timestamp=${this.provider?.authOauthTokenLastUpdated}
></ak-timestamp>
</div>
</dd>
</div>
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text">${msg("OAuth Token expires")}</span>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
<ak-timestamp
.timestamp=${this.provider?.authOauthTokenExpires}
></ak-timestamp>
</div>
</dd>
</div>
`;
}
renderTabOverview(): SlottedTemplateResult {
if (!this.provider) {
return nothing;
@@ -168,91 +212,94 @@ export class SCIMProviderViewPage extends AKElement {
: nothing}
<div class="pf-c-page__main-section pf-m-no-padding-mobile pf-l-grid pf-m-gutter">
<div
class="pf-c-card pf-l-grid__item pf-m-12-col pf-m-6-col-on-xl pf-m-6-col-on-2xl"
class="pf-c-card pf-l-grid__item pf-m-12-col pf-m-4-col-on-xl pf-m-4-col-on-2xl"
>
<div class="pf-c-card__title">${msg("Info")}</div>
<div class="pf-c-card__body">
<dl class="pf-c-description-list">
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text">${msg("Name")}</span>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
${this.provider.name}
</div>
</dd>
</div>
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text"
>${msg("Assigned to application")}</span
${renderDescriptionList([
[msg("Name"), this.provider.name],
[
msg("Assigned to application"),
html`<ak-provider-related-application
mode="backchannel"
.provider=${this.provider}
></ak-provider-related-application>`,
],
[
msg("Dry-run"),
html`<ak-status-label
?good=${!this.provider.dryRun}
type="info"
good-label=${msg("No")}
bad-label=${msg("Yes")}
></ak-status-label>`,
],
[msg("URL"), this.provider.url],
[
msg("Service Provider Config cache timeout"),
this.provider.serviceProviderConfigCacheTimeout,
],
[
msg("Related actions"),
html`<ak-forms-modal>
<span slot="submit">${msg("Save Changes")}</span>
<span slot="header">${msg("Update SCIM Provider")}</span>
<ak-provider-scim-form
slot="form"
.instancePk=${this.provider.pk}
>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
<ak-provider-related-application
mode="backchannel"
.provider=${this.provider}
></ak-provider-related-application>
</div>
</dd>
</div>
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text"
>${msg("Dry-run")}</span
</ak-provider-scim-form>
<button
slot="trigger"
class="pf-c-button pf-m-primary pf-m-block"
>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
<ak-status-label
?good=${!this.provider.dryRun}
type="info"
good-label=${msg("No")}
bad-label=${msg("Yes")}
></ak-status-label>
</div>
</dd>
</div>
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text">${msg("URL")}</span>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
${this.provider.url}
</div>
</dd>
</div>
<div class="pf-c-description-list__group">
<dt class="pf-c-description-list__term">
<span class="pf-c-description-list__text">
${msg("Service Provider Config cache timeout")}
</span>
</dt>
<dd class="pf-c-description-list__description">
<div class="pf-c-description-list__text">
${this.provider.serviceProviderConfigCacheTimeout}
</div>
</dd>
</div>
</dl>
</div>
<div class="pf-c-card__footer">
<ak-forms-modal>
<span slot="submit">${msg("Save Changes")}</span>
<span slot="header">${msg("Update SCIM Provider")}</span>
<ak-provider-scim-form slot="form" .instancePk=${this.provider.pk}>
</ak-provider-scim-form>
<button slot="trigger" class="pf-c-button pf-m-primary">
${msg("Edit")}
</button>
</ak-forms-modal>
${msg("Edit")}
</button>
</ak-forms-modal>`,
],
])}
</div>
</div>
<div
class="pf-c-card pf-l-grid__item pf-m-12-col pf-m-6-col-on-xl pf-m-6-col-on-2xl"
>
<div class="pf-l-grid__item pf-m-12-col pf-m-8-col-on-xl pf-m-8-col-on-2xl">
${this.provider.authMode === SCIMAuthenticationModeEnum.OauthInteractive
? html`
<div class="pf-c-card">
<div class="pf-c-card__body">
${renderDescriptionList(
[
[
msg("OAuth Status"),
html`<ak-status-label
?good=${this.provider
.authOauthTokenLastUpdated !== null}
good-label=${msg("Authenticated")}
bad-label=${msg("No token saved")}
></ak-status-label>
<a
class="pf-c-button pf-m-primary"
href=${this.provider?.authOauthUrlStart ||
""}
target="_blank"
>${msg("(Re-)authenticate")}</a
>`,
],
[
msg("OAuth Callback URL"),
html`<input
class="pf-c-form-control"
readonly
type="text"
value="${this.provider.authOauthUrlCallback ||
""}"
/>`,
],
],
{ horizontal: true },
)}
</div>
</div>
`
: nothing}
<ak-sync-status-card
.fetch=${() => {
return new ProvidersApi(DEFAULT_CONFIG).providersScimSyncStatusRetrieve(
@@ -261,7 +308,9 @@ export class SCIMProviderViewPage extends AKElement {
},
);
}}
></ak-sync-status-card>
>
${this.renderSyncStatusExtra()}
</ak-sync-status-card>
</div>
<div class="pf-l-grid__item pf-m-12-col pf-l-stack__item">
<div class="pf-c-card">
+1
View File
@@ -90,6 +90,7 @@ export class SyncStatusCard extends AKElement {
</div>
</dd>
</div>
<slot></slot>
</dl>
`;
}