Files
invoice-master-poc-v2/tests/integration/repositories/test_annotation_repo_integration.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

465 lines
14 KiB
Python

"""
Annotation Repository Integration Tests
Tests AnnotationRepository with real database operations.
"""
from uuid import uuid4
import pytest
from backend.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepositoryCreate:
"""Tests for annotation creation."""
def test_create_annotation(self, patched_session, sample_document):
"""Test creating a single annotation."""
repo = AnnotationRepository()
ann_id = repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
)
assert ann_id is not None
ann = repo.get(ann_id)
assert ann is not None
assert ann.class_name == "invoice_number"
assert ann.text_value == "INV-2024-001"
assert ann.confidence == 0.95
assert ann.source == "auto"
def test_create_batch_annotations(self, patched_session, sample_document):
"""Test batch creation of annotations."""
repo = AnnotationRepository()
annotations_data = [
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.1,
"width": 0.2,
"height": 0.05,
"bbox_x": 400,
"bbox_y": 80,
"bbox_width": 160,
"bbox_height": 40,
"text_value": "INV-001",
"confidence": 0.95,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.5,
"y_center": 0.2,
"width": 0.15,
"height": 0.04,
"bbox_x": 400,
"bbox_y": 160,
"bbox_width": 120,
"bbox_height": 32,
"text_value": "2024-01-15",
"confidence": 0.92,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 6,
"class_name": "amount",
"x_center": 0.7,
"y_center": 0.8,
"width": 0.1,
"height": 0.04,
"bbox_x": 560,
"bbox_y": 640,
"bbox_width": 80,
"bbox_height": 32,
"text_value": "1500.00",
"confidence": 0.98,
},
]
ids = repo.create_batch(annotations_data)
assert len(ids) == 3
# Verify all annotations exist
for ann_id in ids:
ann = repo.get(ann_id)
assert ann is not None
class TestAnnotationRepositoryRead:
"""Tests for annotation retrieval."""
def test_get_nonexistent_annotation(self, patched_session):
"""Test getting an annotation that doesn't exist."""
repo = AnnotationRepository()
ann = repo.get(str(uuid4()))
assert ann is None
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotations for a document."""
repo = AnnotationRepository()
# Add another annotation
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=320,
bbox_width=120,
bbox_height=32,
text_value="2024-01-15",
)
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 2
# Should be ordered by class_id
assert annotations[0].class_id == 0
assert annotations[1].class_id == 1
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
"""Test getting annotations for a specific page."""
repo = AnnotationRepository()
# Create annotations on different pages
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
)
repo.create(
document_id=str(sample_document.document_id),
page_number=2,
class_id=6,
class_name="amount",
x_center=0.7,
y_center=0.8,
width=0.1,
height=0.04,
bbox_x=560,
bbox_y=640,
bbox_width=80,
bbox_height=32,
)
page1_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=1,
)
page2_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=2,
)
assert len(page1_annotations) == 1
assert len(page2_annotations) == 1
assert page1_annotations[0].page_number == 1
assert page2_annotations[0].page_number == 2
class TestAnnotationRepositoryUpdate:
"""Tests for annotation updates."""
def test_update_annotation_bbox(self, patched_session, sample_annotation):
"""Test updating annotation bounding box."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=480,
bbox_y=320,
bbox_width=200,
bbox_height=48,
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.x_center == 0.6
assert ann.y_center == 0.4
assert ann.bbox_x == 480
assert ann.bbox_width == 200
def test_update_annotation_text(self, patched_session, sample_annotation):
"""Test updating annotation text value."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-2024-002",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.text_value == "INV-2024-002"
def test_update_annotation_class(self, patched_session, sample_annotation):
"""Test updating annotation class."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
class_id=1,
class_name="invoice_date",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.class_id == 1
assert ann.class_name == "invoice_date"
def test_update_nonexistent_annotation(self, patched_session):
"""Test updating annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.update(
str(uuid4()),
text_value="new value",
)
assert result is False
class TestAnnotationRepositoryDelete:
"""Tests for annotation deletion."""
def test_delete_annotation(self, patched_session, sample_annotation):
"""Test deleting a single annotation."""
repo = AnnotationRepository()
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is None
def test_delete_nonexistent_annotation(self, patched_session):
"""Test deleting annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_annotations_for_document(self, patched_session, sample_document):
"""Test deleting all annotations for a document."""
repo = AnnotationRepository()
# Create multiple annotations
for i in range(3):
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=i,
class_name=f"field_{i}",
x_center=0.5,
y_center=0.1 + i * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + i * 160,
bbox_width=160,
bbox_height=40,
)
# Delete all
count = repo.delete_for_document(str(sample_document.document_id))
assert count == 3
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_annotations_by_source(self, patched_session, sample_document):
"""Test deleting annotations by source type."""
repo = AnnotationRepository()
# Create auto and manual annotations
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
source="auto",
)
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.2,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=160,
bbox_width=120,
bbox_height=32,
source="manual",
)
# Delete only auto annotations
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
assert count == 1
remaining = repo.get_for_document(str(sample_document.document_id))
assert len(remaining) == 1
assert remaining[0].source == "manual"
class TestAnnotationVerification:
"""Tests for annotation verification."""
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
"""Test marking annotation as verified."""
repo = AnnotationRepository()
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
assert ann is not None
assert ann.is_verified is True
assert ann.verified_by == admin_token.token
assert ann.verified_at is not None
class TestAnnotationOverride:
"""Tests for annotation override functionality."""
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
"""Test overriding an auto-generated annotation."""
repo = AnnotationRepository()
# Override the annotation
ann = repo.override(
str(sample_annotation.annotation_id),
admin_token.token,
change_reason="Correcting OCR error",
text_value="INV-2024-CORRECTED",
x_center=0.55,
)
assert ann is not None
assert ann.text_value == "INV-2024-CORRECTED"
assert ann.x_center == 0.55
assert ann.source == "manual" # Changed from auto to manual
assert ann.override_source == "auto"
class TestAnnotationHistory:
"""Tests for annotation history tracking."""
def test_create_history_record(self, patched_session, sample_annotation):
"""Test creating annotation history record."""
repo = AnnotationRepository()
history = repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
changed_by="test-user",
)
assert history is not None
assert history.action == "created"
assert history.changed_by == "test-user"
def test_get_annotation_history(self, patched_session, sample_annotation):
"""Test getting history for an annotation."""
repo = AnnotationRepository()
# Create history records
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
)
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="updated",
previous_value={"text_value": "INV-001"},
new_value={"text_value": "INV-002"},
)
history = repo.get_history(sample_annotation.annotation_id)
assert len(history) == 2
# Should be ordered by created_at desc
assert history[0].action == "updated"
assert history[1].action == "created"
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotation history for a document."""
repo = AnnotationRepository()
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="created",
new_value={"class_name": "invoice_number"},
)
history = repo.get_document_history(sample_document.document_id)
assert len(history) >= 1
assert all(h.document_id == sample_document.document_id for h in history)