from threading import local from typing import Any from django.core.signals import request_finished from django.dispatch import receiver from django.middleware.common import CommonMiddleware _thread_locals = local() class CurrentContext: @classmethod def set(cls, key: str, value: Any, shared: bool = True) -> None: if not hasattr(_thread_locals, 'data'): _thread_locals.data = {} if not hasattr(_thread_locals, 'job_data'): _thread_locals.job_data = {} if shared: _thread_locals.job_data[key] = value else: _thread_locals.data[key] = value @classmethod def get(cls, key: str, default=None): return getattr(_thread_locals, 'job_data', {}).get(key, getattr(_thread_locals, 'data', {}).get(key, default)) @classmethod def set_request(cls, request): _thread_locals.request = request if request.user: cls.set_user(request.user) @classmethod def get_organization_id(cls): return cls.get('organization_id') @classmethod def set_organization_id(cls, organization_id: int): cls.set('organization_id', organization_id) @classmethod def get_user(cls): return cls.get('user') @classmethod def set_user(cls, user): cls.set('user', user) if getattr(user, 'active_organization_id', None): cls.set_organization_id(user.active_organization_id) # PERFORMANCE: Cache FSM enabled state at request level when user is set # This allows all downstream code to check a simple boolean property # instead of repeatedly calling feature flag checks and possibly having to resolve the user, org and other related objects cls._cache_fsm_enabled_state(user) @classmethod def set_fsm_disabled(cls, disabled: bool): """ Temporarily disable/enable FSM for the current thread. This is useful for test cleanup and bulk operations where FSM state tracking is not needed and would cause performance issues. Args: disabled: True to disable FSM, False to enable it """ cls.set('fsm_disabled', disabled) @classmethod def is_fsm_disabled(cls) -> bool: """ Check if FSM is disabled for the current thread. Returns: True if FSM is disabled, False otherwise """ return cls.get('fsm_disabled', False) @classmethod def _cache_fsm_enabled_state(cls, user): """ Cache the FSM enabled state for this request/thread. PERFORMANCE: This is called once when the user is first set (typically in middleware). It checks the feature flag once and caches the result, so all downstream code can check a simple boolean property instead of repeatedly calling feature flag checks. This eliminates thousands of feature flag lookups per request. Args: user: The user to check FSM feature flag for """ try: from core.feature_flags import flag_set # Only import when needed to avoid circular imports # Check feature flag once and cache the result fsm_enabled = flag_set('fflag_feat_fit_568_finite_state_management', user=user) if user else False cls.set('fsm_enabled_cached', fsm_enabled) except Exception: # If feature flag check fails, assume disabled to be safe cls.set('fsm_enabled_cached', False) @classmethod def is_fsm_enabled(cls) -> bool: """ Check if FSM is enabled for the current request/thread. PERFORMANCE: Returns cached value that was set when user was first set. This avoids repeated feature flag lookups throughout the request. Returns: True if FSM is enabled, False otherwise (includes manual disable via set_fsm_disabled) """ # Check manual override first (for tests and bulk operations) if cls.is_fsm_disabled(): return False # Return cached feature flag state (set once per request in _cache_fsm_enabled_state) return cls.get('fsm_enabled_cached', False) @classmethod def get_job_data(cls) -> dict: """ This data will be shared to jobs spawned by the current thread. """ return getattr(_thread_locals, 'job_data', {}) @classmethod def clear(cls) -> None: if hasattr(_thread_locals, 'data'): delattr(_thread_locals, 'data') if hasattr(_thread_locals, 'job_data'): delattr(_thread_locals, 'job_data') if hasattr(_thread_locals, 'request'): del _thread_locals.request @classmethod def get_request(cls): return getattr(_thread_locals, 'request', None) def get_current_request(): """returns the request object for this thread""" result = CurrentContext.get_request() return result class ThreadLocalMiddleware(CommonMiddleware): def process_request(self, request): CurrentContext.set_request(request) @receiver(request_finished) def clean_request(sender, **kwargs): CurrentContext.clear()