This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -0,0 +1 @@
"""Tests for repository pattern implementation."""

View File

@@ -0,0 +1,711 @@
"""
Tests for AnnotationRepository
100% coverage tests for annotation management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
from inference.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepository:
"""Tests for AnnotationRepository."""
@pytest.fixture
def sample_annotation(self) -> AdminAnnotation:
"""Create a sample annotation for testing."""
return AdminAnnotation(
annotation_id=uuid4(),
document_id=uuid4(),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
text_value="INV-001",
confidence=0.95,
source="auto",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_history(self) -> AnnotationHistory:
"""Create a sample annotation history for testing."""
return AnnotationHistory(
history_id=uuid4(),
annotation_id=uuid4(),
document_id=uuid4(),
action="override",
previous_value={"class_name": "old_class"},
new_value={"class_name": "new_class"},
changed_by="admin-token",
change_reason="Correction",
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> AnnotationRepository:
"""Create an AnnotationRepository instance."""
return AnnotationRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_annotation_id(self, repo):
"""Test create returns annotation ID."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
document_id=str(uuid4()),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
)
assert result is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
def test_create_with_optional_params(self, repo):
"""Test create with optional text_value and confidence."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
document_id=str(uuid4()),
page_number=2,
class_id=1,
class_name="invoice_date",
x_center=0.6,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=200,
bbox_y=300,
bbox_width=100,
bbox_height=25,
text_value="2024-01-15",
confidence=0.88,
source="auto",
)
assert result is not None
mock_session.add.assert_called_once()
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.text_value == "2024-01-15"
assert added_annotation.confidence == 0.88
assert added_annotation.source == "auto"
def test_create_default_source_is_manual(self, repo):
"""Test create uses manual as default source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
document_id=str(uuid4()),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.source == "manual"
# =========================================================================
# create_batch() tests
# =========================================================================
def test_create_batch_returns_ids(self, repo):
"""Test create_batch returns list of annotation IDs."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
},
{
"document_id": str(uuid4()),
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.6,
"y_center": 0.4,
"width": 0.15,
"height": 0.04,
"bbox_x": 200,
"bbox_y": 300,
"bbox_width": 100,
"bbox_height": 25,
},
]
result = repo.create_batch(annotations)
assert len(result) == 2
assert mock_session.add.call_count == 2
assert mock_session.flush.call_count == 2
def test_create_batch_default_page_number(self, repo):
"""Test create_batch uses page_number=1 by default."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
# no page_number
},
]
repo.create_batch(annotations)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.page_number == 1
def test_create_batch_with_all_optional_params(self, repo):
"""Test create_batch with all optional parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"page_number": 3,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
"text_value": "INV-123",
"confidence": 0.92,
"source": "ocr",
},
]
repo.create_batch(annotations)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.page_number == 3
assert added_annotation.text_value == "INV-123"
assert added_annotation.confidence == 0.92
assert added_annotation.source == "ocr"
def test_create_batch_empty_list(self, repo):
"""Test create_batch with empty list returns empty."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create_batch([])
assert result == []
mock_session.add.assert_not_called()
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_annotation(self, repo, sample_annotation):
"""Test get returns annotation when exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_annotation.annotation_id))
assert result is not None
assert result.class_name == "invoice_number"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_for_document() tests
# =========================================================================
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
"""Test get_for_document returns all annotations for document."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(sample_annotation.document_id))
assert len(result) == 1
assert result[0].class_name == "invoice_number"
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
"""Test get_for_document filters by page number."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(sample_annotation.document_id), page_number=1)
assert len(result) == 1
def test_get_for_document_returns_empty_list(self, repo):
"""Test get_for_document returns empty list when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(uuid4()))
assert result == []
# =========================================================================
# update() tests
# =========================================================================
def test_update_returns_true(self, repo, sample_annotation):
"""Test update returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-002",
)
assert result is True
assert sample_annotation.text_value == "INV-002"
def test_update_returns_false_when_not_found(self, repo):
"""Test update returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(str(uuid4()), text_value="INV-002")
assert result is False
def test_update_all_fields(self, repo, sample_annotation):
"""Test update can update all fields."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=150,
bbox_y=250,
bbox_width=175,
bbox_height=35,
text_value="NEW-VALUE",
class_id=5,
class_name="new_class",
)
assert result is True
assert sample_annotation.x_center == 0.6
assert sample_annotation.y_center == 0.4
assert sample_annotation.width == 0.25
assert sample_annotation.height == 0.06
assert sample_annotation.bbox_x == 150
assert sample_annotation.bbox_y == 250
assert sample_annotation.bbox_width == 175
assert sample_annotation.bbox_height == 35
assert sample_annotation.text_value == "NEW-VALUE"
assert sample_annotation.class_id == 5
assert sample_annotation.class_name == "new_class"
def test_update_partial_fields(self, repo, sample_annotation):
"""Test update only updates provided fields."""
original_x = sample_annotation.x_center
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
text_value="UPDATED",
)
assert result is True
assert sample_annotation.text_value == "UPDATED"
assert sample_annotation.x_center == original_x # unchanged
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_annotation):
"""Test delete returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
mock_session.delete.assert_called_once()
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()
# =========================================================================
# delete_for_document() tests
# =========================================================================
def test_delete_for_document_returns_count(self, repo, sample_annotation):
"""Test delete_for_document returns count of deleted annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(sample_annotation.document_id))
assert result == 1
mock_session.delete.assert_called_once()
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
"""Test delete_for_document filters by source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(sample_annotation.document_id), source="auto")
assert result == 1
def test_delete_for_document_returns_zero(self, repo):
"""Test delete_for_document returns 0 when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(uuid4()))
assert result == 0
mock_session.delete.assert_not_called()
# =========================================================================
# verify() tests
# =========================================================================
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
"""Test verify marks annotation as verified."""
sample_annotation.is_verified = False
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.verify(str(sample_annotation.annotation_id), "admin-token")
assert result is not None
assert sample_annotation.is_verified is True
assert sample_annotation.verified_by == "admin-token"
mock_session.commit.assert_called_once()
def test_verify_returns_none_when_not_found(self, repo):
"""Test verify returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.verify(str(uuid4()), "admin-token")
assert result is None
# =========================================================================
# override() tests
# =========================================================================
def test_override_updates_annotation(self, repo, sample_annotation):
"""Test override updates annotation and creates history."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(
str(sample_annotation.annotation_id),
"admin-token",
change_reason="Correction",
text_value="NEW-VALUE",
)
assert result is not None
assert sample_annotation.text_value == "NEW-VALUE"
assert sample_annotation.source == "manual"
assert sample_annotation.override_source == "auto"
assert mock_session.add.call_count >= 2 # annotation + history
def test_override_returns_none_when_not_found(self, repo):
"""Test override returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(str(uuid4()), "admin-token", text_value="NEW")
assert result is None
def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation):
"""Test override does not change override_source if already manual."""
sample_annotation.source = "manual"
sample_annotation.override_source = None
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.override(
str(sample_annotation.annotation_id),
"admin-token",
text_value="NEW-VALUE",
)
assert sample_annotation.source == "manual"
assert sample_annotation.override_source is None
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
"""Test override ignores unknown attributes."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(
str(sample_annotation.annotation_id),
"admin-token",
unknown_field="should_be_ignored",
text_value="VALID",
)
assert result is not None
assert sample_annotation.text_value == "VALID"
assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored"
# =========================================================================
# create_history() tests
# =========================================================================
def test_create_history_returns_history(self, repo):
"""Test create_history returns created history record."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotation_id = uuid4()
document_id = uuid4()
result = repo.create_history(
annotation_id=annotation_id,
document_id=document_id,
action="create",
previous_value=None,
new_value={"class_name": "invoice_number"},
changed_by="admin-token",
change_reason="Initial creation",
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_history_with_minimal_params(self, repo):
"""Test create_history with minimal parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create_history(
annotation_id=uuid4(),
document_id=uuid4(),
action="delete",
)
mock_session.add.assert_called_once()
added_history = mock_session.add.call_args[0][0]
assert added_history.action == "delete"
assert added_history.previous_value is None
assert added_history.new_value is None
# =========================================================================
# get_history() tests
# =========================================================================
def test_get_history_returns_list(self, repo, sample_history):
"""Test get_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_history(sample_history.annotation_id)
assert len(result) == 1
assert result[0].action == "override"
def test_get_history_returns_empty_list(self, repo):
"""Test get_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_history(uuid4())
assert result == []
# =========================================================================
# get_document_history() tests
# =========================================================================
def test_get_document_history_returns_list(self, repo, sample_history):
"""Test get_document_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_history(sample_history.document_id)
assert len(result) == 1
def test_get_document_history_returns_empty_list(self, repo):
"""Test get_document_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_history(uuid4())
assert result == []

View File

@@ -0,0 +1,142 @@
"""
Tests for BaseRepository
100% coverage tests for base repository utilities.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.repositories.base import BaseRepository
class ConcreteRepository(BaseRepository[MagicMock]):
"""Concrete implementation for testing abstract base class."""
pass
class TestBaseRepository:
"""Tests for BaseRepository."""
@pytest.fixture
def repo(self) -> ConcreteRepository:
"""Create a ConcreteRepository instance."""
return ConcreteRepository()
# =========================================================================
# _session() tests
# =========================================================================
def test_session_yields_session(self, repo):
"""Test _session yields a database session."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
with repo._session() as session:
assert session is mock_session
# =========================================================================
# _expunge() tests
# =========================================================================
def test_expunge_detaches_entity(self, repo):
"""Test _expunge detaches entity from session."""
mock_session = MagicMock()
mock_entity = MagicMock()
result = repo._expunge(mock_session, mock_entity)
mock_session.expunge.assert_called_once_with(mock_entity)
assert result is mock_entity
# =========================================================================
# _expunge_all() tests
# =========================================================================
def test_expunge_all_detaches_all_entities(self, repo):
"""Test _expunge_all detaches all entities from session."""
mock_session = MagicMock()
mock_entity1 = MagicMock()
mock_entity2 = MagicMock()
entities = [mock_entity1, mock_entity2]
result = repo._expunge_all(mock_session, entities)
assert mock_session.expunge.call_count == 2
mock_session.expunge.assert_any_call(mock_entity1)
mock_session.expunge.assert_any_call(mock_entity2)
assert result is entities
def test_expunge_all_empty_list(self, repo):
"""Test _expunge_all with empty list."""
mock_session = MagicMock()
entities = []
result = repo._expunge_all(mock_session, entities)
mock_session.expunge.assert_not_called()
assert result == []
# =========================================================================
# _now() tests
# =========================================================================
def test_now_returns_utc_datetime(self, repo):
"""Test _now returns timezone-aware UTC datetime."""
result = repo._now()
assert result.tzinfo == timezone.utc
assert isinstance(result, datetime)
def test_now_is_recent(self, repo):
"""Test _now returns a recent datetime."""
before = datetime.now(timezone.utc)
result = repo._now()
after = datetime.now(timezone.utc)
assert before <= result <= after
# =========================================================================
# _validate_uuid() tests
# =========================================================================
def test_validate_uuid_with_valid_string(self, repo):
"""Test _validate_uuid with valid UUID string."""
valid_uuid_str = str(uuid4())
result = repo._validate_uuid(valid_uuid_str)
assert isinstance(result, UUID)
assert str(result) == valid_uuid_str
def test_validate_uuid_with_invalid_string(self, repo):
"""Test _validate_uuid raises ValueError for invalid UUID."""
with pytest.raises(ValueError) as exc_info:
repo._validate_uuid("not-a-valid-uuid")
assert "Invalid id" in str(exc_info.value)
def test_validate_uuid_with_custom_field_name(self, repo):
"""Test _validate_uuid uses custom field name in error."""
with pytest.raises(ValueError) as exc_info:
repo._validate_uuid("invalid", field_name="document_id")
assert "Invalid document_id" in str(exc_info.value)
def test_validate_uuid_with_none(self, repo):
"""Test _validate_uuid raises ValueError for None."""
with pytest.raises(ValueError) as exc_info:
repo._validate_uuid(None)
assert "Invalid id" in str(exc_info.value)
def test_validate_uuid_with_empty_string(self, repo):
"""Test _validate_uuid raises ValueError for empty string."""
with pytest.raises(ValueError) as exc_info:
repo._validate_uuid("")
assert "Invalid id" in str(exc_info.value)

View File

@@ -0,0 +1,386 @@
"""
Tests for BatchUploadRepository
100% coverage tests for batch upload management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import BatchUpload, BatchUploadFile
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadRepository:
"""Tests for BatchUploadRepository."""
@pytest.fixture
def sample_batch(self) -> BatchUpload:
"""Create a sample batch upload for testing."""
return BatchUpload(
batch_id=uuid4(),
admin_token="admin-token",
filename="invoices.zip",
file_size=1024000,
upload_source="ui",
status="pending",
total_files=10,
processed_files=0,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_file(self) -> BatchUploadFile:
"""Create a sample batch upload file for testing."""
return BatchUploadFile(
file_id=uuid4(),
batch_id=uuid4(),
filename="invoice_001.pdf",
status="pending",
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> BatchUploadRepository:
"""Create a BatchUploadRepository instance."""
return BatchUploadRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_batch(self, repo):
"""Test create returns created batch upload."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
admin_token="admin-token",
filename="test.zip",
file_size=1024,
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_with_upload_source(self, repo):
"""Test create with custom upload source."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
admin_token="admin-token",
filename="test.zip",
file_size=1024,
upload_source="api",
)
added_batch = mock_session.add.call_args[0][0]
assert added_batch.upload_source == "api"
def test_create_default_upload_source(self, repo):
"""Test create uses default upload source."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
admin_token="admin-token",
filename="test.zip",
file_size=1024,
)
added_batch = mock_session.add.call_args[0][0]
assert added_batch.upload_source == "ui"
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_batch(self, repo, sample_batch):
"""Test get returns batch when exists."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(sample_batch.batch_id)
assert result is not None
assert result.filename == "invoices.zip"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when batch not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(uuid4())
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# update() tests
# =========================================================================
def test_update_updates_batch(self, repo, sample_batch):
"""Test update updates batch fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update(
sample_batch.batch_id,
status="processing",
processed_files=5,
)
assert sample_batch.status == "processing"
assert sample_batch.processed_files == 5
mock_session.add.assert_called_once()
def test_update_ignores_unknown_fields(self, repo, sample_batch):
"""Test update ignores unknown fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update(
sample_batch.batch_id,
unknown_field="should_be_ignored",
)
mock_session.add.assert_called_once()
def test_update_not_found(self, repo):
"""Test update does nothing when batch not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update(uuid4(), status="processing")
mock_session.add.assert_not_called()
def test_update_multiple_fields(self, repo, sample_batch):
"""Test update can update multiple fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update(
sample_batch.batch_id,
status="completed",
processed_files=10,
total_files=10,
)
assert sample_batch.status == "completed"
assert sample_batch.processed_files == 10
assert sample_batch.total_files == 10
# =========================================================================
# create_file() tests
# =========================================================================
def test_create_file_returns_file(self, repo):
"""Test create_file returns created file record."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create_file(
batch_id=uuid4(),
filename="invoice_001.pdf",
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_file_with_kwargs(self, repo):
"""Test create_file with additional kwargs."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create_file(
batch_id=uuid4(),
filename="invoice_001.pdf",
status="processing",
file_size=1024,
)
added_file = mock_session.add.call_args[0][0]
assert added_file.filename == "invoice_001.pdf"
# =========================================================================
# update_file() tests
# =========================================================================
def test_update_file_updates_file(self, repo, sample_file):
"""Test update_file updates file fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file(
sample_file.file_id,
status="completed",
)
assert sample_file.status == "completed"
mock_session.add.assert_called_once()
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
"""Test update_file ignores unknown fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file(
sample_file.file_id,
unknown_field="should_be_ignored",
)
mock_session.add.assert_called_once()
def test_update_file_not_found(self, repo):
"""Test update_file does nothing when file not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file(uuid4(), status="completed")
mock_session.add.assert_not_called()
def test_update_file_multiple_fields(self, repo, sample_file):
"""Test update_file can update multiple fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file(
sample_file.file_id,
status="failed",
)
assert sample_file.status == "failed"
# =========================================================================
# get_files() tests
# =========================================================================
def test_get_files_returns_list(self, repo, sample_file):
"""Test get_files returns list of files."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_file]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_files(sample_file.batch_id)
assert len(result) == 1
assert result[0].filename == "invoice_001.pdf"
def test_get_files_returns_empty_list(self, repo):
"""Test get_files returns empty list when no files."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_files(uuid4())
assert result == []
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
"""Test get_paginated returns list of batches and total count."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_batch]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
batches, total = repo.get_paginated()
assert len(batches) == 1
assert total == 1
def test_get_paginated_with_pagination(self, repo, sample_batch):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 100
mock_session.exec.return_value.all.return_value = [sample_batch]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
batches, total = repo.get_paginated(limit=25, offset=50)
assert total == 100
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
batches, total = repo.get_paginated()
assert batches == []
assert total == 0
def test_get_paginated_with_admin_token(self, repo, sample_batch):
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_batch]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
batches, total = repo.get_paginated(admin_token="admin-token")
assert len(batches) == 1

View File

@@ -0,0 +1,597 @@
"""
Tests for DatasetRepository
100% coverage tests for dataset management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from inference.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepository:
"""Tests for DatasetRepository."""
@pytest.fixture
def sample_dataset(self) -> TrainingDataset:
"""Create a sample dataset for testing."""
return TrainingDataset(
dataset_id=uuid4(),
name="Test Dataset",
description="A test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=100,
total_images=100,
total_annotations=500,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_dataset_document(self) -> DatasetDocument:
"""Create a sample dataset document for testing."""
return DatasetDocument(
id=uuid4(),
dataset_id=uuid4(),
document_id=uuid4(),
split="train",
page_count=2,
annotation_count=10,
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_training_task(self) -> TrainingTask:
"""Create a sample training task for testing."""
return TrainingTask(
task_id=uuid4(),
admin_token="admin-token",
name="Test Task",
status="running",
dataset_id=uuid4(),
)
@pytest.fixture
def repo(self) -> DatasetRepository:
"""Create a DatasetRepository instance."""
return DatasetRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_dataset(self, repo):
"""Test create returns created dataset."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(name="Test Dataset")
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
name="Full Dataset",
description="A complete dataset",
train_ratio=0.7,
val_ratio=0.15,
seed=123,
)
added_dataset = mock_session.add.call_args[0][0]
assert added_dataset.name == "Full Dataset"
assert added_dataset.description == "A complete dataset"
assert added_dataset.train_ratio == 0.7
assert added_dataset.val_ratio == 0.15
assert added_dataset.seed == 123
def test_create_default_values(self, repo):
"""Test create uses default values."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(name="Minimal Dataset")
added_dataset = mock_session.add.call_args[0][0]
assert added_dataset.train_ratio == 0.8
assert added_dataset.val_ratio == 0.1
assert added_dataset.seed == 42
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_dataset(self, repo, sample_dataset):
"""Test get returns dataset when exists."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_dataset.dataset_id))
assert result is not None
assert result.name == "Test Dataset"
mock_session.expunge.assert_called_once()
def test_get_with_uuid(self, repo, sample_dataset):
"""Test get works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(sample_dataset.dataset_id)
assert result is not None
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
"""Test get_paginated returns list of datasets and total count."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated()
assert len(datasets) == 1
assert total == 1
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated(status="ready")
assert len(datasets) == 1
def test_get_paginated_with_pagination(self, repo, sample_dataset):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated(limit=10, offset=20)
assert total == 50
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated()
assert datasets == []
assert total == 0
# =========================================================================
# get_active_training_tasks() tests
# =========================================================================
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
"""Test get_active_training_tasks returns dict of active tasks."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_training_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active_training_tasks([str(sample_training_task.dataset_id)])
assert str(sample_training_task.dataset_id) in result
def test_get_active_training_tasks_empty_input(self, repo):
"""Test get_active_training_tasks with empty input."""
result = repo.get_active_training_tasks([])
assert result == {}
def test_get_active_training_tasks_invalid_uuid(self, repo):
"""Test get_active_training_tasks filters invalid UUIDs."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active_training_tasks(["invalid-uuid", str(uuid4())])
# Should still query with valid UUID
assert result == {}
def test_get_active_training_tasks_all_invalid_uuids(self, repo):
"""Test get_active_training_tasks with all invalid UUIDs."""
result = repo.get_active_training_tasks(["invalid-uuid-1", "invalid-uuid-2"])
assert result == {}
# =========================================================================
# update_status() tests
# =========================================================================
def test_update_status_updates_dataset(self, repo, sample_dataset):
"""Test update_status updates dataset status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_dataset.dataset_id), "training")
assert sample_dataset.status == "training"
mock_session.commit.assert_called_once()
def test_update_status_with_error_message(self, repo, sample_dataset):
"""Test update_status with error message."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"failed",
error_message="Training failed",
)
assert sample_dataset.error_message == "Training failed"
def test_update_status_with_totals(self, repo, sample_dataset):
"""Test update_status with total counts."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"ready",
total_documents=200,
total_images=200,
total_annotations=1000,
)
assert sample_dataset.total_documents == 200
assert sample_dataset.total_images == 200
assert sample_dataset.total_annotations == 1000
def test_update_status_with_dataset_path(self, repo, sample_dataset):
"""Test update_status with dataset path."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"ready",
dataset_path="/path/to/dataset",
)
assert sample_dataset.dataset_path == "/path/to/dataset"
def test_update_status_with_uuid(self, repo, sample_dataset):
"""Test update_status works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(sample_dataset.dataset_id, "ready")
assert sample_dataset.status == "ready"
def test_update_status_not_found(self, repo):
"""Test update_status does nothing when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(uuid4()), "ready")
mock_session.add.assert_not_called()
# =========================================================================
# update_training_status() tests
# =========================================================================
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
"""Test update_training_status updates training status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(str(sample_dataset.dataset_id), "running")
assert sample_dataset.training_status == "running"
mock_session.commit.assert_called_once()
def test_update_training_status_with_task_id(self, repo, sample_dataset):
"""Test update_training_status with active task ID."""
task_id = uuid4()
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
"running",
active_training_task_id=str(task_id),
)
assert sample_dataset.active_training_task_id == task_id
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
"""Test update_training_status updates main status when completed."""
sample_dataset.status = "ready"
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
"completed",
update_main_status=True,
)
assert sample_dataset.training_status == "completed"
assert sample_dataset.status == "trained"
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
"""Test update_training_status clears task ID when None."""
sample_dataset.active_training_task_id = uuid4()
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
None,
active_training_task_id=None,
)
assert sample_dataset.active_training_task_id is None
def test_update_training_status_not_found(self, repo):
"""Test update_training_status does nothing when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(str(uuid4()), "running")
mock_session.add.assert_not_called()
# =========================================================================
# add_documents() tests
# =========================================================================
def test_add_documents_creates_links(self, repo):
"""Test add_documents creates dataset document links."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": str(uuid4()),
"split": "train",
"page_count": 2,
"annotation_count": 10,
},
{
"document_id": str(uuid4()),
"split": "val",
"page_count": 1,
"annotation_count": 5,
},
]
repo.add_documents(str(uuid4()), documents)
assert mock_session.add.call_count == 2
mock_session.commit.assert_called_once()
def test_add_documents_default_counts(self, repo):
"""Test add_documents uses default counts."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": str(uuid4()),
"split": "train",
},
]
repo.add_documents(str(uuid4()), documents)
added_doc = mock_session.add.call_args[0][0]
assert added_doc.page_count == 0
assert added_doc.annotation_count == 0
def test_add_documents_with_uuid(self, repo):
"""Test add_documents works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": uuid4(),
"split": "train",
},
]
repo.add_documents(uuid4(), documents)
mock_session.add.assert_called_once()
def test_add_documents_empty_list(self, repo):
"""Test add_documents with empty list."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.add_documents(str(uuid4()), [])
mock_session.add.assert_not_called()
mock_session.commit.assert_called_once()
# =========================================================================
# get_documents() tests
# =========================================================================
def test_get_documents_returns_list(self, repo, sample_dataset_document):
"""Test get_documents returns list of dataset documents."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(str(sample_dataset_document.dataset_id))
assert len(result) == 1
assert result[0].split == "train"
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
"""Test get_documents works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(sample_dataset_document.dataset_id)
assert len(result) == 1
def test_get_documents_returns_empty_list(self, repo):
"""Test get_documents returns empty list when no documents."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(str(uuid4()))
assert result == []
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_dataset):
"""Test delete returns True when dataset exists."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
mock_session.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_with_uuid(self, repo, sample_dataset):
"""Test delete works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(sample_dataset.dataset_id)
assert result is True
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()

