712 lines
31 KiB
Python
712 lines
31 KiB
Python
"""
|
|
Tests for AnnotationRepository
|
|
|
|
100% coverage tests for annotation management.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import uuid4, UUID
|
|
|
|
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
|
from inference.data.repositories.annotation_repository import AnnotationRepository
|
|
|
|
|
|
class TestAnnotationRepository:
|
|
"""Tests for AnnotationRepository."""
|
|
|
|
@pytest.fixture
|
|
def sample_annotation(self) -> AdminAnnotation:
|
|
"""Create a sample annotation for testing."""
|
|
return AdminAnnotation(
|
|
annotation_id=uuid4(),
|
|
document_id=uuid4(),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=200,
|
|
bbox_width=150,
|
|
bbox_height=30,
|
|
text_value="INV-001",
|
|
confidence=0.95,
|
|
source="auto",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@pytest.fixture
|
|
def sample_history(self) -> AnnotationHistory:
|
|
"""Create a sample annotation history for testing."""
|
|
return AnnotationHistory(
|
|
history_id=uuid4(),
|
|
annotation_id=uuid4(),
|
|
document_id=uuid4(),
|
|
action="override",
|
|
previous_value={"class_name": "old_class"},
|
|
new_value={"class_name": "new_class"},
|
|
changed_by="admin-token",
|
|
change_reason="Correction",
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
@pytest.fixture
|
|
def repo(self) -> AnnotationRepository:
|
|
"""Create an AnnotationRepository instance."""
|
|
return AnnotationRepository()
|
|
|
|
# =========================================================================
|
|
# create() tests
|
|
# =========================================================================
|
|
|
|
def test_create_returns_annotation_id(self, repo):
|
|
"""Test create returns annotation ID."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.create(
|
|
document_id=str(uuid4()),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=200,
|
|
bbox_width=150,
|
|
bbox_height=30,
|
|
)
|
|
|
|
assert result is not None
|
|
mock_session.add.assert_called_once()
|
|
mock_session.flush.assert_called_once()
|
|
|
|
def test_create_with_optional_params(self, repo):
|
|
"""Test create with optional text_value and confidence."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.create(
|
|
document_id=str(uuid4()),
|
|
page_number=2,
|
|
class_id=1,
|
|
class_name="invoice_date",
|
|
x_center=0.6,
|
|
y_center=0.4,
|
|
width=0.15,
|
|
height=0.04,
|
|
bbox_x=200,
|
|
bbox_y=300,
|
|
bbox_width=100,
|
|
bbox_height=25,
|
|
text_value="2024-01-15",
|
|
confidence=0.88,
|
|
source="auto",
|
|
)
|
|
|
|
assert result is not None
|
|
mock_session.add.assert_called_once()
|
|
added_annotation = mock_session.add.call_args[0][0]
|
|
assert added_annotation.text_value == "2024-01-15"
|
|
assert added_annotation.confidence == 0.88
|
|
assert added_annotation.source == "auto"
|
|
|
|
def test_create_default_source_is_manual(self, repo):
|
|
"""Test create uses manual as default source."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.create(
|
|
document_id=str(uuid4()),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=200,
|
|
bbox_width=150,
|
|
bbox_height=30,
|
|
)
|
|
|
|
added_annotation = mock_session.add.call_args[0][0]
|
|
assert added_annotation.source == "manual"
|
|
|
|
# =========================================================================
|
|
# create_batch() tests
|
|
# =========================================================================
|
|
|
|
def test_create_batch_returns_ids(self, repo):
|
|
"""Test create_batch returns list of annotation IDs."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
annotations = [
|
|
{
|
|
"document_id": str(uuid4()),
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"x_center": 0.5,
|
|
"y_center": 0.3,
|
|
"width": 0.2,
|
|
"height": 0.05,
|
|
"bbox_x": 100,
|
|
"bbox_y": 200,
|
|
"bbox_width": 150,
|
|
"bbox_height": 30,
|
|
},
|
|
{
|
|
"document_id": str(uuid4()),
|
|
"class_id": 1,
|
|
"class_name": "invoice_date",
|
|
"x_center": 0.6,
|
|
"y_center": 0.4,
|
|
"width": 0.15,
|
|
"height": 0.04,
|
|
"bbox_x": 200,
|
|
"bbox_y": 300,
|
|
"bbox_width": 100,
|
|
"bbox_height": 25,
|
|
},
|
|
]
|
|
|
|
result = repo.create_batch(annotations)
|
|
|
|
assert len(result) == 2
|
|
assert mock_session.add.call_count == 2
|
|
assert mock_session.flush.call_count == 2
|
|
|
|
def test_create_batch_default_page_number(self, repo):
|
|
"""Test create_batch uses page_number=1 by default."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
annotations = [
|
|
{
|
|
"document_id": str(uuid4()),
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"x_center": 0.5,
|
|
"y_center": 0.3,
|
|
"width": 0.2,
|
|
"height": 0.05,
|
|
"bbox_x": 100,
|
|
"bbox_y": 200,
|
|
"bbox_width": 150,
|
|
"bbox_height": 30,
|
|
# no page_number
|
|
},
|
|
]
|
|
|
|
repo.create_batch(annotations)
|
|
|
|
added_annotation = mock_session.add.call_args[0][0]
|
|
assert added_annotation.page_number == 1
|
|
|
|
def test_create_batch_with_all_optional_params(self, repo):
|
|
"""Test create_batch with all optional parameters."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
annotations = [
|
|
{
|
|
"document_id": str(uuid4()),
|
|
"page_number": 3,
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"x_center": 0.5,
|
|
"y_center": 0.3,
|
|
"width": 0.2,
|
|
"height": 0.05,
|
|
"bbox_x": 100,
|
|
"bbox_y": 200,
|
|
"bbox_width": 150,
|
|
"bbox_height": 30,
|
|
"text_value": "INV-123",
|
|
"confidence": 0.92,
|
|
"source": "ocr",
|
|
},
|
|
]
|
|
|
|
repo.create_batch(annotations)
|
|
|
|
added_annotation = mock_session.add.call_args[0][0]
|
|
assert added_annotation.page_number == 3
|
|
assert added_annotation.text_value == "INV-123"
|
|
assert added_annotation.confidence == 0.92
|
|
assert added_annotation.source == "ocr"
|
|
|
|
def test_create_batch_empty_list(self, repo):
|
|
"""Test create_batch with empty list returns empty."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.create_batch([])
|
|
|
|
assert result == []
|
|
mock_session.add.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# get() tests
|
|
# =========================================================================
|
|
|
|
def test_get_returns_annotation(self, repo, sample_annotation):
|
|
"""Test get returns annotation when exists."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get(str(sample_annotation.annotation_id))
|
|
|
|
assert result is not None
|
|
assert result.class_name == "invoice_number"
|
|
mock_session.expunge.assert_called_once()
|
|
|
|
def test_get_returns_none_when_not_found(self, repo):
|
|
"""Test get returns None when annotation not found."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get(str(uuid4()))
|
|
|
|
assert result is None
|
|
mock_session.expunge.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# get_for_document() tests
|
|
# =========================================================================
|
|
|
|
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
|
|
"""Test get_for_document returns all annotations for document."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_for_document(str(sample_annotation.document_id))
|
|
|
|
assert len(result) == 1
|
|
assert result[0].class_name == "invoice_number"
|
|
|
|
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
|
|
"""Test get_for_document filters by page number."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_for_document(str(sample_annotation.document_id), page_number=1)
|
|
|
|
assert len(result) == 1
|
|
|
|
def test_get_for_document_returns_empty_list(self, repo):
|
|
"""Test get_for_document returns empty list when no annotations."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_for_document(str(uuid4()))
|
|
|
|
assert result == []
|
|
|
|
# =========================================================================
|
|
# update() tests
|
|
# =========================================================================
|
|
|
|
def test_update_returns_true(self, repo, sample_annotation):
|
|
"""Test update returns True when annotation exists."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
text_value="INV-002",
|
|
)
|
|
|
|
assert result is True
|
|
assert sample_annotation.text_value == "INV-002"
|
|
|
|
def test_update_returns_false_when_not_found(self, repo):
|
|
"""Test update returns False when annotation not found."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(str(uuid4()), text_value="INV-002")
|
|
|
|
assert result is False
|
|
|
|
def test_update_all_fields(self, repo, sample_annotation):
|
|
"""Test update can update all fields."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
x_center=0.6,
|
|
y_center=0.4,
|
|
width=0.25,
|
|
height=0.06,
|
|
bbox_x=150,
|
|
bbox_y=250,
|
|
bbox_width=175,
|
|
bbox_height=35,
|
|
text_value="NEW-VALUE",
|
|
class_id=5,
|
|
class_name="new_class",
|
|
)
|
|
|
|
assert result is True
|
|
assert sample_annotation.x_center == 0.6
|
|
assert sample_annotation.y_center == 0.4
|
|
assert sample_annotation.width == 0.25
|
|
assert sample_annotation.height == 0.06
|
|
assert sample_annotation.bbox_x == 150
|
|
assert sample_annotation.bbox_y == 250
|
|
assert sample_annotation.bbox_width == 175
|
|
assert sample_annotation.bbox_height == 35
|
|
assert sample_annotation.text_value == "NEW-VALUE"
|
|
assert sample_annotation.class_id == 5
|
|
assert sample_annotation.class_name == "new_class"
|
|
|
|
def test_update_partial_fields(self, repo, sample_annotation):
|
|
"""Test update only updates provided fields."""
|
|
original_x = sample_annotation.x_center
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
text_value="UPDATED",
|
|
)
|
|
|
|
assert result is True
|
|
assert sample_annotation.text_value == "UPDATED"
|
|
assert sample_annotation.x_center == original_x # unchanged
|
|
|
|
# =========================================================================
|
|
# delete() tests
|
|
# =========================================================================
|
|
|
|
def test_delete_returns_true(self, repo, sample_annotation):
|
|
"""Test delete returns True when annotation exists."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(str(sample_annotation.annotation_id))
|
|
|
|
assert result is True
|
|
mock_session.delete.assert_called_once()
|
|
|
|
def test_delete_returns_false_when_not_found(self, repo):
|
|
"""Test delete returns False when annotation not found."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete(str(uuid4()))
|
|
|
|
assert result is False
|
|
mock_session.delete.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# delete_for_document() tests
|
|
# =========================================================================
|
|
|
|
def test_delete_for_document_returns_count(self, repo, sample_annotation):
|
|
"""Test delete_for_document returns count of deleted annotations."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete_for_document(str(sample_annotation.document_id))
|
|
|
|
assert result == 1
|
|
mock_session.delete.assert_called_once()
|
|
|
|
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
|
|
"""Test delete_for_document filters by source."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete_for_document(str(sample_annotation.document_id), source="auto")
|
|
|
|
assert result == 1
|
|
|
|
def test_delete_for_document_returns_zero(self, repo):
|
|
"""Test delete_for_document returns 0 when no annotations."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.delete_for_document(str(uuid4()))
|
|
|
|
assert result == 0
|
|
mock_session.delete.assert_not_called()
|
|
|
|
# =========================================================================
|
|
# verify() tests
|
|
# =========================================================================
|
|
|
|
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
|
|
"""Test verify marks annotation as verified."""
|
|
sample_annotation.is_verified = False
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.verify(str(sample_annotation.annotation_id), "admin-token")
|
|
|
|
assert result is not None
|
|
assert sample_annotation.is_verified is True
|
|
assert sample_annotation.verified_by == "admin-token"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_verify_returns_none_when_not_found(self, repo):
|
|
"""Test verify returns None when annotation not found."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.verify(str(uuid4()), "admin-token")
|
|
|
|
assert result is None
|
|
|
|
# =========================================================================
|
|
# override() tests
|
|
# =========================================================================
|
|
|
|
def test_override_updates_annotation(self, repo, sample_annotation):
|
|
"""Test override updates annotation and creates history."""
|
|
sample_annotation.source = "auto"
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.override(
|
|
str(sample_annotation.annotation_id),
|
|
"admin-token",
|
|
change_reason="Correction",
|
|
text_value="NEW-VALUE",
|
|
)
|
|
|
|
assert result is not None
|
|
assert sample_annotation.text_value == "NEW-VALUE"
|
|
assert sample_annotation.source == "manual"
|
|
assert sample_annotation.override_source == "auto"
|
|
assert mock_session.add.call_count >= 2 # annotation + history
|
|
|
|
def test_override_returns_none_when_not_found(self, repo):
|
|
"""Test override returns None when annotation not found."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = None
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.override(str(uuid4()), "admin-token", text_value="NEW")
|
|
|
|
assert result is None
|
|
|
|
def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation):
|
|
"""Test override does not change override_source if already manual."""
|
|
sample_annotation.source = "manual"
|
|
sample_annotation.override_source = None
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.override(
|
|
str(sample_annotation.annotation_id),
|
|
"admin-token",
|
|
text_value="NEW-VALUE",
|
|
)
|
|
|
|
assert sample_annotation.source == "manual"
|
|
assert sample_annotation.override_source is None
|
|
|
|
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
|
|
"""Test override ignores unknown attributes."""
|
|
sample_annotation.source = "auto"
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.get.return_value = sample_annotation
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.override(
|
|
str(sample_annotation.annotation_id),
|
|
"admin-token",
|
|
unknown_field="should_be_ignored",
|
|
text_value="VALID",
|
|
)
|
|
|
|
assert result is not None
|
|
assert sample_annotation.text_value == "VALID"
|
|
assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored"
|
|
|
|
# =========================================================================
|
|
# create_history() tests
|
|
# =========================================================================
|
|
|
|
def test_create_history_returns_history(self, repo):
|
|
"""Test create_history returns created history record."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
annotation_id = uuid4()
|
|
document_id = uuid4()
|
|
result = repo.create_history(
|
|
annotation_id=annotation_id,
|
|
document_id=document_id,
|
|
action="create",
|
|
previous_value=None,
|
|
new_value={"class_name": "invoice_number"},
|
|
changed_by="admin-token",
|
|
change_reason="Initial creation",
|
|
)
|
|
|
|
mock_session.add.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_create_history_with_minimal_params(self, repo):
|
|
"""Test create_history with minimal parameters."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
repo.create_history(
|
|
annotation_id=uuid4(),
|
|
document_id=uuid4(),
|
|
action="delete",
|
|
)
|
|
|
|
mock_session.add.assert_called_once()
|
|
added_history = mock_session.add.call_args[0][0]
|
|
assert added_history.action == "delete"
|
|
assert added_history.previous_value is None
|
|
assert added_history.new_value is None
|
|
|
|
# =========================================================================
|
|
# get_history() tests
|
|
# =========================================================================
|
|
|
|
def test_get_history_returns_list(self, repo, sample_history):
|
|
"""Test get_history returns list of history records."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_history]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_history(sample_history.annotation_id)
|
|
|
|
assert len(result) == 1
|
|
assert result[0].action == "override"
|
|
|
|
def test_get_history_returns_empty_list(self, repo):
|
|
"""Test get_history returns empty list when no history."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_history(uuid4())
|
|
|
|
assert result == []
|
|
|
|
# =========================================================================
|
|
# get_document_history() tests
|
|
# =========================================================================
|
|
|
|
def test_get_document_history_returns_list(self, repo, sample_history):
|
|
"""Test get_document_history returns list of history records."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = [sample_history]
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_document_history(sample_history.document_id)
|
|
|
|
assert len(result) == 1
|
|
|
|
def test_get_document_history_returns_empty_list(self, repo):
|
|
"""Test get_document_history returns empty list when no history."""
|
|
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_session.exec.return_value.all.return_value = []
|
|
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
|
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
|
|
|
result = repo.get_document_history(uuid4())
|
|
|
|
assert result == []
|