Bin
2025-12-16 971a2a12c03b74dd2d7d668b9dbc599f5131bcaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from datetime import timedelta
from typing import Any
 
from annoying.fields import AutoOneToOneField
from django.db import models
from django.utils.translation import gettext_lazy as _
from organizations.models import Organization
from rest_framework_simplejwt.backends import TokenBackend
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.tokens import RefreshToken
from rest_framework_simplejwt.tokens import api_settings as simple_jwt_settings
 
 
class JWTSettings(models.Model):
    """Organization-specific JWT settings for authentication"""
 
    organization = AutoOneToOneField(Organization, related_name='jwt', primary_key=True, on_delete=models.DO_NOTHING)
    api_tokens_enabled = models.BooleanField(
        _('JWT API tokens enabled'),
        default=True,
        help_text='Enable JWT API token authentication for this organization',
    )
    api_token_ttl_days = models.IntegerField(
        _('JWT API token time to live (days)'),
        default=(200 * 365),  # "eternity", 200 years
        help_text='Number of days before JWT API tokens expire',
    )
    legacy_api_tokens_enabled = models.BooleanField(
        _('legacy API tokens enabled'),
        default=False,
        help_text='Enable legacy API token authentication for this organization',
    )
 
    created_at = models.DateTimeField(_('created at'), auto_now_add=True)
    updated_at = models.DateTimeField(_('updated at'), auto_now=True)
 
    def has_permission(self, user):
        return self.organization.has_permission(user)
 
 
class LSTokenBackend(TokenBackend):
    """A custom JWT token backend that truncates tokens before storing in the database.
 
    Extends simlpe jwt's TokenBackend to provide methods for generating both
    truncated tokens (header + payload only) and full tokens (header + payload + signature).
    This preserves privacy of the token by not exposing the signature to the frontend.
    """
 
    def encode(self, payload: dict[str, Any]) -> str:
        """Encode a payload into a truncated JWT token string.
 
        Args:
            payload: Dictionary containing the JWT claims to encode
 
        Returns:
            A truncated JWT string containing only the header and payload portions,
            with the signature section removed
        """
        header, payload, signature = super().encode(payload).split('.')
        return '.'.join([header, payload])
 
    def encode_full(self, payload: dict[str, Any]) -> str:
        """Encode a payload into a complete JWT token string.
 
        Args:
            payload: Dictionary containing the JWT claims to encode
 
        Returns:
            A complete JWT string containing header, payload and signature portions
        """
        return super().encode(payload)
 
 
class LSAPIToken(RefreshToken):
    """API token that utilizes JWT, but stores a truncated version and expires
    based on user settings
 
    This token class extends RefreshToken to provide organization-specific token
    lifetimes and support for truncated tokens. It uses the LSTokenBackend to
    securely store the token (without the signature).
    """
 
    lifetime = timedelta(days=365 * 200)  # "eternity" (200 years)
 
    _token_backend = LSTokenBackend(
        simple_jwt_settings.ALGORITHM,
        simple_jwt_settings.SIGNING_KEY,
        simple_jwt_settings.VERIFYING_KEY,
        simple_jwt_settings.AUDIENCE,
        simple_jwt_settings.ISSUER,
        simple_jwt_settings.JWK_URL,
        simple_jwt_settings.LEEWAY,
        simple_jwt_settings.JSON_ENCODER,
    )
 
    def get_full_jwt(self) -> str:
        """Get the complete JWT token string (including the signature).
 
        Returns:
            The full JWT token string with header, payload and signature
        """
        return self.get_token_backend().encode_full(self.payload)
 
    def blacklist(self):
        """Blacklist this token.
 
        Raises:
            rest_framework_simplejwt.exceptions.TokenError: If the token is already blacklisted.
        """
        self.check_blacklist()
        return super().blacklist()
 
 
class TruncatedLSAPIToken(LSAPIToken):
    """Handles JWT tokens that contain only header and payload (no signature).
    Used when frontend has access to truncated refresh tokens only."""
 
    def __init__(self, token, *args, **kwargs):
        """Initialize a truncated token, ensuring it has exactly 2 parts before adding a dummy signature."""
        # Ensure we have exactly 2 parts (header and payload)
        parts = token.split('.')
        if len(parts) > 2:
            token = '.'.join(parts[:2])
        elif len(parts) < 2:
            raise TokenError('Invalid Label Studio token')
 
        # Add dummy signature with exactly 43 'x' characters to match expected JWT signature length
        token = token + '.' + ('x' * 43)
        super().__init__(token, verify=False, *args, **kwargs)