Bin
2025-12-17 05a69820e0c402b0b33c063d3b922f0a0571cbbb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import logging
from datetime import datetime
 
from core.permissions import ViewClassPermission, all_permissions
from django.utils.decorators import method_decorator
from drf_spectacular.utils import extend_schema
from jwt_auth.auth import TokenAuthenticationPhaseout
from jwt_auth.models import LSAPIToken, TruncatedLSAPIToken
from jwt_auth.serializers import (
    JWTSettingsSerializer,
    LSAPITokenCreateSerializer,
    LSAPITokenListSerializer,
    TokenRefreshResponseSerializer,
    TokenRotateResponseSerializer,
)
from rest_framework import generics, status
from rest_framework.authentication import SessionAuthentication
from rest_framework.exceptions import APIException
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response
from rest_framework.settings import api_settings
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError
from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken, OutstandingToken
from rest_framework_simplejwt.views import TokenRefreshView, TokenViewBase
 
logger = logging.getLogger(__name__)
 
 
class TokenExistsError(APIException):
    status_code = status.HTTP_409_CONFLICT
    default_detail = 'You already have a valid token. Please revoke it before creating a new one.'
    default_code = 'token_exists'
 
 
@method_decorator(
    name='get',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Retrieve JWT Settings',
        description='Retrieve JWT settings for the currently active organization.',
        extensions={
            'x-fern-sdk-group-name': 'jwt_settings',
            'x-fern-sdk-method-name': 'get',
            'x-fern-audiences': ['public'],
        },
    ),
)
@method_decorator(
    name='post',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Update JWT Settings',
        description='Update JWT settings for the currently active organization.',
        extensions={
            'x-fern-sdk-group-name': 'jwt_settings',
            'x-fern-sdk-method-name': 'update',
            'x-fern-audiences': ['public'],
        },
    ),
)
class JWTSettingsAPI(CreateAPIView):
    serializer_class = JWTSettingsSerializer
    permission_required = ViewClassPermission(
        GET=all_permissions.organizations_view,
        POST=all_permissions.organizations_change,
    )
 
    def get_object(self):
        jwt = self.request.user.active_organization.jwt
        self.check_object_permissions(self.request, jwt)
        return jwt
 
    def get(self, request, *args, **kwargs):
        jwt_settings = self.get_object()
        return Response(self.get_serializer(jwt_settings).data)
 
    def post(self, request, *args, **kwargs):
        jwt_settings = self.get_object()
        serializer = self.get_serializer(data=request.data, instance=jwt_settings)
        serializer.is_valid(raise_exception=True)
        serializer.save()
        return Response(serializer.data)
 
 
