import logging from datetime import datetime from core.permissions import ViewClassPermission, all_permissions from django.utils.decorators import method_decorator from drf_spectacular.utils import extend_schema from jwt_auth.auth import TokenAuthenticationPhaseout from jwt_auth.models import LSAPIToken, TruncatedLSAPIToken from jwt_auth.serializers import ( JWTSettingsSerializer, LSAPITokenCreateSerializer, LSAPITokenListSerializer, TokenRefreshResponseSerializer, TokenRotateResponseSerializer, ) from rest_framework import generics, status from rest_framework.authentication import SessionAuthentication from rest_framework.exceptions import APIException from rest_framework.generics import CreateAPIView from rest_framework.response import Response from rest_framework.settings import api_settings from rest_framework_simplejwt.authentication import JWTAuthentication from rest_framework_simplejwt.exceptions import TokenBackendError, TokenError from rest_framework_simplejwt.token_blacklist.models import BlacklistedToken, OutstandingToken from rest_framework_simplejwt.views import TokenRefreshView, TokenViewBase logger = logging.getLogger(__name__) class TokenExistsError(APIException): status_code = status.HTTP_409_CONFLICT default_detail = 'You already have a valid token. Please revoke it before creating a new one.' default_code = 'token_exists' @method_decorator( name='get', decorator=extend_schema( tags=['JWT'], summary='Retrieve JWT Settings', description='Retrieve JWT settings for the currently active organization.', extensions={ 'x-fern-sdk-group-name': 'jwt_settings', 'x-fern-sdk-method-name': 'get', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='post', decorator=extend_schema( tags=['JWT'], summary='Update JWT Settings', description='Update JWT settings for the currently active organization.', extensions={ 'x-fern-sdk-group-name': 'jwt_settings', 'x-fern-sdk-method-name': 'update', 'x-fern-audiences': ['public'], }, ), ) class JWTSettingsAPI(CreateAPIView): serializer_class = JWTSettingsSerializer permission_required = ViewClassPermission( GET=all_permissions.organizations_view, POST=all_permissions.organizations_change, ) def get_object(self): jwt = self.request.user.active_organization.jwt self.check_object_permissions(self.request, jwt) return jwt def get(self, request, *args, **kwargs): jwt_settings = self.get_object() return Response(self.get_serializer(jwt_settings).data) def post(self, request, *args, **kwargs): jwt_settings = self.get_object() serializer = self.get_serializer(data=request.data, instance=jwt_settings) serializer.is_valid(raise_exception=True) serializer.save() return Response(serializer.data) class DecoratedTokenRefreshView(TokenRefreshView): @extend_schema( tags=['JWT'], summary='Refresh JWT token', description='Get a new access token, using a refresh token.', responses={ status.HTTP_200_OK: TokenRefreshResponseSerializer, }, extensions={ 'x-fern-sdk-group-name': 'tokens', 'x-fern-sdk-method-name': 'refresh', 'x-fern-audiences': ['public'], }, ) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) @method_decorator( name='get', decorator=extend_schema( tags=['JWT'], summary='List API tokens', description='List all API tokens for the current user.', responses={ status.HTTP_200_OK: LSAPITokenListSerializer, }, extensions={ 'x-fern-sdk-group-name': 'tokens', 'x-fern-sdk-method-name': 'list', 'x-fern-audiences': ['public'], }, ), ) @method_decorator( name='post', decorator=extend_schema( tags=['JWT'], summary='Create API token', description='Create a new API token for the current user.', responses={ status.HTTP_201_CREATED: LSAPITokenCreateSerializer, }, extensions={ 'x-fern-sdk-group-name': 'tokens', 'x-fern-sdk-method-name': 'create', 'x-fern-audiences': ['public'], }, ), ) class LSAPITokenView(generics.ListCreateAPIView): permission_required = all_permissions.users_token_any token_class = LSAPIToken def get_queryset(self): """Returns all non-expired non-blacklisted tokens for the current user. The `list` method handles filtering for refresh tokens (as opposed to access tokens), since simple-jwt makes it hard to do this at the DB level.""" # Notably, if the list of non-expired blacklisted tokens ever gets too long # (e.g. users from orgs who have not set a token expiration for their org # revoke enough tokens for this to blow up), this will become inefficient. # Would be ideal to just add a "blacklisted" attr to our own subclass of # OutstandingToken so we can check at that level, or just clean up # OutstandingTokens that have been blacklisted every so often. current_blacklisted_tokens = BlacklistedToken.objects.filter(token__expires_at__gt=datetime.now()).values_list( 'token_id', flat=True ) return OutstandingToken.objects.filter(user_id=self.request.user.id, expires_at__gt=datetime.now()).exclude( id__in=current_blacklisted_tokens ) def list(self, request, *args, **kwargs): all_tokens = self.get_queryset() def _maybe_get_token(token: OutstandingToken): try: return TruncatedLSAPIToken(str(token.token)) except (TokenError, TokenBackendError) as e: # expired/invalid token logger.debug('JWT API token validation failed: %s', e) return None # Annoyingly, token_type not stored directly so we have to filter it here. # Shouldn't be many unexpired tokens to iterate through. token_objects = list(filter(None, [_maybe_get_token(token) for token in all_tokens])) refresh_tokens = [tok for tok in token_objects if tok['token_type'] == 'refresh'] serializer = self.get_serializer(refresh_tokens, many=True) data = serializer.data return Response(data) def get_serializer_class(self): if self.request.method == 'POST': return LSAPITokenCreateSerializer return LSAPITokenListSerializer def perform_create(self, serializer): # Check for existing valid tokens existing_tokens = self.get_queryset() if existing_tokens.exists(): raise TokenExistsError() token = self.token_class.for_user(self.request.user) serializer.instance = token class LSTokenBlacklistView(TokenViewBase): _serializer_class = 'jwt_auth.serializers.LSAPITokenBlacklistSerializer' @extend_schema( tags=['JWT'], summary='Blacklist a JWT refresh token', description='Adds a JWT refresh token to the blacklist, preventing it from being used to obtain new access tokens.', responses={ status.HTTP_204_NO_CONTENT: 'Token was successfully blacklisted', status.HTTP_404_NOT_FOUND: 'Token is already blacklisted', }, extensions={ 'x-fern-sdk-group-name': 'tokens', 'x-fern-sdk-method-name': 'blacklist', 'x-fern-audiences': ['public'], }, ) def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) try: # Notably, simple jwt's serializer (which we inherit from) calls # .blacklist() on the token under the hood serializer.is_valid(raise_exception=True) except TokenError as e: logger.error('Token error occurred while trying to blacklist a token: %s', str(e), exc_info=True) return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_404_NOT_FOUND) return Response(status=status.HTTP_204_NO_CONTENT) class LSAPITokenRotateView(TokenViewBase): # Have to explicitly set authentication_classes here, due to how auth works in our middleware, request.user is not set # properly before executing the view. authentication_classes = [JWTAuthentication, TokenAuthenticationPhaseout, SessionAuthentication] permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES permission_required = all_permissions.users_token_any _serializer_class = 'jwt_auth.serializers.LSAPITokenRotateSerializer' token_class = LSAPIToken @extend_schema( tags=['JWT'], summary='Rotate JWT refresh token', description='Creates a new JWT refresh token and blacklists the current one.', responses={ status.HTTP_200_OK: TokenRotateResponseSerializer, status.HTTP_400_BAD_REQUEST: 'Invalid token or token already blacklisted', }, extensions={ 'x-fern-sdk-group-name': 'tokens', 'x-fern-sdk-method-name': 'rotate', 'x-fern-audiences': ['public'], }, ) def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) current_token = serializer.validated_data['refresh'] # Blacklist the current token try: current_token.blacklist() except TokenError: return Response({'detail': 'Token is invalid or already blacklisted.'}, status=status.HTTP_400_BAD_REQUEST) # Create a new token for the user new_token = self.create_token(request.user) return Response({'refresh': new_token.get_full_jwt()}, status=status.HTTP_200_OK) def create_token(self, user): """Create a new token for the user. Can be overridden by child classes to use different token classes.""" return self.token_class.for_user(user)