This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -0,0 +1,711 @@
"""
Tests for AnnotationRepository
100% coverage tests for annotation management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
from inference.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepository:
"""Tests for AnnotationRepository."""
@pytest.fixture
def sample_annotation(self) -> AdminAnnotation:
"""Create a sample annotation for testing."""
return AdminAnnotation(
annotation_id=uuid4(),
document_id=uuid4(),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
text_value="INV-001",
confidence=0.95,
source="auto",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_history(self) -> AnnotationHistory:
"""Create a sample annotation history for testing."""
return AnnotationHistory(
history_id=uuid4(),
annotation_id=uuid4(),
document_id=uuid4(),
action="override",
previous_value={"class_name": "old_class"},
new_value={"class_name": "new_class"},
changed_by="admin-token",
change_reason="Correction",
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def repo(self) -> AnnotationRepository:
"""Create an AnnotationRepository instance."""
return AnnotationRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_annotation_id(self, repo):
"""Test create returns annotation ID."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
document_id=str(uuid4()),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
)
assert result is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
def test_create_with_optional_params(self, repo):
"""Test create with optional text_value and confidence."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
document_id=str(uuid4()),
page_number=2,
class_id=1,
class_name="invoice_date",
x_center=0.6,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=200,
bbox_y=300,
bbox_width=100,
bbox_height=25,
text_value="2024-01-15",
confidence=0.88,
source="auto",
)
assert result is not None
mock_session.add.assert_called_once()
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.text_value == "2024-01-15"
assert added_annotation.confidence == 0.88
assert added_annotation.source == "auto"
def test_create_default_source_is_manual(self, repo):
"""Test create uses manual as default source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(
document_id=str(uuid4()),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=100,
bbox_y=200,
bbox_width=150,
bbox_height=30,
)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.source == "manual"
# =========================================================================
# create_batch() tests
# =========================================================================
def test_create_batch_returns_ids(self, repo):
"""Test create_batch returns list of annotation IDs."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
},
{
"document_id": str(uuid4()),
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.6,
"y_center": 0.4,
"width": 0.15,
"height": 0.04,
"bbox_x": 200,
"bbox_y": 300,
"bbox_width": 100,
"bbox_height": 25,
},
]
result = repo.create_batch(annotations)
assert len(result) == 2
assert mock_session.add.call_count == 2
assert mock_session.flush.call_count == 2
def test_create_batch_default_page_number(self, repo):
"""Test create_batch uses page_number=1 by default."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
# no page_number
},
]
repo.create_batch(annotations)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.page_number == 1
def test_create_batch_with_all_optional_params(self, repo):
"""Test create_batch with all optional parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotations = [
{
"document_id": str(uuid4()),
"page_number": 3,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.3,
"width": 0.2,
"height": 0.05,
"bbox_x": 100,
"bbox_y": 200,
"bbox_width": 150,
"bbox_height": 30,
"text_value": "INV-123",
"confidence": 0.92,
"source": "ocr",
},
]
repo.create_batch(annotations)
added_annotation = mock_session.add.call_args[0][0]
assert added_annotation.page_number == 3
assert added_annotation.text_value == "INV-123"
assert added_annotation.confidence == 0.92
assert added_annotation.source == "ocr"
def test_create_batch_empty_list(self, repo):
"""Test create_batch with empty list returns empty."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create_batch([])
assert result == []
mock_session.add.assert_not_called()
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_annotation(self, repo, sample_annotation):
"""Test get returns annotation when exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_annotation.annotation_id))
assert result is not None
assert result.class_name == "invoice_number"
mock_session.expunge.assert_called_once()
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_for_document() tests
# =========================================================================
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
"""Test get_for_document returns all annotations for document."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(sample_annotation.document_id))
assert len(result) == 1
assert result[0].class_name == "invoice_number"
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
"""Test get_for_document filters by page number."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(sample_annotation.document_id), page_number=1)
assert len(result) == 1
def test_get_for_document_returns_empty_list(self, repo):
"""Test get_for_document returns empty list when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_for_document(str(uuid4()))
assert result == []
# =========================================================================
# update() tests
# =========================================================================
def test_update_returns_true(self, repo, sample_annotation):
"""Test update returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-002",
)
assert result is True
assert sample_annotation.text_value == "INV-002"
def test_update_returns_false_when_not_found(self, repo):
"""Test update returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(str(uuid4()), text_value="INV-002")
assert result is False
def test_update_all_fields(self, repo, sample_annotation):
"""Test update can update all fields."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=150,
bbox_y=250,
bbox_width=175,
bbox_height=35,
text_value="NEW-VALUE",
class_id=5,
class_name="new_class",
)
assert result is True
assert sample_annotation.x_center == 0.6
assert sample_annotation.y_center == 0.4
assert sample_annotation.width == 0.25
assert sample_annotation.height == 0.06
assert sample_annotation.bbox_x == 150
assert sample_annotation.bbox_y == 250
assert sample_annotation.bbox_width == 175
assert sample_annotation.bbox_height == 35
assert sample_annotation.text_value == "NEW-VALUE"
assert sample_annotation.class_id == 5
assert sample_annotation.class_name == "new_class"
def test_update_partial_fields(self, repo, sample_annotation):
"""Test update only updates provided fields."""
original_x = sample_annotation.x_center
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.update(
str(sample_annotation.annotation_id),
text_value="UPDATED",
)
assert result is True
assert sample_annotation.text_value == "UPDATED"
assert sample_annotation.x_center == original_x # unchanged
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_annotation):
"""Test delete returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
mock_session.delete.assert_called_once()
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()
# =========================================================================
# delete_for_document() tests
# =========================================================================
def test_delete_for_document_returns_count(self, repo, sample_annotation):
"""Test delete_for_document returns count of deleted annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(sample_annotation.document_id))
assert result == 1
mock_session.delete.assert_called_once()
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
"""Test delete_for_document filters by source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(sample_annotation.document_id), source="auto")
assert result == 1
def test_delete_for_document_returns_zero(self, repo):
"""Test delete_for_document returns 0 when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete_for_document(str(uuid4()))
assert result == 0
mock_session.delete.assert_not_called()
# =========================================================================
# verify() tests
# =========================================================================
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
"""Test verify marks annotation as verified."""
sample_annotation.is_verified = False
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.verify(str(sample_annotation.annotation_id), "admin-token")
assert result is not None
assert sample_annotation.is_verified is True
assert sample_annotation.verified_by == "admin-token"
mock_session.commit.assert_called_once()
def test_verify_returns_none_when_not_found(self, repo):
"""Test verify returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.verify(str(uuid4()), "admin-token")
assert result is None
# =========================================================================
# override() tests
# =========================================================================
def test_override_updates_annotation(self, repo, sample_annotation):
"""Test override updates annotation and creates history."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(
str(sample_annotation.annotation_id),
"admin-token",
change_reason="Correction",
text_value="NEW-VALUE",
)
assert result is not None
assert sample_annotation.text_value == "NEW-VALUE"
assert sample_annotation.source == "manual"
assert sample_annotation.override_source == "auto"
assert mock_session.add.call_count >= 2 # annotation + history
def test_override_returns_none_when_not_found(self, repo):
"""Test override returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(str(uuid4()), "admin-token", text_value="NEW")
assert result is None
def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation):
"""Test override does not change override_source if already manual."""
sample_annotation.source = "manual"
sample_annotation.override_source = None
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.override(
str(sample_annotation.annotation_id),
"admin-token",
text_value="NEW-VALUE",
)
assert sample_annotation.source == "manual"
assert sample_annotation.override_source is None
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
"""Test override ignores unknown attributes."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.override(
str(sample_annotation.annotation_id),
"admin-token",
unknown_field="should_be_ignored",
text_value="VALID",
)
assert result is not None
assert sample_annotation.text_value == "VALID"
assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored"
# =========================================================================
# create_history() tests
# =========================================================================
def test_create_history_returns_history(self, repo):
"""Test create_history returns created history record."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
annotation_id = uuid4()
document_id = uuid4()
result = repo.create_history(
annotation_id=annotation_id,
document_id=document_id,
action="create",
previous_value=None,
new_value={"class_name": "invoice_number"},
changed_by="admin-token",
change_reason="Initial creation",
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_history_with_minimal_params(self, repo):
"""Test create_history with minimal parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create_history(
annotation_id=uuid4(),
document_id=uuid4(),
action="delete",
)
mock_session.add.assert_called_once()
added_history = mock_session.add.call_args[0][0]
assert added_history.action == "delete"
assert added_history.previous_value is None
assert added_history.new_value is None
# =========================================================================
# get_history() tests
# =========================================================================
def test_get_history_returns_list(self, repo, sample_history):
"""Test get_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_history(sample_history.annotation_id)
assert len(result) == 1
assert result[0].action == "override"
def test_get_history_returns_empty_list(self, repo):
"""Test get_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_history(uuid4())
assert result == []
# =========================================================================
# get_document_history() tests
# =========================================================================
def test_get_document_history_returns_list(self, repo, sample_history):
"""Test get_document_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_history(sample_history.document_id)
assert len(result) == 1
def test_get_document_history_returns_empty_list(self, repo):
"""Test get_document_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_document_history(uuid4())
assert result == []