"""Tests for enum subset validation in REQUIRED_INPUTS. This module tests the enhanced validation logic that supports restricting enum fields to specific subsets using Literal types. """ from typing import Any, Literal from sentinel.risk_models.base import RiskModel from sentinel.user_input import ( Anthropometrics, Demographics, Ethnicity, Lifestyle, PersonalMedicalHistory, Sex, SmokingHistory, SmokingStatus, UserInput, ) class EnumValidationTestModel(RiskModel): """Test risk model with various enum restrictions for validation testing.""" def __init__(self): super().__init__("test_enum_validation") # Test cases for different enum restriction patterns REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = { # Single enum value restriction "demographics.sex": (Literal[Sex.FEMALE], True), # Multiple enum value restriction (subset) "demographics.ethnicity": ( Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] | None, False, ), } def compute_score(self, user: UserInput) -> str: """Test implementation. Args: user: The user profile to score. Returns: A test score string. """ return "test_score" def cancer_type(self) -> str: return "test" def description(self) -> str: return "Test model" def interpretation(self) -> str: return "Test interpretation" def references(self) -> list[str]: return ["Test reference"] def time_horizon_years(self) -> float | None: return None class TestEnumSubsetValidation: """Test enum subset validation functionality.""" def setup_method(self): """Set up test model.""" self.model = EnumValidationTestModel() def _create_user_input( self, sex: Sex, ethnicity: Ethnicity | None = None ) -> UserInput: """Create a valid UserInput instance for testing. Args: sex: The biological sex for the user. ethnicity: The ethnicity for the user (optional). Returns: A valid UserInput instance for testing. """ return UserInput( demographics=Demographics( age_years=40, sex=sex, ethnicity=ethnicity, anthropometrics=Anthropometrics(height_cm=165.0, weight_kg=65.0), ), lifestyle=Lifestyle( smoking=SmokingHistory(status=SmokingStatus.NEVER), ), personal_medical_history=PersonalMedicalHistory(), ) def test_single_enum_value_restriction_valid(self): """Test that valid single enum value passes validation.""" user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE) is_valid, errors = self.model.validate_inputs(user) assert is_valid assert len(errors) == 0 def test_single_enum_value_restriction_invalid(self): """Test that invalid single enum value fails validation with clear message.""" user = self._create_user_input(Sex.MALE, Ethnicity.WHITE) # Should be FEMALE is_valid, errors = self.model.validate_inputs(user) assert not is_valid assert len(errors) == 1 assert "Field 'demographics.sex': must be FEMALE" in errors[0] def test_multiple_enum_value_restriction_valid(self): """Test that valid enum values from subset pass validation.""" valid_ethnicities = [Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] for ethnicity in valid_ethnicities: user = self._create_user_input(Sex.FEMALE, ethnicity) is_valid, errors = self.model.validate_inputs(user) assert is_valid, f"Failed for ethnicity: {ethnicity}" assert len(errors) == 0 def test_multiple_enum_value_restriction_invalid(self): """Test that invalid enum values fail validation with clear message.""" invalid_ethnicities = [ Ethnicity.HISPANIC, Ethnicity.ASHKENAZI_JEWISH, Ethnicity.NATIVE_AMERICAN, Ethnicity.PACIFIC_ISLANDER, Ethnicity.OTHER, Ethnicity.UNKNOWN, ] for ethnicity in invalid_ethnicities: user = self._create_user_input(Sex.FEMALE, ethnicity) is_valid, errors = self.model.validate_inputs(user) assert not is_valid, f"Should have failed for ethnicity: {ethnicity}" assert len(errors) == 1 assert "Field 'demographics.ethnicity': Input should be" in errors[0] assert ( "WHITE" in errors[0] and "BLACK" in errors[0] and "ASIAN" in errors[0] ) def test_optional_enum_field_with_none(self): """Test that None values are handled correctly for optional enum fields.""" user = self._create_user_input(Sex.FEMALE, None) # Optional field is_valid, errors = self.model.validate_inputs(user) assert is_valid assert len(errors) == 0 def test_missing_required_enum_field(self): """Test that missing required enum fields are caught.""" # Create a model that requires a field that's not in the user input class MissingFieldModel(RiskModel): """Test model for missing field validation.""" def __init__(self): super().__init__("missing_field_test") REQUIRED_INPUTS: dict[str, tuple[Any, bool]] = { "demographics.sex": (Literal[Sex.FEMALE], True), "demographics.ethnicity": (Ethnicity | None, False), "demographics.nonexistent_field": ( str, True, ), # This field doesn't exist } def compute_score(self, user: UserInput) -> str: return "test" def cancer_type(self) -> str: return "test" def description(self) -> str: return "test" def interpretation(self) -> str: return "test" def references(self) -> list[str]: return ["test"] def time_horizon_years(self) -> float | None: return None model = MissingFieldModel() user = self._create_user_input(Sex.FEMALE, Ethnicity.WHITE) is_valid, errors = model.validate_inputs(user) assert not is_valid assert len(errors) == 1 assert "Required field 'demographics.nonexistent_field' is missing" in errors[0] def test_multiple_validation_errors(self): """Test that multiple validation errors are reported.""" user = self._create_user_input(Sex.MALE, Ethnicity.HISPANIC) # Both wrong is_valid, errors = self.model.validate_inputs(user) assert not is_valid assert len(errors) == 2 # Check that both errors are present error_messages = " ".join(errors) assert "must be FEMALE" in error_messages assert "Input should be" in error_messages assert ( "WHITE" in error_messages and "BLACK" in error_messages and "ASIAN" in error_messages ) def test_literal_enum_type_detection(self): """Test the _is_literal_enum_type helper method.""" # Test Literal with single enum value single_literal = Literal[Sex.FEMALE] assert self.model._is_literal_enum_type(single_literal) # Test Literal with multiple enum values multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK] assert self.model._is_literal_enum_type(multi_literal) # Test non-Literal types assert not self.model._is_literal_enum_type(Sex) assert not self.model._is_literal_enum_type(int) assert not self.model._is_literal_enum_type(str) def test_extract_literal_enum_values(self): """Test the _extract_literal_enum_values helper method.""" # Test single enum value single_literal = Literal[Sex.FEMALE] values = self.model._extract_literal_enum_values(single_literal) assert values == ["FEMALE"] # Test multiple enum values multi_literal = Literal[Ethnicity.WHITE, Ethnicity.BLACK, Ethnicity.ASIAN] values = self.model._extract_literal_enum_values(multi_literal) assert set(values) == {"WHITE", "BLACK", "ASIAN"} def test_backward_compatibility_unrestricted_enum(self): """Test that unrestricted enum types still work (backward compatibility).""" # Create a model with unrestricted enum class UnrestrictedModel(RiskModel): """Test model for backward compatibility with unrestricted enums.""" def __init__(self): super().__init__("unrestricted_test") REQUIRED_INPUTS: dict[str, tuple[type | Any, bool]] = { "demographics.sex": (Sex, True), "demographics.ethnicity": (Ethnicity | None, False), } def compute_score(self, user: UserInput) -> str: return "test" def cancer_type(self) -> str: return "test" def description(self) -> str: return "test" def interpretation(self) -> str: return "test" def references(self) -> list[str]: return ["test"] def time_horizon_years(self) -> float | None: return None model = UnrestrictedModel() # Test with any valid enum values user = self._create_user_input( Sex.MALE, Ethnicity.HISPANIC ) # Any values should work is_valid, errors = model.validate_inputs(user) assert is_valid assert len(errors) == 0