WIP
This commit is contained in:
711
tests/data/repositories/test_annotation_repository.py
Normal file
711
tests/data/repositories/test_annotation_repository.py
Normal file
@@ -0,0 +1,711 @@
|
||||
"""
|
||||
Tests for AnnotationRepository
|
||||
|
||||
100% coverage tests for annotation management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepository:
|
||||
"""Tests for AnnotationRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(self) -> AdminAnnotation:
|
||||
"""Create a sample annotation for testing."""
|
||||
return AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
text_value="INV-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_history(self) -> AnnotationHistory:
|
||||
"""Create a sample annotation history for testing."""
|
||||
return AnnotationHistory(
|
||||
history_id=uuid4(),
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
action="override",
|
||||
previous_value={"class_name": "old_class"},
|
||||
new_value={"class_name": "new_class"},
|
||||
changed_by="admin-token",
|
||||
change_reason="Correction",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> AnnotationRepository:
|
||||
"""Create an AnnotationRepository instance."""
|
||||
return AnnotationRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_annotation_id(self, repo):
|
||||
"""Test create returns annotation ID."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_create_with_optional_params(self, repo):
|
||||
"""Test create with optional text_value and confidence."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=2,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=200,
|
||||
bbox_y=300,
|
||||
bbox_width=100,
|
||||
bbox_height=25,
|
||||
text_value="2024-01-15",
|
||||
confidence=0.88,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.text_value == "2024-01-15"
|
||||
assert added_annotation.confidence == 0.88
|
||||
assert added_annotation.source == "auto"
|
||||
|
||||
def test_create_default_source_is_manual(self, repo):
|
||||
"""Test create uses manual as default source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.source == "manual"
|
||||
|
||||
# =========================================================================
|
||||
# create_batch() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_batch_returns_ids(self, repo):
|
||||
"""Test create_batch returns list of annotation IDs."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
},
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.6,
|
||||
"y_center": 0.4,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 200,
|
||||
"bbox_y": 300,
|
||||
"bbox_width": 100,
|
||||
"bbox_height": 25,
|
||||
},
|
||||
]
|
||||
|
||||
result = repo.create_batch(annotations)
|
||||
|
||||
assert len(result) == 2
|
||||
assert mock_session.add.call_count == 2
|
||||
assert mock_session.flush.call_count == 2
|
||||
|
||||
def test_create_batch_default_page_number(self, repo):
|
||||
"""Test create_batch uses page_number=1 by default."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
# no page_number
|
||||
},
|
||||
]
|
||||
|
||||
repo.create_batch(annotations)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.page_number == 1
|
||||
|
||||
def test_create_batch_with_all_optional_params(self, repo):
|
||||
"""Test create_batch with all optional parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"page_number": 3,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
"text_value": "INV-123",
|
||||
"confidence": 0.92,
|
||||
"source": "ocr",
|
||||
},
|
||||
]
|
||||
|
||||
repo.create_batch(annotations)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.page_number == 3
|
||||
assert added_annotation.text_value == "INV-123"
|
||||
assert added_annotation.confidence == 0.92
|
||||
assert added_annotation.source == "ocr"
|
||||
|
||||
def test_create_batch_empty_list(self, repo):
|
||||
"""Test create_batch with empty list returns empty."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_batch([])
|
||||
|
||||
assert result == []
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_annotation(self, repo, sample_annotation):
|
||||
"""Test get returns annotation when exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_annotation.annotation_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.class_name == "invoice_number"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_for_document() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
|
||||
"""Test get_for_document returns all annotations for document."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(sample_annotation.document_id))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].class_name == "invoice_number"
|
||||
|
||||
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
|
||||
"""Test get_for_document filters by page number."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(sample_annotation.document_id), page_number=1)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_for_document_returns_empty_list(self, repo):
|
||||
"""Test get_for_document returns empty list when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(uuid4()))
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# update() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_returns_true(self, repo, sample_annotation):
|
||||
"""Test update returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.text_value == "INV-002"
|
||||
|
||||
def test_update_returns_false_when_not_found(self, repo):
|
||||
"""Test update returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(str(uuid4()), text_value="INV-002")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_update_all_fields(self, repo, sample_annotation):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=150,
|
||||
bbox_y=250,
|
||||
bbox_width=175,
|
||||
bbox_height=35,
|
||||
text_value="NEW-VALUE",
|
||||
class_id=5,
|
||||
class_name="new_class",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.x_center == 0.6
|
||||
assert sample_annotation.y_center == 0.4
|
||||
assert sample_annotation.width == 0.25
|
||||
assert sample_annotation.height == 0.06
|
||||
assert sample_annotation.bbox_x == 150
|
||||
assert sample_annotation.bbox_y == 250
|
||||
assert sample_annotation.bbox_width == 175
|
||||
assert sample_annotation.bbox_height == 35
|
||||
assert sample_annotation.text_value == "NEW-VALUE"
|
||||
assert sample_annotation.class_id == 5
|
||||
assert sample_annotation.class_name == "new_class"
|
||||
|
||||
def test_update_partial_fields(self, repo, sample_annotation):
|
||||
"""Test update only updates provided fields."""
|
||||
original_x = sample_annotation.x_center
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="UPDATED",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.text_value == "UPDATED"
|
||||
assert sample_annotation.x_center == original_x # unchanged
|
||||
|
||||
# =========================================================================
|
||||
# delete() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_annotation):
|
||||
"""Test delete returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# delete_for_document() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_for_document_returns_count(self, repo, sample_annotation):
|
||||
"""Test delete_for_document returns count of deleted annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(sample_annotation.document_id))
|
||||
|
||||
assert result == 1
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
|
||||
"""Test delete_for_document filters by source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(sample_annotation.document_id), source="auto")
|
||||
|
||||
assert result == 1
|
||||
|
||||
def test_delete_for_document_returns_zero(self, repo):
|
||||
"""Test delete_for_document returns 0 when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(uuid4()))
|
||||
|
||||
assert result == 0
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# verify() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
|
||||
"""Test verify marks annotation as verified."""
|
||||
sample_annotation.is_verified = False
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.verify(str(sample_annotation.annotation_id), "admin-token")
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.is_verified is True
|
||||
assert sample_annotation.verified_by == "admin-token"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_verify_returns_none_when_not_found(self, repo):
|
||||
"""Test verify returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.verify(str(uuid4()), "admin-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
# =========================================================================
|
||||
# override() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_override_updates_annotation(self, repo, sample_annotation):
|
||||
"""Test override updates annotation and creates history."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
change_reason="Correction",
|
||||
text_value="NEW-VALUE",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.text_value == "NEW-VALUE"
|
||||
assert sample_annotation.source == "manual"
|
||||
assert sample_annotation.override_source == "auto"
|
||||
assert mock_session.add.call_count >= 2 # annotation + history
|
||||
|
||||
def test_override_returns_none_when_not_found(self, repo):
|
||||
"""Test override returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(str(uuid4()), "admin-token", text_value="NEW")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation):
|
||||
"""Test override does not change override_source if already manual."""
|
||||
sample_annotation.source = "manual"
|
||||
sample_annotation.override_source = None
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
text_value="NEW-VALUE",
|
||||
)
|
||||
|
||||
assert sample_annotation.source == "manual"
|
||||
assert sample_annotation.override_source is None
|
||||
|
||||
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
|
||||
"""Test override ignores unknown attributes."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
unknown_field="should_be_ignored",
|
||||
text_value="VALID",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.text_value == "VALID"
|
||||
assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored"
|
||||
|
||||
# =========================================================================
|
||||
# create_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_history_returns_history(self, repo):
|
||||
"""Test create_history returns created history record."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotation_id = uuid4()
|
||||
document_id = uuid4()
|
||||
result = repo.create_history(
|
||||
annotation_id=annotation_id,
|
||||
document_id=document_id,
|
||||
action="create",
|
||||
previous_value=None,
|
||||
new_value={"class_name": "invoice_number"},
|
||||
changed_by="admin-token",
|
||||
change_reason="Initial creation",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_history_with_minimal_params(self, repo):
|
||||
"""Test create_history with minimal parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
action="delete",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
added_history = mock_session.add.call_args[0][0]
|
||||
assert added_history.action == "delete"
|
||||
assert added_history.previous_value is None
|
||||
assert added_history.new_value is None
|
||||
|
||||
# =========================================================================
|
||||
# get_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_history(sample_history.annotation_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].action == "override"
|
||||
|
||||
def test_get_history_returns_empty_list(self, repo):
|
||||
"""Test get_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_history(uuid4())
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# get_document_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_document_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_document_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_history(sample_history.document_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_document_history_returns_empty_list(self, repo):
|
||||
"""Test get_document_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_history(uuid4())
|
||||
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user