View File

@@ -0,0 +1,748 @@
"""
Tests for DocumentRepository
Comprehensive TDD tests for document management - targeting 100% coverage.
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4
from inference.data.admin_models import AdminDocument, AdminAnnotation
from inference.data.repositories.document_repository import DocumentRepository
class TestDocumentRepository:
"""Tests for DocumentRepository."""
@pytest.fixture
def sample_document(self) -> AdminDocument:
"""Create a sample document for testing."""
return AdminDocument(
document_id=uuid4(),
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def labeled_document(self) -> AdminDocument:
"""Create a labeled document for testing."""
return AdminDocument(
document_id=uuid4(),
filename="labeled.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/tmp/labeled.pdf",
page_count=2,
status="labeled",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def locked_document(self) -> AdminDocument:
"""Create a document with annotation lock."""
doc = AdminDocument(
document_id=uuid4(),
filename="locked.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/locked.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
doc.annotation_lock_until = datetime.now(timezone.utc) + timedelta(minutes=5)
return doc
@pytest.fixture
def expired_lock_document(self) -> AdminDocument:
"""Create a document with expired annotation lock."""
doc = AdminDocument(
document_id=uuid4(),
filename="expired_lock.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/expired_lock.pdf",
page_count=1,
status="pending",
category="invoice",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
doc.annotation_lock_until = datetime.now(timezone.utc) - timedelta(minutes=5)
return doc
@pytest.fixture
def repo(self) -> DocumentRepository:
"""Create a DocumentRepository instance."""
return DocumentRepository()
# ==========================================================================
# create() tests
# ==========================================================================
def test_create_returns_document_id(self, repo):
"""Test create returns document ID."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
)
assert result is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
def test_create_with_all_parameters(self, repo):
"""Test create with all optional parameters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
filename="test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/tmp/test.pdf",
page_count=5,
upload_source="api",
csv_field_values={"InvoiceNumber": "INV-001"},
group_key="batch-001",
category="receipt",
admin_token="token-123",
)
assert result is not None
added_doc = mock_session.add.call_args[0][0]
assert added_doc.page_count == 5
assert added_doc.upload_source == "api"
assert added_doc.csv_field_values == {"InvoiceNumber": "INV-001"}
assert added_doc.group_key == "batch-001"
assert added_doc.category == "receipt"
# ==========================================================================
# get() tests
# ==========================================================================
def test_get_returns_document(self, repo, sample_document):
"""Test get returns document when exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_document.document_id))
assert result is not None
assert result.filename == "test.pdf"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
# ==========================================================================
# get_by_token() tests
# ==========================================================================
def test_get_by_token_delegates_to_get(self, repo, sample_document):
"""Test get_by_token delegates to get method."""
with patch.object(repo, "get", return_value=sample_document) as mock_get:
result = repo.get_by_token(str(sample_document.document_id), "token-123")
assert result == sample_document
mock_get.assert_called_once_with(str(sample_document.document_id))
# ==========================================================================
# get_paginated() tests
# ==========================================================================
def test_get_paginated_no_filters(self, repo, sample_document):
"""Test get_paginated with no filters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated()
assert total == 1
assert len(results) == 1
def test_get_paginated_with_status_filter(self, repo, sample_document):
"""Test get_paginated with status filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(status="pending")
assert total == 1
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
"""Test get_paginated with upload_source filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(upload_source="ui")
assert total == 1
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
"""Test get_paginated with auto_label_status filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(auto_label_status="completed")
assert total == 1
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
"""Test get_paginated with batch_id filter."""
batch_id = str(uuid4())
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(batch_id=batch_id)
assert total == 1
def test_get_paginated_with_category_filter(self, repo, sample_document):
"""Test get_paginated with category filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(category="invoice")
assert total == 1
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
"""Test get_paginated with has_annotations=True."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(has_annotations=True)
assert total == 1
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
"""Test get_paginated with has_annotations=False."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_paginated(has_annotations=False)
assert total == 1
# ==========================================================================
# update_status() tests
# ==========================================================================
def test_update_status(self, repo, sample_document):
"""Test update_status updates document status."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_document.document_id), "labeled")
assert sample_document.status == "labeled"
mock_session.add.assert_called_once()
def test_update_status_with_auto_label_status(self, repo, sample_document):
"""Test update_status with auto_label_status."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_document.document_id),
"labeled",
auto_label_status="completed",
)
assert sample_document.auto_label_status == "completed"
def test_update_status_with_auto_label_error(self, repo, sample_document):
"""Test update_status with auto_label_error."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_document.document_id),
"failed",
auto_label_error="OCR failed",
)
assert sample_document.auto_label_error == "OCR failed"
def test_update_status_document_not_found(self, repo):
"""Test update_status when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(uuid4()), "labeled")
mock_session.add.assert_not_called()
# ==========================================================================
# update_file_path() tests
# ==========================================================================
def test_update_file_path(self, repo, sample_document):
"""Test update_file_path updates document file path."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file_path(str(sample_document.document_id), "/new/path.pdf")
assert sample_document.file_path == "/new/path.pdf"
mock_session.add.assert_called_once()
def test_update_file_path_document_not_found(self, repo):
"""Test update_file_path when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_file_path(str(uuid4()), "/new/path.pdf")
mock_session.add.assert_not_called()
# ==========================================================================
# update_group_key() tests
# ==========================================================================
def test_update_group_key_returns_true(self, repo, sample_document):
"""Test update_group_key returns True when document exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_group_key(str(sample_document.document_id), "new-group")
assert result is True
assert sample_document.group_key == "new-group"
def test_update_group_key_returns_false(self, repo):
"""Test update_group_key returns False when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_group_key(str(uuid4()), "new-group")
assert result is False
# ==========================================================================
# update_category() tests
# ==========================================================================
def test_update_category(self, repo, sample_document):
"""Test update_category updates document category."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_category(str(sample_document.document_id), "receipt")
assert sample_document.category == "receipt"
mock_session.add.assert_called()
def test_update_category_returns_none_when_not_found(self, repo):
"""Test update_category returns None when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update_category(str(uuid4()), "receipt")
assert result is None
# ==========================================================================
# delete() tests
# ==========================================================================
def test_delete_returns_true_when_exists(self, repo, sample_document):
"""Test delete returns True when document exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_document.document_id))
assert result is True
mock_session.delete.assert_called_once_with(sample_document)
def test_delete_with_annotations(self, repo, sample_document):
"""Test delete removes annotations before deleting document."""
annotation = MagicMock()
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = [annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_document.document_id))
assert result is True
assert mock_session.delete.call_count == 2
def test_delete_returns_false_when_not_exists(self, repo):
"""Test delete returns False when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
# ==========================================================================
# get_categories() tests
# ==========================================================================
def test_get_categories(self, repo):
"""Test get_categories returns unique categories."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_categories()
assert result == ["invoice", "receipt"]
# ==========================================================================
# get_labeled_for_export() tests
# ==========================================================================
def test_get_labeled_for_export(self, repo, labeled_document):
"""Test get_labeled_for_export returns labeled documents."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_labeled_for_export()
assert len(result) == 1
assert result[0].status == "labeled"
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
"""Test get_labeled_for_export with admin_token filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_labeled_for_export(admin_token="token-123")
assert len(result) == 1
# ==========================================================================
# count_by_status() tests
# ==========================================================================
def test_count_by_status(self, repo):
"""Test count_by_status returns status counts."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [
("pending", 10),
("labeled", 5),
]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.count_by_status()
assert result == {"pending": 10, "labeled": 5}
# ==========================================================================
# get_by_ids() tests
# ==========================================================================
def test_get_by_ids(self, repo, sample_document):
"""Test get_by_ids returns documents by IDs."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_by_ids([str(sample_document.document_id)])
assert len(result) == 1
# ==========================================================================
# get_for_training() tests
# ==========================================================================
def test_get_for_training_basic(self, repo, labeled_document):
"""Test get_for_training with default parameters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training()
assert total == 1
assert len(results) == 1
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
"""Test get_for_training with min_annotation_count."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(min_annotation_count=3)
assert total == 1
def test_get_for_training_exclude_used(self, repo, labeled_document):
"""Test get_for_training with exclude_used_in_training."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(exclude_used_in_training=True)
assert total == 1
def test_get_for_training_no_annotations(self, repo, labeled_document):
"""Test get_for_training with has_annotations=False."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
results, total = repo.get_for_training(has_annotations=False)
assert total == 1
# ==========================================================================
# acquire_annotation_lock() tests
# ==========================================================================
def test_acquire_annotation_lock_success(self, repo, sample_document):
"""Test acquire_annotation_lock when no lock exists."""
sample_document.annotation_lock_until = None
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is not None
assert sample_document.annotation_lock_until is not None
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
"""Test acquire_annotation_lock fails when document is already locked."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(locked_document.document_id))
assert result is None
def test_acquire_annotation_lock_document_not_found(self, repo):
"""Test acquire_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.acquire_annotation_lock(str(uuid4()))
assert result is None
# ==========================================================================
# release_annotation_lock() tests
# ==========================================================================
def test_release_annotation_lock_success(self, repo, locked_document):
"""Test release_annotation_lock releases the lock."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.release_annotation_lock(str(locked_document.document_id))
assert result is not None
assert locked_document.annotation_lock_until is None
def test_release_annotation_lock_document_not_found(self, repo):
"""Test release_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.release_annotation_lock(str(uuid4()))
assert result is None
# ==========================================================================
# extend_annotation_lock() tests
# ==========================================================================
def test_extend_annotation_lock_success(self, repo, locked_document):
"""Test extend_annotation_lock extends the lock."""
original_lock = locked_document.annotation_lock_until
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(locked_document.document_id))
assert result is not None
assert locked_document.annotation_lock_until > original_lock
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
"""Test extend_annotation_lock fails when no lock exists."""
sample_document.annotation_lock_until = None
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(sample_document.document_id))
assert result is None
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
"""Test extend_annotation_lock fails when lock is expired."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = expired_lock_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(expired_lock_document.document_id))
assert result is None
def test_extend_annotation_lock_document_not_found(self, repo):
"""Test extend_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.extend_annotation_lock(str(uuid4()))
assert result is None