class DecoratedTokenRefreshView(TokenRefreshView):
    @extend_schema(
        tags=['JWT'],
        summary='Refresh JWT token',
        description='Get a new access token, using a refresh token.',
        responses={
            status.HTTP_200_OK: TokenRefreshResponseSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'refresh',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        return super().post(request, *args, **kwargs)
 
 
@method_decorator(
    name='get',
    decorator=extend_schema(
        tags=['JWT'],
        summary='List API tokens',
        description='List all API tokens for the current user.',
        responses={
            status.HTTP_200_OK: LSAPITokenListSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'list',
            'x-fern-audiences': ['public'],
        },
    ),
)
@method_decorator(
    name='post',
    decorator=extend_schema(
        tags=['JWT'],
        summary='Create API token',
        description='Create a new API token for the current user.',
        responses={
            status.HTTP_201_CREATED: LSAPITokenCreateSerializer,
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'create',
            'x-fern-audiences': ['public'],
        },
    ),
)
class LSAPITokenView(generics.ListCreateAPIView):
    permission_required = all_permissions.users_token_any
    token_class = LSAPIToken
 
    def get_queryset(self):
        """Returns all non-expired non-blacklisted tokens for the current user.
 
        The `list` method handles filtering for refresh tokens (as opposed to access tokens),
        since simple-jwt makes it hard to do this at the DB level."""
        # Notably, if the list of non-expired blacklisted tokens ever gets too long
        # (e.g. users from orgs who have not set a token expiration for their org
        # revoke enough tokens for this to blow up), this will become inefficient.
        # Would be ideal to just add a "blacklisted" attr to our own subclass of
        # OutstandingToken so we can check at that level, or just clean up
        # OutstandingTokens that have been blacklisted every so often.
        current_blacklisted_tokens = BlacklistedToken.objects.filter(token__expires_at__gt=datetime.now()).values_list(
            'token_id', flat=True
        )
        return OutstandingToken.objects.filter(user_id=self.request.user.id, expires_at__gt=datetime.now()).exclude(
            id__in=current_blacklisted_tokens
        )
 
    def list(self, request, *args, **kwargs):
        all_tokens = self.get_queryset()
 
        def _maybe_get_token(token: OutstandingToken):
            try:
                return TruncatedLSAPIToken(str(token.token))
            except (TokenError, TokenBackendError) as e:  # expired/invalid token
                logger.debug('JWT API token validation failed: %s', e)
                return None
 
        # Annoyingly, token_type not stored directly so we have to filter it here.
        # Shouldn't be many unexpired tokens to iterate through.
        token_objects = list(filter(None, [_maybe_get_token(token) for token in all_tokens]))
        refresh_tokens = [tok for tok in token_objects if tok['token_type'] == 'refresh']
 
        serializer = self.get_serializer(refresh_tokens, many=True)
        data = serializer.data
        return Response(data)
 
    def get_serializer_class(self):
        if self.request.method == 'POST':
            return LSAPITokenCreateSerializer
        return LSAPITokenListSerializer
 
    def perform_create(self, serializer):
        # Check for existing valid tokens
        existing_tokens = self.get_queryset()
        if existing_tokens.exists():
            raise TokenExistsError()
 
        token = self.token_class.for_user(self.request.user)
        serializer.instance = token
 
 
class LSTokenBlacklistView(TokenViewBase):
    _serializer_class = 'jwt_auth.serializers.LSAPITokenBlacklistSerializer'
 
    @extend_schema(
        tags=['JWT'],
        summary='Blacklist a JWT refresh token',
        description='Adds a JWT refresh token to the blacklist, preventing it from being used to obtain new access tokens.',
        responses={
            status.HTTP_204_NO_CONTENT: 'Token was successfully blacklisted',
            status.HTTP_404_NOT_FOUND: 'Token is already blacklisted',
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'blacklist',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        try:
            # Notably, simple jwt's serializer (which we inherit from) calls
            # .blacklist() on the token under the hood
            serializer.is_valid(raise_exception=True)
        except TokenError as e:
            logger.error('Token error occurred while trying to blacklist a token: %s', str(e), exc_info=True)
            return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_404_NOT_FOUND)
 
        return Response(status=status.HTTP_204_NO_CONTENT)
 
 
class LSAPITokenRotateView(TokenViewBase):
    # Have to explicitly set authentication_classes here, due to how auth works in our middleware, request.user is not set
    # properly before executing the view.
    authentication_classes = [JWTAuthentication, TokenAuthenticationPhaseout, SessionAuthentication]
    permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
    permission_required = all_permissions.users_token_any
    _serializer_class = 'jwt_auth.serializers.LSAPITokenRotateSerializer'
    token_class = LSAPIToken
 
    @extend_schema(
        tags=['JWT'],
        summary='Rotate JWT refresh token',
        description='Creates a new JWT refresh token and blacklists the current one.',
        responses={
            status.HTTP_200_OK: TokenRotateResponseSerializer,
            status.HTTP_400_BAD_REQUEST: 'Invalid token or token already blacklisted',
        },
        extensions={
            'x-fern-sdk-group-name': 'tokens',
            'x-fern-sdk-method-name': 'rotate',
            'x-fern-audiences': ['public'],
        },
    )
    def post(self, request, *args, **kwargs):
        serializer = self.get_serializer(data=request.data)
        serializer.is_valid(raise_exception=True)
        current_token = serializer.validated_data['refresh']
 
        # Blacklist the current token
        try:
            current_token.blacklist()
        except TokenError:
            return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_400_BAD_REQUEST)
 
        # Create a new token for the user
        new_token = self.create_token(request.user)
        return Response({'refresh': new_token.get_full_jwt()}, status=status.HTTP_200_OK)
 
    def create_token(self, user):
        """Create a new token for the user. Can be overridden by child classes to use different token classes."""
        return self.token_class.for_user(user)