143 lines
5.0 KiB
Python
143 lines
5.0 KiB
Python
"""
|
|
Tests for BaseRepository
|
|
|
|
100% coverage tests for base repository utilities.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import uuid4, UUID
|
|
|
|
from inference.data.repositories.base import BaseRepository
|
|
|
|
|
|
class ConcreteRepository(BaseRepository[MagicMock]):
|
|
"""Concrete implementation for testing abstract base class."""
|
|
pass
|
|
|
|
|
|
class TestBaseRepository:
|
|
"""Tests for BaseRepository."""
|
|
|
|
@pytest.fixture
|
|
def repo(self) -> ConcreteRepository:
|
|
"""Create a ConcreteRepository instance."""
|
|
return ConcreteRepository()
|
|
|
|
# =========================================================================
|
|
# _session() tests
|
|
# =========================================================================
|
|
|
|
def test_session_yields_session(self, repo):
|
|
"""Test _session yields a database session."""
|
|
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
with repo._session() as session:
|
|
assert session is mock_session
|
|
|
|
# =========================================================================
|
|
# _expunge() tests
|
|
# =========================================================================
|
|
|
|
def test_expunge_detaches_entity(self, repo):
|
|
"""Test _expunge detaches entity from session."""
|
|
mock_session = MagicMock()
|
|
mock_entity = MagicMock()
|
|
|
|
result = repo._expunge(mock_session, mock_entity)
|
|
|
|
mock_session.expunge.assert_called_once_with(mock_entity)
|
|
assert result is mock_entity
|
|
|
|
# =========================================================================
|
|
# _expunge_all() tests
|
|
# =========================================================================
|
|
|
|
def test_expunge_all_detaches_all_entities(self, repo):
|
|
"""Test _expunge_all detaches all entities from session."""
|
|
mock_session = MagicMock()
|
|
mock_entity1 = MagicMock()
|
|
mock_entity2 = MagicMock()
|
|
entities = [mock_entity1, mock_entity2]
|
|
|
|
result = repo._expunge_all(mock_session, entities)
|
|
|
|
assert mock_session.expunge.call_count == 2
|
|
mock_session.expunge.assert_any_call(mock_entity1)
|
|
mock_session.expunge.assert_any_call(mock_entity2)
|
|
assert result is entities
|
|
|
|
def test_expunge_all_empty_list(self, repo):
|
|
"""Test _expunge_all with empty list."""
|
|
mock_session = MagicMock()
|
|
entities = []
|
|
|
|
result = repo._expunge_all(mock_session, entities)
|
|
|
|
mock_session.expunge.assert_not_called()
|
|
assert result == []
|
|
|
|
# =========================================================================
|
|
# _now() tests
|
|
# =========================================================================
|
|
|
|
def test_now_returns_utc_datetime(self, repo):
|
|
"""Test _now returns timezone-aware UTC datetime."""
|
|
result = repo._now()
|
|
|
|
assert result.tzinfo == timezone.utc
|
|
assert isinstance(result, datetime)
|
|
|
|
def test_now_is_recent(self, repo):
|
|
"""Test _now returns a recent datetime."""
|
|
before = datetime.now(timezone.utc)
|
|
result = repo._now()
|
|
after = datetime.now(timezone.utc)
|
|
|
|
assert before <= result <= after
|
|
|
|
# =========================================================================
|
|
# _validate_uuid() tests
|
|
# =========================================================================
|
|
|
|
def test_validate_uuid_with_valid_string(self, repo):
|
|
"""Test _validate_uuid with valid UUID string."""
|
|
valid_uuid_str = str(uuid4())
|
|
|
|
result = repo._validate_uuid(valid_uuid_str)
|
|
|
|
assert isinstance(result, UUID)
|
|
assert str(result) == valid_uuid_str
|
|
|
|
def test_validate_uuid_with_invalid_string(self, repo):
|
|
"""Test _validate_uuid raises ValueError for invalid UUID."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
repo._validate_uuid("not-a-valid-uuid")
|
|
|
|
assert "Invalid id" in str(exc_info.value)
|
|
|
|
def test_validate_uuid_with_custom_field_name(self, repo):
|
|
"""Test _validate_uuid uses custom field name in error."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
repo._validate_uuid("invalid", field_name="document_id")
|
|
|
|
assert "Invalid document_id" in str(exc_info.value)
|
|
|
|
def test_validate_uuid_with_none(self, repo):
|
|
"""Test _validate_uuid raises ValueError for None."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
repo._validate_uuid(None)
|
|
|
|
assert "Invalid id" in str(exc_info.value)
|
|
|
|
def test_validate_uuid_with_empty_string(self, repo):
|
|
"""Test _validate_uuid raises ValueError for empty string."""
|
|
with pytest.raises(ValueError) as exc_info:
|
|
repo._validate_uuid("")
|
|
|
|
assert "Invalid id" in str(exc_info.value)
|