"""This file and its contents are licensed under the Apache License 2.0. Please see the included NOTICE for copyright information and LICENSE for a copy of the license. """ import logging import time from datetime import timedelta from uuid import uuid4 import ujson as json from core.utils.contextlog import ContextLog from csp.middleware import CSPMiddleware from django.conf import settings from django.contrib.auth import logout from django.core.exceptions import MiddlewareNotUsed from django.core.handlers.base import BaseHandler from django.http import HttpResponsePermanentRedirect from django.middleware.common import CommonMiddleware from django.utils.deprecation import MiddlewareMixin from django.utils.http import escape_leading_slashes from rest_framework.permissions import SAFE_METHODS logger = logging.getLogger(__name__) def enforce_csrf_checks(func): """Enable csrf for specified view func""" # USE_ENFORCE_CSRF_CHECKS=False is for tests if settings.USE_ENFORCE_CSRF_CHECKS: def wrapper(request, *args, **kwargs): return func(request, *args, **kwargs) wrapper._dont_enforce_csrf_checks = False return wrapper else: return func class DisableCSRF(MiddlewareMixin): # disable csrf for api requests def process_view(self, request, callback, *args, **kwargs): if hasattr(callback, '_dont_enforce_csrf_checks'): setattr(request, '_dont_enforce_csrf_checks', callback._dont_enforce_csrf_checks) elif request.GET.get('enforce_csrf_checks'): # _dont_enforce_csrf_checks is for test setattr(request, '_dont_enforce_csrf_checks', False) else: setattr(request, '_dont_enforce_csrf_checks', True) class HttpSmartRedirectResponse(HttpResponsePermanentRedirect): pass class CommonMiddlewareAppendSlashWithoutRedirect(CommonMiddleware): """This class converts HttpSmartRedirectResponse to the common response of Django view, without redirect. This is necessary to match status_codes for urls like /url?q=1 and /url/?q=1. If you don't use it, you will have 302 code always on pages without slash. """ response_redirect_class = HttpSmartRedirectResponse def __init__(self, *args, **kwargs): # create django request resolver self.handler = BaseHandler() # prevent recursive includes old = settings.MIDDLEWARE name = self.__module__ + '.' + self.__class__.__name__ settings.MIDDLEWARE = [i for i in settings.MIDDLEWARE if i != name] self.handler.load_middleware() settings.MIDDLEWARE = old super(CommonMiddlewareAppendSlashWithoutRedirect, self).__init__(*args, **kwargs) def get_full_path_with_slash(self, request): """Return the full path of the request with a trailing slash appended without Exception in Debug mode """ new_path = request.get_full_path(force_append_slash=True) # Prevent construction of scheme relative urls. new_path = escape_leading_slashes(new_path) return new_path def process_response(self, request, response): response = super(CommonMiddlewareAppendSlashWithoutRedirect, self).process_response(request, response) request.editor_keymap = settings.EDITOR_KEYMAP if isinstance(response, HttpSmartRedirectResponse): if not request.path.endswith('/'): # remove prefix SCRIPT_NAME path = request.path[len(settings.FORCE_SCRIPT_NAME) :] if settings.FORCE_SCRIPT_NAME else request.path request.path = path + '/' # we don't need query string in path_info because it's in request.GET already request.path_info = request.path response = self.handler.get_response(request) return response def should_redirect_with_slash(self, request): """ Override the original method to keep global APPEND_SLASH setting false """ if not request.path_info.endswith('/'): return True return False class SetSessionUIDMiddleware(CommonMiddleware): def process_request(self, request): if 'uid' not in request.session: request.session['uid'] = str(uuid4()) class ContextLogMiddleware(CommonMiddleware): def __init__(self, get_response): self.get_response = get_response self.log = ContextLog() def __call__(self, request): body = None try: body = json.loads(request.body) except: # noqa: E722 try: body = request.body.decode('utf-8') except: # noqa: E722 pass if 'server_id' not in request: setattr(request, 'server_id', self.log._get_server_id()) response = self.get_response(request) self.log.send(request=request, response=response, body=body) return response def process_request(self, request): if 'server_id' not in request: setattr(request, 'server_id', self.log._get_server_id()) class DatabaseIsLockedRetryMiddleware(CommonMiddleware): """Workaround for sqlite performance issues we wait and retry request if database is locked""" def __init__(self, get_response): if settings.DJANGO_DB != settings.DJANGO_DB_SQLITE: raise MiddlewareNotUsed() self.get_response = get_response def __call__(self, request): response = self.get_response(request) retries_number = 0 sleep_time = 1 backoff = 1.5 while ( response.status_code == 500 and hasattr(response, 'content') and b'database-is-locked-error' in response.content and retries_number < 15 ): time.sleep(sleep_time) response = self.get_response(request) retries_number += 1 sleep_time *= backoff return response class XApiKeySupportMiddleware: """Middleware that adds support for the X-Api-Key header, by having its value supersede anything that's set in the Authorization header.""" def __init__(self, get_response): self.get_response = get_response def __call__(self, request): if 'HTTP_X_API_KEY' in request.META: request.META['HTTP_AUTHORIZATION'] = f'Token {request.META["HTTP_X_API_KEY"]}' del request.META['HTTP_X_API_KEY'] return self.get_response(request) class UpdateLastActivityMiddleware(CommonMiddleware): def process_view(self, request, view_func, view_args, view_kwargs): if hasattr(request, 'user') and request.method not in SAFE_METHODS: if request.user.is_authenticated: request.user.update_last_activity() class InactivitySessionTimeoutMiddleWare(CommonMiddleware): """Log the user out if they have been logged in for too long or inactive for too long""" # paths that don't count as user activity NOT_USER_ACTIVITY_PATHS = [] def process_request(self, request) -> None: if ( not hasattr(request, 'session') or request.session.is_empty() or not hasattr(request, 'user') or not request.user.is_authenticated or # scim assign request.user implicitly, check CustomSCIMAuthCheckMiddleware (hasattr(request, 'is_scim') and request.is_scim) or (hasattr(request, 'is_jwt') and request.is_jwt) ): return current_time = time.time() last_login = request.session['last_login'] if 'last_login' in request.session else 0 active_org = request.user.active_organization if active_org: org_max_session_age = timedelta(minutes=active_org.session_timeout_policy.max_session_age).total_seconds() max_time_between_activity = timedelta( minutes=active_org.session_timeout_policy.max_time_between_activity ).total_seconds() if (current_time - last_login) > org_max_session_age: logger.info( f'Request is too far from last login {current_time - last_login:.0f} > {settings.MAX_SESSION_AGE}; logout' ) logout(request) else: max_time_between_activity = settings.MAX_TIME_BETWEEN_ACTIVITY # Check if this request is too far from when the login happened if (current_time - last_login) > settings.MAX_SESSION_AGE: logger.info( f'Request is too far from last login {current_time - last_login:.0f} > {settings.MAX_SESSION_AGE}; logout' ) logout(request) # Push the expiry to the max every time a new request is made to a url that indicates user activity # but only if it's not a URL we want to ignore for path in self.NOT_USER_ACTIVITY_PATHS: if isinstance(path, str) and path == str(request.path_info): return elif 'query' in path: parts = str(request.path_info).split('?') if len(parts) == 2 and path['query'] in parts[1]: return request.session.set_expiry(max_time_between_activity if request.session.get('keep_me_logged_in', True) else 0) class HumanSignalCspMiddleware(CSPMiddleware): """ Extend CSPMiddleware to support switching report-only CSP to regular CSP. For use with core.decorators.override_report_only_csp. """ def process_response(self, request, response): response = super().process_response(request, response) if getattr(response, '_override_report_only_csp', False): if csp_policy := response.get('Content-Security-Policy-Report-Only'): response['Content-Security-Policy'] = csp_policy del response['Content-Security-Policy-Report-Only'] delattr(response, '_override_report_only_csp') return response