View File

@@ -0,0 +1,582 @@
"""
Tests for ModelVersionRepository
100% coverage tests for model version management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import ModelVersion
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionRepository:
"""Tests for ModelVersionRepository."""
@pytest.fixture
def sample_model(self) -> ModelVersion:
"""Create a sample model version for testing."""
return ModelVersion(
version_id=uuid4(),
version="v1.0.0",
name="Test Model",
description="A test model",
model_path="/path/to/model.pt",
status="ready",
is_active=False,
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=100,
training_config={"epochs": 100},
file_size=1024000,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def active_model(self) -> ModelVersion:
"""Create an active model version for testing."""
return ModelVersion(
version_id=uuid4(),
version="v1.0.0",
name="Active Model",
model_path="/path/to/active_model.pt",
status="active",
is_active=True,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> ModelVersionRepository:
"""Create a ModelVersionRepository instance."""
return ModelVersionRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_model(self, repo):
"""Test create returns created model version."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
task_id = uuid4()
dataset_id = uuid4()
trained_at = datetime.now(timezone.utc)
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
version="v2.0.0",
name="Full Model",
model_path="/path/to/full_model.pt",
description="A complete model",
task_id=str(task_id),
dataset_id=str(dataset_id),
metrics_mAP=0.95,
metrics_precision=0.92,
metrics_recall=0.88,
document_count=500,
training_config={"epochs": 200},
file_size=2048000,
trained_at=trained_at,
)
added_model = mock_session.add.call_args[0][0]
assert added_model.version == "v2.0.0"
assert added_model.description == "A complete model"
assert added_model.task_id == task_id
assert added_model.dataset_id == dataset_id
assert added_model.metrics_mAP == 0.95
def test_create_with_uuid_objects(self, repo):
"""Test create works with UUID objects."""
task_id = uuid4()
dataset_id = uuid4()
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
task_id=task_id,
dataset_id=dataset_id,
)
added_model = mock_session.add.call_args[0][0]
assert added_model.task_id == task_id
assert added_model.dataset_id == dataset_id
def test_create_without_optional_ids(self, repo):
"""Test create without task_id and dataset_id."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
version="v1.0.0",
name="Test Model",
model_path="/path/to/model.pt",
)
added_model = mock_session.add.call_args[0][0]
assert added_model.task_id is None
assert added_model.dataset_id is None
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_model(self, repo, sample_model):
"""Test get returns model when exists."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_model.version_id))
assert result is not None
assert result.name == "Test Model"
mock_session.expunge.assert_called_once()
def test_get_with_uuid(self, repo, sample_model):
"""Test get works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(sample_model.version_id)
assert result is not None
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
"""Test get_paginated returns list of models and total count."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated()
assert len(models) == 1
assert total == 1
def test_get_paginated_with_status_filter(self, repo, sample_model):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated(status="ready")
assert len(models) == 1
def test_get_paginated_with_pagination(self, repo, sample_model):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_model]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated(limit=10, offset=20)
assert total == 50
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
models, total = repo.get_paginated()
assert models == []
assert total == 0
# =========================================================================
# get_active() tests
# =========================================================================
def test_get_active_returns_active_model(self, repo, active_model):
"""Test get_active returns the active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active()
assert result is not None
assert result.is_active is True
mock_session.expunge.assert_called_once()
def test_get_active_returns_none(self, repo):
"""Test get_active returns None when no active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active()
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# activate() tests
# =========================================================================
def test_activate_activates_model(self, repo, sample_model, active_model):
"""Test activate sets model as active and deactivates others."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [active_model]
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(str(sample_model.version_id))
assert result is not None
assert sample_model.is_active is True
assert sample_model.status == "active"
assert active_model.is_active is False
assert active_model.status == "inactive"
def test_activate_with_uuid(self, repo, sample_model):
"""Test activate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(sample_model.version_id)
assert result is not None
assert sample_model.is_active is True
def test_activate_returns_none_when_not_found(self, repo):
"""Test activate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.activate(str(uuid4()))
assert result is None
def test_activate_sets_activated_at(self, repo, sample_model):
"""Test activate sets activated_at timestamp."""
sample_model.activated_at = None
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.activate(str(sample_model.version_id))
assert sample_model.activated_at is not None
# =========================================================================
# deactivate() tests
# =========================================================================
def test_deactivate_deactivates_model(self, repo, active_model):
"""Test deactivate sets model as inactive."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(str(active_model.version_id))
assert result is not None
assert active_model.is_active is False
assert active_model.status == "inactive"
mock_session.commit.assert_called_once()
def test_deactivate_with_uuid(self, repo, active_model):
"""Test deactivate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(active_model.version_id)
assert result is not None
def test_deactivate_returns_none_when_not_found(self, repo):
"""Test deactivate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate(str(uuid4()))
assert result is None
# =========================================================================
# update() tests
# =========================================================================
def test_update_updates_model(self, repo, sample_model):
"""Test update updates model metadata."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
name="Updated Model",
)
assert result is not None
assert sample_model.name == "Updated Model"
mock_session.commit.assert_called_once()
def test_update_all_fields(self, repo, sample_model):
"""Test update can update all fields."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
name="New Name",
description="New Description",
status="archived",
)
assert sample_model.name == "New Name"
assert sample_model.description == "New Description"
assert sample_model.status == "archived"
def test_update_with_uuid(self, repo, sample_model):
"""Test update works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(sample_model.version_id, name="Updated")
assert result is not None
def test_update_returns_none_when_not_found(self, repo):
"""Test update returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(str(uuid4()), name="New Name")
assert result is None
def test_update_partial_fields(self, repo, sample_model):
"""Test update only updates provided fields."""
original_name = sample_model.name
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_model.version_id),
description="Only description changed",
)
assert sample_model.name == original_name
assert sample_model.description == "Only description changed"
# =========================================================================
# archive() tests
# =========================================================================
def test_archive_archives_model(self, repo, sample_model):
"""Test archive sets model status to archived."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(sample_model.version_id))
assert result is not None
assert sample_model.status == "archived"
mock_session.commit.assert_called_once()
def test_archive_with_uuid(self, repo, sample_model):
"""Test archive works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(sample_model.version_id)
assert result is not None
def test_archive_returns_none_when_not_found(self, repo):
"""Test archive returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(uuid4()))
assert result is None
def test_archive_returns_none_when_active(self, repo, active_model):
"""Test archive returns None when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.archive(str(active_model.version_id))
assert result is None
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_model):
"""Test delete returns True when model exists and not active."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_model.version_id))
assert result is True
mock_session.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_with_uuid(self, repo, sample_model):
"""Test delete works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(sample_model.version_id)
assert result is True
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()
def test_delete_returns_false_when_active(self, repo, active_model):
"""Test delete returns False when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(active_model.version_id))
assert result is False
mock_session.delete.assert_not_called()

View File

@@ -0,0 +1,199 @@
"""
Tests for TokenRepository
TDD tests for admin token management.
"""
import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from inference.data.admin_models import AdminToken
from inference.data.repositories.token_repository import TokenRepository
class TestTokenRepository:
"""Tests for TokenRepository."""
@pytest.fixture
def sample_token(self) -> AdminToken:
"""Create a sample token for testing."""
return AdminToken(
token="test-token-123",
name="Test Token",
is_active=True,
created_at=datetime.now(timezone.utc),
last_used_at=None,
expires_at=None,
)
@pytest.fixture
def expired_token(self) -> AdminToken:
"""Create an expired token."""
return AdminToken(
token="expired-token",
name="Expired Token",
is_active=True,
created_at=datetime.now(timezone.utc) - timedelta(days=30),
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
)
@pytest.fixture
def inactive_token(self) -> AdminToken:
"""Create an inactive token."""
return AdminToken(
token="inactive-token",
name="Inactive Token",
is_active=False,
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> TokenRepository:
"""Create a TokenRepository instance."""
return TokenRepository()
def test_is_valid_returns_true_for_active_token(self, repo, sample_token):
"""Test is_valid returns True for an active, non-expired token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.is_valid("test-token-123")
assert result is True
mock_session.get.assert_called_once_with(AdminToken, "test-token-123")
def test_is_valid_returns_false_for_nonexistent_token(self, repo):
"""Test is_valid returns False for a non-existent token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.is_valid("nonexistent-token")
assert result is False
def test_is_valid_returns_false_for_inactive_token(self, repo, inactive_token):
"""Test is_valid returns False for an inactive token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = inactive_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.is_valid("inactive-token")
assert result is False
def test_is_valid_returns_false_for_expired_token(self, repo, expired_token):
"""Test is_valid returns False for an expired token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = expired_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.is_valid("expired-token")
assert result is False
def test_get_returns_token_when_exists(self, repo, sample_token):
"""Test get returns token when it exists."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get("test-token-123")
assert result is not None
assert result.token == "test-token-123"
assert result.name == "Test Token"
mock_session.expunge.assert_called_once_with(sample_token)
def test_get_returns_none_when_not_exists(self, repo):
"""Test get returns None when token doesn't exist."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get("nonexistent-token")
assert result is None
def test_create_new_token(self, repo):
"""Test creating a new token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None # Token doesn't exist
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create("new-token", "New Token", expires_at=None)
mock_session.add.assert_called_once()
added_token = mock_session.add.call_args[0][0]
assert isinstance(added_token, AdminToken)
assert added_token.token == "new-token"
assert added_token.name == "New Token"
def test_create_updates_existing_token(self, repo, sample_token):
"""Test create updates an existing token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create("test-token-123", "Updated Name", expires_at=None)
mock_session.add.assert_called_once_with(sample_token)
assert sample_token.name == "Updated Name"
assert sample_token.is_active is True
def test_update_usage(self, repo, sample_token):
"""Test updating token last_used_at timestamp."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_usage("test-token-123")
assert sample_token.last_used_at is not None
mock_session.add.assert_called_once_with(sample_token)
def test_deactivate_returns_true_when_token_exists(self, repo, sample_token):
"""Test deactivate returns True when token exists."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate("test-token-123")
assert result is True
assert sample_token.is_active is False
mock_session.add.assert_called_once_with(sample_token)
def test_deactivate_returns_false_when_token_not_exists(self, repo):
"""Test deactivate returns False when token doesn't exist."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.deactivate("nonexistent-token")
assert result is False

