enterprise: UI improvements, better handling of expiry (#10828)

* web/admin: show enterprise banner on the very top

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* rework license

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix a bunch of things

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add some more tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add more tests

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* fix middleware

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* better api

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* format

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* add tests for and fix read only mode

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

* field name consistency

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L.
2024-08-09 14:26:38 +02:00
committed by GitHub
parent 3265b4af01
commit 4b5bb77d99
20 changed files with 749 additions and 194 deletions
+1 -1
View File
@@ -73,7 +73,7 @@ class SystemInfoSerializer(PassiveSerializer):
"authentik_version": get_full_version(),
"environment": get_env(),
"openssl_fips_enabled": (
backend._fips_enabled if LicenseKey.get_total().is_valid() else None
backend._fips_enabled if LicenseKey.get_total().status().is_valid else None
),
"openssl_version": OPENSSL_VERSION,
"platform": platform.platform(),
+1 -1
View File
@@ -171,7 +171,7 @@ class Importer:
def default_context(self):
"""Default context"""
return {
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().is_valid(),
"goauthentik.io/enterprise/licensed": LicenseKey.get_total().status().is_valid,
"goauthentik.io/rbac/models": rbac_models(),
}
+3 -3
View File
@@ -19,7 +19,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import ModelSerializer, PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.license import LicenseKey, LicenseSummarySerializer
from authentik.enterprise.models import License
from authentik.enterprise.models import License, LicenseUsageStatus
from authentik.rbac.decorators import permission_required
from authentik.tenants.utils import get_unique_identifier
@@ -30,7 +30,7 @@ class EnterpriseRequiredMixin:
def validate(self, attrs: dict) -> dict:
"""Check that a valid license exists"""
if not LicenseKey.cached_summary().has_license:
if LicenseKey.cached_summary().status != LicenseUsageStatus.UNLICENSED:
raise ValidationError(_("Enterprise is required to create/update this object."))
return super().validate(attrs)
@@ -128,7 +128,7 @@ class LicenseViewSet(UsedByMixin, ModelViewSet):
forecast_for_months = 12
response = LicenseForecastSerializer(
data={
"internal_users": LicenseKey.get_default_user_count(),
"internal_users": LicenseKey.get_internal_user_count(),
"external_users": LicenseKey.get_external_user_count(),
"forecasted_internal_users": (internal_in_last_month * forecast_for_months),
"forecasted_external_users": (external_in_last_month * forecast_for_months),
+1 -1
View File
@@ -25,4 +25,4 @@ class AuthentikEnterpriseConfig(EnterpriseConfig):
"""Actual enterprise check, cached"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.cached_summary().valid
return LicenseKey.cached_summary().status
+77 -54
View File
@@ -3,24 +3,36 @@
from base64 import b64decode
from binascii import Error
from dataclasses import asdict, dataclass, field
from datetime import datetime, timedelta
from datetime import UTC, datetime, timedelta
from enum import Enum
from functools import lru_cache
from time import mktime
from cryptography.exceptions import InvalidSignature
from cryptography.x509 import Certificate, load_der_x509_certificate, load_pem_x509_certificate
from dacite import from_dict
from dacite import DaciteError, from_dict
from django.core.cache import cache
from django.db.models.query import QuerySet
from django.utils.timezone import now
from jwt import PyJWTError, decode, get_unverified_header
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, DateTimeField, IntegerField
from rest_framework.fields import (
ChoiceField,
DateTimeField,
IntegerField,
)
from authentik.core.api.utils import PassiveSerializer
from authentik.core.models import User, UserTypes
from authentik.enterprise.models import License, LicenseUsage
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_EXPIRY_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.tenants.utils import get_unique_identifier
CACHE_KEY_ENTERPRISE_LICENSE = "goauthentik.io/enterprise/license"
@@ -42,6 +54,8 @@ def get_license_aud() -> str:
class LicenseFlags(Enum):
"""License flags"""
TRIAL = "trial"
@dataclass
class LicenseSummary:
@@ -49,12 +63,8 @@ class LicenseSummary:
internal_users: int
external_users: int
valid: bool
show_admin_warning: bool
show_user_warning: bool
read_only: bool
status: LicenseUsageStatus
latest_valid: datetime
has_license: bool
class LicenseSummarySerializer(PassiveSerializer):
@@ -62,12 +72,8 @@ class LicenseSummarySerializer(PassiveSerializer):
internal_users = IntegerField(required=True)
external_users = IntegerField(required=True)
valid = BooleanField()
show_admin_warning = BooleanField()
show_user_warning = BooleanField()
read_only = BooleanField()
status = ChoiceField(choices=LicenseUsageStatus.choices)
latest_valid = DateTimeField()
has_license = BooleanField()
@dataclass
@@ -83,7 +89,7 @@ class LicenseKey:
flags: list[LicenseFlags] = field(default_factory=list)
@staticmethod
def validate(jwt: str) -> "LicenseKey":
def validate(jwt: str, check_expiry=True) -> "LicenseKey":
"""Validate the license from a given JWT"""
try:
headers = get_unverified_header(jwt)
@@ -107,6 +113,7 @@ class LicenseKey:
our_cert.public_key(),
algorithms=["ES512"],
audience=get_license_aud(),
options={"verify_exp": check_expiry},
),
)
except PyJWTError:
@@ -116,9 +123,8 @@ class LicenseKey:
@staticmethod
def get_total() -> "LicenseKey":
"""Get a summarized version of all (not expired) licenses"""
active_licenses = License.objects.filter(expiry__gte=now())
total = LicenseKey(get_license_aud(), 0, "Summarized license", 0, 0)
for lic in active_licenses:
for lic in License.objects.all():
total.internal_users += lic.internal_users
total.external_users += lic.external_users
exp_ts = int(mktime(lic.expiry.timetuple()))
@@ -135,7 +141,7 @@ class LicenseKey:
return User.objects.all().exclude_anonymous().exclude(is_active=False)
@staticmethod
def get_default_user_count():
def get_internal_user_count():
"""Get current default user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.INTERNAL).count()
@@ -144,59 +150,72 @@ class LicenseKey:
"""Get current external user count"""
return LicenseKey.base_user_qs().filter(type=UserTypes.EXTERNAL).count()
def is_valid(self) -> bool:
"""Check if the given license body covers all users
def _last_valid_date(self):
last_valid_date = (
LicenseUsage.objects.order_by("-record_date")
.filter(status=LicenseUsageStatus.VALID)
.first()
)
if not last_valid_date:
return datetime.fromtimestamp(0, UTC)
return last_valid_date.record_date
Only checks the current count, no historical data is checked"""
default_users = self.get_default_user_count()
if default_users > self.internal_users:
return False
active_users = self.get_external_user_count()
if active_users > self.external_users:
return False
return True
def status(self) -> LicenseUsageStatus:
"""Check if the given license body covers all users, and is valid."""
last_valid = self._last_valid_date()
if self.exp == 0 and not License.objects.exists():
return LicenseUsageStatus.UNLICENSED
_now = now()
# Check limit-exceeded based status
internal_users = self.get_internal_user_count()
external_users = self.get_external_user_count()
if internal_users > self.internal_users or external_users > self.external_users:
if last_valid < _now - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS):
return LicenseUsageStatus.READ_ONLY
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_USER
if last_valid < _now - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS):
return LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
# Check expiry based status
if datetime.fromtimestamp(self.exp, UTC) < _now:
if datetime.fromtimestamp(self.exp, UTC) < _now - timedelta(
weeks=THRESHOLD_READ_ONLY_WEEKS
):
return LicenseUsageStatus.READ_ONLY
return LicenseUsageStatus.EXPIRED
# Expiry warning
if datetime.fromtimestamp(self.exp, UTC) <= _now + timedelta(
weeks=THRESHOLD_WARNING_EXPIRY_WEEKS
):
return LicenseUsageStatus.EXPIRY_SOON
return LicenseUsageStatus.VALID
def record_usage(self):
"""Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if not LicenseUsage.objects.filter(record_date__gte=threshold).exists():
LicenseUsage.objects.create(
user_count=self.get_default_user_count(),
usage = (
LicenseUsage.objects.order_by("-record_date").filter(record_date__gte=threshold).first()
)
if not usage:
usage = LicenseUsage.objects.create(
internal_user_count=self.get_internal_user_count(),
external_user_count=self.get_external_user_count(),
within_limits=self.is_valid(),
status=self.status(),
)
summary = asdict(self.summary())
# Also cache the latest summary for the middleware
cache.set(CACHE_KEY_ENTERPRISE_LICENSE, summary, timeout=CACHE_EXPIRY_ENTERPRISE_LICENSE)
return summary
@staticmethod
def last_valid_date() -> datetime:
"""Get the last date the license was valid"""
usage: LicenseUsage = (
LicenseUsage.filter_not_expired(within_limits=True).order_by("-record_date").first()
)
if not usage:
return now()
return usage.record_date
return usage
def summary(self) -> LicenseSummary:
"""Summary of license status"""
has_license = License.objects.all().count() > 0
last_valid = LicenseKey.last_valid_date()
show_admin_warning = last_valid < now() - timedelta(weeks=2)
show_user_warning = last_valid < now() - timedelta(weeks=4)
read_only = last_valid < now() - timedelta(weeks=6)
status = self.status()
latest_valid = datetime.fromtimestamp(self.exp)
return LicenseSummary(
show_admin_warning=show_admin_warning and has_license,
show_user_warning=show_user_warning and has_license,
read_only=read_only and has_license,
latest_valid=latest_valid,
internal_users=self.internal_users,
external_users=self.external_users,
valid=self.is_valid(),
has_license=has_license,
status=status,
)
@staticmethod
@@ -205,4 +224,8 @@ class LicenseKey:
summary = cache.get(CACHE_KEY_ENTERPRISE_LICENSE)
if not summary:
return LicenseKey.get_total().summary()
return from_dict(LicenseSummary, summary)
try:
return from_dict(LicenseSummary, summary)
except DaciteError:
cache.delete(CACHE_KEY_ENTERPRISE_LICENSE)
return LicenseKey.get_total().summary()
+4 -3
View File
@@ -8,6 +8,7 @@ from structlog.stdlib import BoundLogger, get_logger
from authentik.enterprise.api import LicenseViewSet
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import LicenseUsageStatus
from authentik.flows.views.executor import FlowExecutorView
from authentik.lib.utils.reflection import class_to_path
@@ -43,7 +44,7 @@ class EnterpriseMiddleware:
cached_status = LicenseKey.cached_summary()
if not cached_status:
return True
if cached_status.read_only:
if cached_status.status == LicenseUsageStatus.READ_ONLY:
return False
return True
@@ -53,10 +54,10 @@ class EnterpriseMiddleware:
if request.method.lower() in ["get", "head", "options", "trace"]:
return True
# Always allow requests to manage licenses
if class_to_path(request.resolver_match.func) == class_to_path(LicenseViewSet):
if request.resolver_match._func_path == class_to_path(LicenseViewSet):
return True
# Flow executor is mounted as an API path but explicitly allowed
if class_to_path(request.resolver_match.func) == class_to_path(FlowExecutorView):
if request.resolver_match._func_path == class_to_path(FlowExecutorView):
return True
# Only apply these restrictions to the API
if "authentik_api" not in request.resolver_match.app_names:
@@ -0,0 +1,68 @@
# Generated by Django 5.0.8 on 2024-08-08 14:15
from django.db import migrations, models
from django.apps.registry import Apps
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def migrate_license_usage(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
LicenseUsage = apps.get_model("authentik_enterprise", "licenseusage")
db_alias = schema_editor.connection.alias
for usage in LicenseUsage.objects.using(db_alias).all():
usage.status = "valid" if usage.within_limits else "limit_exceeded_admin"
usage.save()
class Migration(migrations.Migration):
dependencies = [
("authentik_enterprise", "0002_rename_users_license_internal_users_and_more"),
]
operations = [
migrations.AddField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
default=None,
null=True,
),
preserve_default=False,
),
migrations.RunPython(migrate_license_usage),
migrations.RemoveField(
model_name="licenseusage",
name="within_limits",
),
migrations.AlterField(
model_name="licenseusage",
name="status",
field=models.TextField(
choices=[
("unlicensed", "Unlicensed"),
("valid", "Valid"),
("expired", "Expired"),
("expiry_soon", "Expiry Soon"),
("limit_exceeded_admin", "Limit Exceeded Admin"),
("limit_exceeded_user", "Limit Exceeded User"),
("read_only", "Read Only"),
],
),
preserve_default=False,
),
migrations.RenameField(
model_name="licenseusage",
old_name="user_count",
new_name="internal_user_count",
),
]
+31 -6
View File
@@ -17,6 +17,17 @@ if TYPE_CHECKING:
from authentik.enterprise.license import LicenseKey
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
THRESHOLD_WARNING_ADMIN_WEEKS = 2
THRESHOLD_WARNING_USER_WEEKS = 4
THRESHOLD_WARNING_EXPIRY_WEEKS = 2
THRESHOLD_READ_ONLY_WEEKS = 6
class License(SerializerModel):
"""An authentik enterprise license"""
@@ -39,7 +50,7 @@ class License(SerializerModel):
"""Get parsed license status"""
from authentik.enterprise.license import LicenseKey
return LicenseKey.validate(self.key)
return LicenseKey.validate(self.key, check_expiry=False)
class Meta:
indexes = (HashIndex(fields=("key",)),)
@@ -47,9 +58,23 @@ class License(SerializerModel):
verbose_name_plural = _("Licenses")
def usage_expiry():
"""Keep license usage records for 3 months"""
return now() + timedelta(days=30 * 3)
class LicenseUsageStatus(models.TextChoices):
"""License states an instance/tenant can be in"""
UNLICENSED = "unlicensed"
VALID = "valid"
EXPIRED = "expired"
EXPIRY_SOON = "expiry_soon"
# User limit exceeded, 2 week threshold, show message in admin interface
LIMIT_EXCEEDED_ADMIN = "limit_exceeded_admin"
# User limit exceeded, 4 week threshold, show message in user interface
LIMIT_EXCEEDED_USER = "limit_exceeded_user"
READ_ONLY = "read_only"
@property
def is_valid(self) -> bool:
"""Quickly check if a license is valid"""
return self in [LicenseUsageStatus.VALID, LicenseUsageStatus.EXPIRY_SOON]
class LicenseUsage(ExpiringModel):
@@ -59,9 +84,9 @@ class LicenseUsage(ExpiringModel):
usage_uuid = models.UUIDField(primary_key=True, editable=False, default=uuid4)
user_count = models.BigIntegerField()
internal_user_count = models.BigIntegerField()
external_user_count = models.BigIntegerField()
within_limits = models.BooleanField()
status = models.TextField(choices=LicenseUsageStatus.choices)
record_date = models.DateTimeField(auto_now_add=True)
+1 -1
View File
@@ -13,7 +13,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
def check_license(self):
"""Check license"""
if not LicenseKey.get_total().is_valid():
if not LicenseKey.get_total().status().is_valid:
return PolicyResult(False, _("Enterprise required to access this feature."))
if self.request.user.type != UserTypes.INTERNAL:
return PolicyResult(False, _("Feature only accessible for internal users."))
+200 -9
View File
@@ -9,10 +9,26 @@ from django.utils.timezone import now
from rest_framework.exceptions import ValidationError
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import License
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
THRESHOLD_WARNING_ADMIN_WEEKS,
THRESHOLD_WARNING_USER_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.lib.generators import generate_id
_exp = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry
expiry_valid = int(mktime((now() + timedelta(days=3000)).timetuple()))
# Valid license expiry, expires soon
expiry_soon = int(mktime((now() + timedelta(hours=10)).timetuple()))
# Invalid license expiry, recently expired
expiry_expired = int(mktime((now() - timedelta(hours=10)).timetuple()))
# Invalid license expiry, expired longer ago
expiry_expired_read_only = int(
mktime((now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)).timetuple())
)
class TestEnterpriseLicense(TestCase):
@@ -23,7 +39,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock(
return_value=LicenseKey(
aud="",
exp=_exp,
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
@@ -33,7 +49,7 @@ class TestEnterpriseLicense(TestCase):
def test_valid(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.is_valid())
self.assertTrue(lic.status.status().is_valid)
self.assertEqual(lic.internal_users, 100)
def test_invalid(self):
@@ -46,7 +62,7 @@ class TestEnterpriseLicense(TestCase):
MagicMock(
return_value=LicenseKey(
aud="",
exp=_exp,
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
@@ -56,11 +72,186 @@ class TestEnterpriseLicense(TestCase):
def test_valid_multiple(self):
"""Check license verification"""
lic = License.objects.create(key=generate_id())
self.assertTrue(lic.status.is_valid())
self.assertTrue(lic.status.status().is_valid)
lic2 = License.objects.create(key=generate_id())
self.assertTrue(lic2.status.is_valid())
self.assertTrue(lic2.status.status().is_valid)
total = LicenseKey.get_total()
self.assertEqual(total.internal_users, 200)
self.assertEqual(total.external_users, 200)
self.assertEqual(total.exp, _exp)
self.assertTrue(total.is_valid())
self.assertEqual(total.exp, expiry_valid)
self.assertTrue(total.status().is_valid)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_user_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_USER_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_USER
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_limit_exceeded_admin_warning(self):
"""Check license verification"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_WARNING_ADMIN_WEEKS + 1)
usage.save(update_fields=["record_date"])
self.assertEqual(
LicenseKey.get_total().summary().status, LicenseUsageStatus.LIMIT_EXCEEDED_ADMIN
)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired_read_only,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_read_only(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.READ_ONLY)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_expired,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_expired(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRED)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_soon,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_expiry_soon(self):
"""Check license verification"""
License.objects.create(key=generate_id())
self.assertEqual(LicenseKey.get_total().summary().status, LicenseUsageStatus.EXPIRY_SOON)
@@ -0,0 +1,217 @@
"""read only tests"""
from datetime import timedelta
from unittest.mock import MagicMock, patch
from django.urls import reverse
from django.utils.timezone import now
from authentik.core.tests.utils import create_test_admin_user, create_test_flow, create_test_user
from authentik.enterprise.license import LicenseKey
from authentik.enterprise.models import (
THRESHOLD_READ_ONLY_WEEKS,
License,
LicenseUsage,
LicenseUsageStatus,
)
from authentik.enterprise.tests.test_license import expiry_valid
from authentik.flows.models import (
FlowDesignation,
FlowStageBinding,
)
from authentik.flows.tests import FlowTestCase
from authentik.lib.generators import generate_id
from authentik.stages.identification.models import IdentificationStage, UserFields
from authentik.stages.password import BACKEND_INBUILT
from authentik.stages.password.models import PasswordStage
from authentik.stages.user_login.models import UserLoginStage
class TestReadOnly(FlowTestCase):
"""Test read_only"""
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_login(self):
"""Test flow, ensure login is still possible with read only mode"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
flow = create_test_flow(
FlowDesignation.AUTHENTICATION,
)
ident_stage = IdentificationStage.objects.create(
name=generate_id(),
user_fields=[UserFields.E_MAIL],
pretend_user_exists=False,
)
FlowStageBinding.objects.create(
target=flow,
stage=ident_stage,
order=0,
)
password_stage = PasswordStage.objects.create(
name=generate_id(), backends=[BACKEND_INBUILT]
)
FlowStageBinding.objects.create(
target=flow,
stage=password_stage,
order=1,
)
login_stage = UserLoginStage.objects.create(
name=generate_id(),
)
FlowStageBinding.objects.create(
target=flow,
stage=login_stage,
order=2,
)
user = create_test_user()
exec_url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug})
response = self.client.get(exec_url)
self.assertStageResponse(
response,
flow,
component="ak-stage-identification",
password_fields=False,
primary_action="Log in",
sources=[],
show_source_labels=False,
user_fields=[UserFields.E_MAIL],
)
response = self.client.post(exec_url, {"uid_field": user.email}, follow=True)
self.assertStageResponse(response, flow, component="ak-stage-password")
response = self.client.post(exec_url, {"password": user.username}, follow=True)
self.assertStageRedirects(response, reverse("authentik_core:root-redirect"))
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_licenses(self):
"""Test that managing licenses is still possible"""
license = License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Reading is always allowed
response = self.client.get(reverse("authentik_api:license-list"))
self.assertEqual(response.status_code, 200)
# Writing should also be allowed
response = self.client.patch(
reverse("authentik_api:license-detail", kwargs={"pk": license.pk})
)
self.assertEqual(response.status_code, 200)
@patch(
"authentik.enterprise.license.LicenseKey.validate",
MagicMock(
return_value=LicenseKey(
aud="",
exp=expiry_valid,
name=generate_id(),
internal_users=100,
external_users=100,
)
),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_internal_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.get_external_user_count",
MagicMock(return_value=1000),
)
@patch(
"authentik.enterprise.license.LicenseKey.record_usage",
MagicMock(),
)
def test_manage_flows(self):
"""Test flow"""
License.objects.create(key=generate_id())
usage = LicenseUsage.objects.create(
internal_user_count=100,
external_user_count=100,
status=LicenseUsageStatus.VALID,
)
usage.record_date = now() - timedelta(weeks=THRESHOLD_READ_ONLY_WEEKS + 1)
usage.save(update_fields=["record_date"])
admin = create_test_admin_user()
self.client.force_login(admin)
# Read only is still allowed
response = self.client.get(reverse("authentik_api:flow-list"))
self.assertEqual(response.status_code, 200)
flow = create_test_flow()
# Writing is not
response = self.client.patch(
reverse("authentik_api:flow-detail", kwargs={"slug": flow.slug})
)
self.assertJSONEqual(
response.content,
{"detail": "Request denied due to expired/invalid license.", "code": "denied_license"},
)
self.assertEqual(response.status_code, 400)
+1 -1
View File
@@ -140,7 +140,7 @@ class OutpostHealthSerializer(PassiveSerializer):
def get_fips_enabled(self, obj: dict) -> bool | None:
"""Get FIPS enabled"""
if not LicenseKey.get_total().is_valid():
if not LicenseKey.get_total().status().is_valid:
return None
return obj["fips_enabled"]
+24 -32
View File
@@ -6321,22 +6321,6 @@
"authentik_rbac.edit_system_settings",
"authentik_rbac.view_system_info",
"authentik_rbac.view_system_settings",
"authentik_sources_kerberos.add_groupkerberossourceconnection",
"authentik_sources_kerberos.change_groupkerberossourceconnection",
"authentik_sources_kerberos.delete_groupkerberossourceconnection",
"authentik_sources_kerberos.view_groupkerberossourceconnection",
"authentik_sources_kerberos.add_kerberospropertymapping",
"authentik_sources_kerberos.change_kerberospropertymapping",
"authentik_sources_kerberos.delete_kerberospropertymapping",
"authentik_sources_kerberos.view_kerberospropertymapping",
"authentik_sources_kerberos.add_kerberossource",
"authentik_sources_kerberos.change_kerberossource",
"authentik_sources_kerberos.delete_kerberossource",
"authentik_sources_kerberos.view_kerberossource",
"authentik_sources_kerberos.add_userkerberossourceconnection",
"authentik_sources_kerberos.change_userkerberossourceconnection",
"authentik_sources_kerberos.delete_userkerberossourceconnection",
"authentik_sources_kerberos.view_userkerberossourceconnection",
"authentik_sources_ldap.add_ldapsource",
"authentik_sources_ldap.change_ldapsource",
"authentik_sources_ldap.delete_ldapsource",
@@ -6361,14 +6345,26 @@
"authentik_sources_oauth.change_useroauthsourceconnection",
"authentik_sources_oauth.delete_useroauthsourceconnection",
"authentik_sources_oauth.view_useroauthsourceconnection",
"authentik_sources_plex.add_groupplexsourceconnection",
"authentik_sources_plex.change_groupplexsourceconnection",
"authentik_sources_plex.delete_groupplexsourceconnection",
"authentik_sources_plex.view_groupplexsourceconnection",
"authentik_sources_plex.add_plexsource",
"authentik_sources_plex.change_plexsource",
"authentik_sources_plex.delete_plexsource",
"authentik_sources_plex.view_plexsource",
"authentik_sources_plex.add_plexsourcepropertymapping",
"authentik_sources_plex.change_plexsourcepropertymapping",
"authentik_sources_plex.delete_plexsourcepropertymapping",
"authentik_sources_plex.view_plexsourcepropertymapping",
"authentik_sources_plex.add_plexsourceconnection",
"authentik_sources_plex.add_userplexsourceconnection",
"authentik_sources_plex.change_plexsourceconnection",
"authentik_sources_plex.change_userplexsourceconnection",
"authentik_sources_plex.delete_plexsourceconnection",
"authentik_sources_plex.delete_userplexsourceconnection",
"authentik_sources_plex.view_plexsourceconnection",
"authentik_sources_plex.view_userplexsourceconnection",
"authentik_sources_saml.add_groupsamlsourceconnection",
"authentik_sources_saml.change_groupsamlsourceconnection",
"authentik_sources_saml.delete_groupsamlsourceconnection",
@@ -11984,22 +11980,6 @@
"authentik_rbac.edit_system_settings",
"authentik_rbac.view_system_info",
"authentik_rbac.view_system_settings",
"authentik_sources_kerberos.add_groupkerberossourceconnection",
"authentik_sources_kerberos.change_groupkerberossourceconnection",
"authentik_sources_kerberos.delete_groupkerberossourceconnection",
"authentik_sources_kerberos.view_groupkerberossourceconnection",
"authentik_sources_kerberos.add_kerberospropertymapping",
"authentik_sources_kerberos.change_kerberospropertymapping",
"authentik_sources_kerberos.delete_kerberospropertymapping",
"authentik_sources_kerberos.view_kerberospropertymapping",
"authentik_sources_kerberos.add_kerberossource",
"authentik_sources_kerberos.change_kerberossource",
"authentik_sources_kerberos.delete_kerberossource",
"authentik_sources_kerberos.view_kerberossource",
"authentik_sources_kerberos.add_userkerberossourceconnection",
"authentik_sources_kerberos.change_userkerberossourceconnection",
"authentik_sources_kerberos.delete_userkerberossourceconnection",
"authentik_sources_kerberos.view_userkerberossourceconnection",
"authentik_sources_ldap.add_ldapsource",
"authentik_sources_ldap.change_ldapsource",
"authentik_sources_ldap.delete_ldapsource",
@@ -12024,14 +12004,26 @@
"authentik_sources_oauth.change_useroauthsourceconnection",
"authentik_sources_oauth.delete_useroauthsourceconnection",
"authentik_sources_oauth.view_useroauthsourceconnection",
"authentik_sources_plex.add_groupplexsourceconnection",
"authentik_sources_plex.change_groupplexsourceconnection",
"authentik_sources_plex.delete_groupplexsourceconnection",
"authentik_sources_plex.view_groupplexsourceconnection",
"authentik_sources_plex.add_plexsource",
"authentik_sources_plex.change_plexsource",
"authentik_sources_plex.delete_plexsource",
"authentik_sources_plex.view_plexsource",
"authentik_sources_plex.add_plexsourcepropertymapping",
"authentik_sources_plex.change_plexsourcepropertymapping",
"authentik_sources_plex.delete_plexsourcepropertymapping",
"authentik_sources_plex.view_plexsourcepropertymapping",
"authentik_sources_plex.add_plexsourceconnection",
"authentik_sources_plex.add_userplexsourceconnection",
"authentik_sources_plex.change_plexsourceconnection",
"authentik_sources_plex.change_userplexsourceconnection",
"authentik_sources_plex.delete_plexsourceconnection",
"authentik_sources_plex.delete_userplexsourceconnection",
"authentik_sources_plex.view_plexsourceconnection",
"authentik_sources_plex.view_userplexsourceconnection",
"authentik_sources_saml.add_groupsamlsourceconnection",
"authentik_sources_saml.change_groupsamlsourceconnection",
"authentik_sources_saml.delete_groupsamlsourceconnection",
+12 -14
View File
@@ -41386,28 +41386,26 @@ components:
type: integer
external_users:
type: integer
valid:
type: boolean
show_admin_warning:
type: boolean
show_user_warning:
type: boolean
read_only:
type: boolean
status:
$ref: '#/components/schemas/LicenseSummaryStatusEnum'
latest_valid:
type: string
format: date-time
has_license:
type: boolean
required:
- external_users
- has_license
- internal_users
- latest_valid
- read_only
- show_admin_warning
- show_user_warning
- status
LicenseSummaryStatusEnum:
enum:
- unlicensed
- valid
- expired
- expiry_soon
- limit_exceeded_admin
- limit_exceeded_user
- read_only
type: string
Link:
type: object
description: Returns a single link
@@ -71,6 +71,12 @@ export class AdminInterface extends EnterpriseAwareInterface {
:host([theme="dark"]) .pf-c-page {
--pf-c-page--BackgroundColor: var(--ak-dark-background);
}
ak-enterprise-status {
grid-area: header;
}
ak-admin-sidebar {
grid-area: nav;
}
`,
];
}
@@ -118,6 +124,7 @@ export class AdminInterface extends EnterpriseAwareInterface {
return html` <ak-locale-context>
<div class="pf-c-page">
<ak-enterprise-status interface="admin"></ak-enterprise-status>
<ak-admin-sidebar
class="pf-c-page__sidebar ${classMap(sidebarClasses)}"
></ak-admin-sidebar>
@@ -29,6 +29,7 @@ import {
License,
LicenseForecast,
LicenseSummary,
LicenseSummaryStatusEnum,
RbacPermissionsAssignedByUsersListModelEnum,
} from "@goauthentik/api";
@@ -182,7 +183,7 @@ export class EnterpriseLicenseListPage extends TablePage<License> {
header=${msg("Expiry")}
subtext=${msg("Cumulative license expiry")}
>
${this.summary?.hasLicense
${this.summary?.status === LicenseSummaryStatusEnum.Unlicensed
? html`<div>${getRelativeTime(this.summary.latestValid)}</div>
<small>${this.summary.latestValid.toLocaleString()}</small>`
: "-"}
@@ -4,7 +4,7 @@ import { Constructor } from "@goauthentik/elements/types.js";
import { consume } from "@lit/context";
import type { LitElement } from "lit";
import type { LicenseSummary } from "@goauthentik/api";
import { type LicenseSummary, LicenseSummaryStatusEnum } from "@goauthentik/api";
export function WithLicenseSummary<T extends Constructor<LitElement>>(
superclass: T,
@@ -15,7 +15,7 @@ export function WithLicenseSummary<T extends Constructor<LitElement>>(
public licenseSummary!: LicenseSummary;
get hasEnterpriseLicense() {
return this.licenseSummary?.hasLicense;
return this.licenseSummary?.status !== LicenseSummaryStatusEnum.Unlicensed;
}
}
+56 -57
View File
@@ -138,63 +138,62 @@ export class PageHeader extends WithBrandConfig(AKElement) {
}
render(): TemplateResult {
return html` <ak-enterprise-status interface="admin"></ak-enterprise-status>
<div class="bar">
<button
class="sidebar-trigger pf-c-button pf-m-plain"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_SIDEBAR_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<i class="fas fa-bars"></i>
</button>
<section class="pf-c-page__main-section pf-m-light">
<div class="pf-c-content">
<h1>
<slot name="icon">${this.renderIcon()}</slot>&nbsp;
<slot name="header">${this.header}</slot>
</h1>
${this.description ? html`<p>${this.description}</p>` : html``}
</div>
</section>
<button
class="notification-trigger pf-c-button pf-m-plain"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_API_DRAWER_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<pf-tooltip position="top" content=${msg("Open API drawer")}>
<i class="fas fa-code"></i>
</pf-tooltip>
</button>
<button
class="notification-trigger pf-c-button pf-m-plain ${this.hasNotifications
? "has-notifications"
: ""}"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_NOTIFICATION_DRAWER_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<pf-tooltip position="top" content=${msg("Open Notification drawer")}>
<i class="fas fa-bell"></i>
</pf-tooltip>
</button>
</div>`;
return html`<div class="bar">
<button
class="sidebar-trigger pf-c-button pf-m-plain"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_SIDEBAR_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<i class="fas fa-bars"></i>
</button>
<section class="pf-c-page__main-section pf-m-light">
<div class="pf-c-content">
<h1>
<slot name="icon">${this.renderIcon()}</slot>&nbsp;
<slot name="header">${this.header}</slot>
</h1>
${this.description ? html`<p>${this.description}</p>` : html``}
</div>
</section>
<button
class="notification-trigger pf-c-button pf-m-plain"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_API_DRAWER_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<pf-tooltip position="top" content=${msg("Open API drawer")}>
<i class="fas fa-code"></i>
</pf-tooltip>
</button>
<button
class="notification-trigger pf-c-button pf-m-plain ${this.hasNotifications
? "has-notifications"
: ""}"
@click=${() => {
this.dispatchEvent(
new CustomEvent(EVENT_NOTIFICATION_DRAWER_TOGGLE, {
bubbles: true,
composed: true,
}),
);
}}
>
<pf-tooltip position="top" content=${msg("Open Notification drawer")}>
<i class="fas fa-bell"></i>
</pf-tooltip>
</button>
</div>`;
}
}
@@ -7,6 +7,8 @@ import { customElement, property } from "lit/decorators.js";
import PFBanner from "@patternfly/patternfly/components/Banner/banner.css";
import { LicenseSummaryStatusEnum } from "@goauthentik/api";
@customElement("ak-enterprise-status")
export class EnterpriseStatusBanner extends WithLicenseSummary(AKElement) {
@property()
@@ -17,26 +19,58 @@ export class EnterpriseStatusBanner extends WithLicenseSummary(AKElement) {
}
renderBanner(): TemplateResult {
let message = "";
switch (this.licenseSummary.status) {
case LicenseSummaryStatusEnum.LimitExceededAdmin:
case LicenseSummaryStatusEnum.LimitExceededUser:
message = msg(
"Warning: The current user count has exceeded the configured licenses.",
);
break;
case LicenseSummaryStatusEnum.Expired:
message = msg("Warning: One or more license(s) have expired.");
break;
case LicenseSummaryStatusEnum.ExpirySoon:
message = msg(
"Warning: One or more license(s) will expire within the next 2 weeks.",
);
break;
case LicenseSummaryStatusEnum.ReadOnly:
message = msg(
"Caution: This authentik instance has entered read-only mode due to expired/exceeded licenses.",
);
break;
default:
break;
}
return html`<div
class="pf-c-banner ${this.licenseSummary?.readOnly ? "pf-m-red" : "pf-m-gold"}"
class="pf-c-banner ${this.licenseSummary?.status === LicenseSummaryStatusEnum.ReadOnly
? "pf-m-red"
: "pf-m-gold"}"
>
${msg("Warning: The current user count has exceeded the configured licenses.")}
${message}
<a href="/if/admin/#/enterprise/licenses"> ${msg("Click here for more info.")} </a>
</div>`;
}
render(): TemplateResult {
switch (this.interface.toLowerCase()) {
case "admin":
if (this.licenseSummary?.showAdminWarning || this.licenseSummary?.readOnly) {
switch (this.licenseSummary.status) {
case LicenseSummaryStatusEnum.LimitExceededUser:
if (this.interface.toLowerCase() === "user") {
return this.renderBanner();
}
break;
case "user":
if (this.licenseSummary?.showUserWarning || this.licenseSummary?.readOnly) {
case LicenseSummaryStatusEnum.ExpirySoon:
case LicenseSummaryStatusEnum.Expired:
case LicenseSummaryStatusEnum.LimitExceededAdmin:
if (this.interface.toLowerCase() === "admin") {
return this.renderBanner();
}
break;
case LicenseSummaryStatusEnum.ReadOnly:
return this.renderBanner();
default:
break;
}
return html``;
}
-1
View File
@@ -42,7 +42,6 @@ export class Sidebar extends AKElement {
nav {
display: flex;
flex-direction: column;
max-height: 100vh;
height: 100%;
overflow-y: hidden;
}