from datetime import timedelta from typing import Any from annoying.fields import AutoOneToOneField from django.db import models from django.utils.translation import gettext_lazy as _ from organizations.models import Organization from rest_framework_simplejwt.backends import TokenBackend from rest_framework_simplejwt.exceptions import TokenError from rest_framework_simplejwt.tokens import RefreshToken from rest_framework_simplejwt.tokens import api_settings as simple_jwt_settings class JWTSettings(models.Model): """Organization-specific JWT settings for authentication""" organization = AutoOneToOneField(Organization, related_name='jwt', primary_key=True, on_delete=models.DO_NOTHING) api_tokens_enabled = models.BooleanField( _('JWT API tokens enabled'), default=True, help_text='Enable JWT API token authentication for this organization', ) api_token_ttl_days = models.IntegerField( _('JWT API token time to live (days)'), default=(200 * 365), # "eternity", 200 years help_text='Number of days before JWT API tokens expire', ) legacy_api_tokens_enabled = models.BooleanField( _('legacy API tokens enabled'), default=False, help_text='Enable legacy API token authentication for this organization', ) created_at = models.DateTimeField(_('created at'), auto_now_add=True) updated_at = models.DateTimeField(_('updated at'), auto_now=True) def has_permission(self, user): return self.organization.has_permission(user) class LSTokenBackend(TokenBackend): """A custom JWT token backend that truncates tokens before storing in the database. Extends simlpe jwt's TokenBackend to provide methods for generating both truncated tokens (header + payload only) and full tokens (header + payload + signature). This preserves privacy of the token by not exposing the signature to the frontend. """ def encode(self, payload: dict[str, Any]) -> str: """Encode a payload into a truncated JWT token string. Args: payload: Dictionary containing the JWT claims to encode Returns: A truncated JWT string containing only the header and payload portions, with the signature section removed """ header, payload, signature = super().encode(payload).split('.') return '.'.join([header, payload]) def encode_full(self, payload: dict[str, Any]) -> str: """Encode a payload into a complete JWT token string. Args: payload: Dictionary containing the JWT claims to encode Returns: A complete JWT string containing header, payload and signature portions """ return super().encode(payload) class LSAPIToken(RefreshToken): """API token that utilizes JWT, but stores a truncated version and expires based on user settings This token class extends RefreshToken to provide organization-specific token lifetimes and support for truncated tokens. It uses the LSTokenBackend to securely store the token (without the signature). """ lifetime = timedelta(days=365 * 200) # "eternity" (200 years) _token_backend = LSTokenBackend( simple_jwt_settings.ALGORITHM, simple_jwt_settings.SIGNING_KEY, simple_jwt_settings.VERIFYING_KEY, simple_jwt_settings.AUDIENCE, simple_jwt_settings.ISSUER, simple_jwt_settings.JWK_URL, simple_jwt_settings.LEEWAY, simple_jwt_settings.JSON_ENCODER, ) def get_full_jwt(self) -> str: """Get the complete JWT token string (including the signature). Returns: The full JWT token string with header, payload and signature """ return self.get_token_backend().encode_full(self.payload) def blacklist(self): """Blacklist this token. Raises: rest_framework_simplejwt.exceptions.TokenError: If the token is already blacklisted. """ self.check_blacklist() return super().blacklist() class TruncatedLSAPIToken(LSAPIToken): """Handles JWT tokens that contain only header and payload (no signature). Used when frontend has access to truncated refresh tokens only.""" def __init__(self, token, *args, **kwargs): """Initialize a truncated token, ensuring it has exactly 2 parts before adding a dummy signature.""" # Ensure we have exactly 2 parts (header and payload) parts = token.split('.') if len(parts) > 2: token = '.'.join(parts[:2]) elif len(parts) < 2: raise TokenError('Invalid Label Studio token') # Add dummy signature with exactly 43 'x' characters to match expected JWT signature length token = token + '.' + ('x' * 43) super().__init__(token, verify=False, *args, **kwargs)