"""
|
Tests for FSM registry functionality.
|
|
Tests registry management, state model registration, transition registration,
|
and related error handling scenarios.
|
"""
|
|
from typing import Optional
|
from unittest.mock import Mock, patch
|
|
import pytest
|
from django.test import TestCase
|
from fsm.registry import (
|
register_state_model,
|
register_state_transition,
|
state_model_registry,
|
transition_registry,
|
)
|
from fsm.transitions import BaseTransition, TransitionContext
|
|
|
class MockEntity:
|
"""Mock entity for testing"""
|
|
def __init__(self, pk=1):
|
self.pk = pk
|
self.id = pk
|
self._meta = Mock()
|
self._meta.model_name = 'testentity'
|
self._meta.label_lower = 'tests.testentity'
|
self.organization_id = 1
|
|
|
class RegistryTests(TestCase):
|
"""Tests for registry functionality and edge cases"""
|
|
def setUp(self):
|
self.entity = MockEntity()
|
|
def test_registry_state_model_with_denormalizer(self):
|
"""Test StateModelRegistry with state model that has get_denormalized_fields"""
|
|
mock_state_model = Mock()
|
mock_state_model.__name__ = 'MockStateModel'
|
|
# Mock the get_denormalized_fields classmethod
|
mock_state_model.get_denormalized_fields = Mock(return_value={'custom_field': 'denormalized_1'})
|
|
# Register the model (no denormalizer parameter anymore)
|
state_model_registry.register_model('testentity', mock_state_model)
|
|
# Check model was registered
|
registered_model = state_model_registry.get_model('testentity')
|
assert registered_model is not None
|
assert registered_model == mock_state_model
|
|
# Test that get_denormalized_fields works on the model
|
result = mock_state_model.get_denormalized_fields(self.entity)
|
assert result == {'custom_field': 'denormalized_1'}
|
|
def test_registry_denormalizer_error_handling(self):
|
"""Test error handling when get_denormalized_fields raises an exception"""
|
|
mock_state_model = Mock()
|
mock_state_model.__name__ = 'MockStateModel'
|
|
# Mock get_denormalized_fields to raise an error
|
mock_state_model.get_denormalized_fields = Mock(side_effect=RuntimeError('Denormalizer failed'))
|
|
# Register the model
|
state_model_registry.register_model('testentity', mock_state_model)
|
|
# Test that the error is propagated correctly
|
with pytest.raises(RuntimeError) as exc_info:
|
mock_state_model.get_denormalized_fields(self.entity)
|
|
assert 'Denormalizer failed' in str(exc_info.value)
|
|
def test_registry_overwrite_warning(self):
|
"""Test warning when overwriting existing registry entries"""
|
|
mock_state_model1 = Mock()
|
mock_state_model1.__name__ = 'MockModel1'
|
mock_state_model2 = Mock()
|
mock_state_model2.__name__ = 'MockModel2'
|
|
# Register first model
|
state_model_registry.register_model('testentity', mock_state_model1)
|
|
# Register second model (should warn about overwrite)
|
with patch('fsm.registry.logger') as mock_logger:
|
state_model_registry.register_model('testentity', mock_state_model2)
|
|
# Should have logged debug about overwrite
|
mock_logger.debug.assert_called()
|
# Find the call that has the overwrite message
|
debug_calls = mock_logger.debug.call_args_list
|
overwrite_call = None
|
for call in debug_calls:
|
if 'Overwriting existing state model' in call[0][0]:
|
overwrite_call = call
|
break
|
assert overwrite_call is not None, 'Expected debug log about overwriting existing state model'
|
debug_msg = overwrite_call[0][0]
|
assert 'Overwriting existing state model' in debug_msg
|
|
def test_registry_clear_methods(self):
|
"""Test registry clear methods"""
|
|
# Add some test data
|
mock_state_model = Mock()
|
mock_state_model.__name__ = 'MockStateModel'
|
state_model_registry.register_model('testentity', mock_state_model)
|
|
class TestTransition(BaseTransition):
|
def get_target_state(self, context: Optional[TransitionContext] = None) -> str:
|
return 'TEST'
|
|
def transition(self, context):
|
return {}
|
|
transition_registry.register('testentity', 'test_transition', TestTransition)
|
|
# Verify data exists
|
assert state_model_registry.get_model('testentity') is not None
|
assert 'test_transition' in transition_registry.get_transitions_for_entity('testentity')
|
|
# Clear registries
|
state_model_registry.clear()
|
transition_registry.clear()
|
|
# Verify data is cleared
|
assert state_model_registry.get_model('testentity') is None
|
assert transition_registry.get_transitions_for_entity('testentity') == {}
|
|
def test_registry_decorator_functions(self):
|
"""Test decorator functions for registration"""
|
|
# Test state model decorator
|
@register_state_model('decorated_entity')
|
class DecoratedStateModel:
|
pass
|
|
# Should be registered
|
assert state_model_registry.get_model('decorated_entity') == DecoratedStateModel
|
|
# Test transition decorator
|
@register_state_transition('decorated_entity', 'decorated_transition')
|
class DecoratedTransition(BaseTransition):
|
def get_target_state(self, context: Optional[TransitionContext] = None) -> str:
|
return 'DECORATED'
|
|
def transition(self, context):
|
return {}
|
|
# Should be registered
|
transitions = transition_registry.get_transitions_for_entity('decorated_entity')
|
assert 'decorated_transition' in transitions
|
assert transitions['decorated_transition'] == DecoratedTransition
|