View File

@@ -0,0 +1,615 @@
"""
Tests for TrainingTaskRepository
100% coverage tests for training task management.
"""
import pytest
from datetime import datetime, timezone, timedelta
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskRepository:
"""Tests for TrainingTaskRepository."""
@pytest.fixture
def sample_task(self) -> TrainingTask:
"""Create a sample training task for testing."""
return TrainingTask(
task_id=uuid4(),
admin_token="admin-token",
name="Test Training Task",
task_type="train",
description="A test training task",
status="pending",
config={"epochs": 100, "batch_size": 16},
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_log(self) -> TrainingLog:
"""Create a sample training log for testing."""
return TrainingLog(
log_id=uuid4(),
task_id=uuid4(),
level="INFO",
message="Training started",
details={"epoch": 1},
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_link(self) -> TrainingDocumentLink:
"""Create a sample training document link for testing."""
return TrainingDocumentLink(
link_id=uuid4(),
task_id=uuid4(),
document_id=uuid4(),
annotation_snapshot={"annotations": []},
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> TrainingTaskRepository:
"""Create a TrainingTaskRepository instance."""
return TrainingTaskRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_task_id(self, repo):
"""Test create returns task ID."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
admin_token="admin-token",
name="Test Task",
)
assert result is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
admin_token="admin-token",
name="Test Task",
task_type="finetune",
description="Full test",
config={"epochs": 50},
scheduled_at=scheduled_time,
cron_expression="0 0 * * *",
is_recurring=True,
dataset_id=str(uuid4()),
)
assert result is not None
added_task = mock_session.add.call_args[0][0]
assert added_task.task_type == "finetune"
assert added_task.description == "Full test"
assert added_task.is_recurring is True
assert added_task.status == "scheduled" # because scheduled_at is set
def test_create_pending_status_when_not_scheduled(self, repo):
"""Test create sets pending status when no scheduled_at."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
admin_token="admin-token",
name="Test Task",
)
added_task = mock_session.add.call_args[0][0]
assert added_task.status == "pending"
def test_create_scheduled_status_when_scheduled(self, repo):
"""Test create sets scheduled status when scheduled_at is provided."""
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
admin_token="admin-token",
name="Test Task",
scheduled_at=scheduled_time,
)
added_task = mock_session.add.call_args[0][0]
assert added_task.status == "scheduled"
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_task(self, repo, sample_task):
"""Test get returns task when exists."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_task.task_id))
assert result is not None
assert result.name == "Test Training Task"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_by_token() tests
# =========================================================================
def test_get_by_token_returns_task(self, repo, sample_task):
"""Test get_by_token returns task (delegates to get)."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_by_token(str(sample_task.task_id), "admin-token")
assert result is not None
def test_get_by_token_without_token_param(self, repo, sample_task):
"""Test get_by_token works without token parameter."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_by_token(str(sample_task.task_id))
assert result is not None
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_tasks_and_total(self, repo, sample_task):
"""Test get_paginated returns list of tasks and total count."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
tasks, total = repo.get_paginated()
assert len(tasks) == 1
assert total == 1
def test_get_paginated_with_status_filter(self, repo, sample_task):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
tasks, total = repo.get_paginated(status="pending")
assert len(tasks) == 1
def test_get_paginated_with_pagination(self, repo, sample_task):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
tasks, total = repo.get_paginated(limit=10, offset=20)
assert total == 50
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
tasks, total = repo.get_paginated()
assert tasks == []
assert total == 0
# =========================================================================
# get_pending() tests
# =========================================================================
def test_get_pending_returns_pending_tasks(self, repo, sample_task):
"""Test get_pending returns pending and scheduled tasks."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_pending()
assert len(result) == 1
def test_get_pending_returns_empty_list(self, repo):
"""Test get_pending returns empty list when no pending tasks."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_pending()
assert result == []
# =========================================================================
# update_status() tests
# =========================================================================
def test_update_status_updates_task(self, repo, sample_task):
"""Test update_status updates task status."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_task.task_id), "running")
assert sample_task.status == "running"
def test_update_status_sets_started_at_for_running(self, repo, sample_task):
"""Test update_status sets started_at when status is running."""
sample_task.started_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_task.task_id), "running")
assert sample_task.started_at is not None
def test_update_status_sets_completed_at_for_completed(self, repo, sample_task):
"""Test update_status sets completed_at when status is completed."""
sample_task.completed_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_task.task_id), "completed")
assert sample_task.completed_at is not None
def test_update_status_sets_completed_at_for_failed(self, repo, sample_task):
"""Test update_status sets completed_at when status is failed."""
sample_task.completed_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_task.task_id), "failed", error_message="Error occurred")
assert sample_task.completed_at is not None
assert sample_task.error_message == "Error occurred"
def test_update_status_with_result_metrics(self, repo, sample_task):
"""Test update_status with result metrics."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_task.task_id),
"completed",
result_metrics={"mAP": 0.95},
)
assert sample_task.result_metrics == {"mAP": 0.95}
def test_update_status_with_model_path(self, repo, sample_task):
"""Test update_status with model path."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_task.task_id),
"completed",
model_path="/path/to/model.pt",
)
assert sample_task.model_path == "/path/to/model.pt"
def test_update_status_not_found(self, repo):
"""Test update_status does nothing when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(uuid4()), "running")
mock_session.add.assert_not_called()
# =========================================================================
# cancel() tests
# =========================================================================
def test_cancel_returns_true_for_pending(self, repo, sample_task):
"""Test cancel returns True for pending task."""
sample_task.status = "pending"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.cancel(str(sample_task.task_id))
assert result is True
assert sample_task.status == "cancelled"
def test_cancel_returns_true_for_scheduled(self, repo, sample_task):
"""Test cancel returns True for scheduled task."""
sample_task.status = "scheduled"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.cancel(str(sample_task.task_id))
assert result is True
assert sample_task.status == "cancelled"
def test_cancel_returns_false_for_running(self, repo, sample_task):
"""Test cancel returns False for running task."""
sample_task.status = "running"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.cancel(str(sample_task.task_id))
assert result is False
def test_cancel_returns_false_when_not_found(self, repo):
"""Test cancel returns False when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.cancel(str(uuid4()))
assert result is False
# =========================================================================
# add_log() tests
# =========================================================================
def test_add_log_creates_log_entry(self, repo):
"""Test add_log creates a log entry."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.add_log(
task_id=str(uuid4()),
level="INFO",
message="Training started",
)
mock_session.add.assert_called_once()
added_log = mock_session.add.call_args[0][0]
assert added_log.level == "INFO"
assert added_log.message == "Training started"
def test_add_log_with_details(self, repo):
"""Test add_log with details."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.add_log(
task_id=str(uuid4()),
level="DEBUG",
message="Epoch complete",
details={"epoch": 5, "loss": 0.05},
)
added_log = mock_session.add.call_args[0][0]
assert added_log.details == {"epoch": 5, "loss": 0.05}
# =========================================================================
# get_logs() tests
# =========================================================================
def test_get_logs_returns_list(self, repo, sample_log):
"""Test get_logs returns list of logs."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_log]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_logs(str(sample_log.task_id))
assert len(result) == 1
assert result[0].level == "INFO"
def test_get_logs_with_pagination(self, repo, sample_log):
"""Test get_logs with limit and offset."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_log]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_logs(str(sample_log.task_id), limit=50, offset=10)
assert len(result) == 1
def test_get_logs_returns_empty_list(self, repo):
"""Test get_logs returns empty list when no logs."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_logs(str(uuid4()))
assert result == []
# =========================================================================
# create_document_link() tests
# =========================================================================
def test_create_document_link_returns_link(self, repo):
"""Test create_document_link returns created link."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
task_id = uuid4()
document_id = uuid4()
result = repo.create_document_link(
task_id=task_id,
document_id=document_id,
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_document_link_with_snapshot(self, repo):
"""Test create_document_link with annotation snapshot."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
snapshot = {"annotations": [{"class_name": "invoice_number"}]}
repo.create_document_link(
task_id=uuid4(),
document_id=uuid4(),
annotation_snapshot=snapshot,
)
added_link = mock_session.add.call_args[0][0]
assert added_link.annotation_snapshot == snapshot
# =========================================================================
# get_document_links() tests
# =========================================================================
def test_get_document_links_returns_list(self, repo, sample_link):
"""Test get_document_links returns list of links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_link]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_links(sample_link.task_id)
assert len(result) == 1
def test_get_document_links_returns_empty_list(self, repo):
"""Test get_document_links returns empty list when no links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_links(uuid4())
assert result == []
# =========================================================================
# get_document_training_tasks() tests
# =========================================================================
def test_get_document_training_tasks_returns_list(self, repo, sample_link):
"""Test get_document_training_tasks returns list of links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_link]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_training_tasks(sample_link.document_id)
assert len(result) == 1
def test_get_document_training_tasks_returns_empty_list(self, repo):
"""Test get_document_training_tasks returns empty list when no links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_training_tasks(uuid4())
assert result == []