465 lines
14 KiB
Python
465 lines
14 KiB
Python
"""
|
|
Annotation Repository Integration Tests
|
|
|
|
Tests AnnotationRepository with real database operations.
|
|
"""
|
|
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from inference.data.repositories.annotation_repository import AnnotationRepository
|
|
|
|
|
|
class TestAnnotationRepositoryCreate:
|
|
"""Tests for annotation creation."""
|
|
|
|
def test_create_annotation(self, patched_session, sample_document):
|
|
"""Test creating a single annotation."""
|
|
repo = AnnotationRepository()
|
|
|
|
ann_id = repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.3,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=240,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
text_value="INV-2024-001",
|
|
confidence=0.95,
|
|
source="auto",
|
|
)
|
|
|
|
assert ann_id is not None
|
|
|
|
ann = repo.get(ann_id)
|
|
assert ann is not None
|
|
assert ann.class_name == "invoice_number"
|
|
assert ann.text_value == "INV-2024-001"
|
|
assert ann.confidence == 0.95
|
|
assert ann.source == "auto"
|
|
|
|
def test_create_batch_annotations(self, patched_session, sample_document):
|
|
"""Test batch creation of annotations."""
|
|
repo = AnnotationRepository()
|
|
|
|
annotations_data = [
|
|
{
|
|
"document_id": str(sample_document.document_id),
|
|
"page_number": 1,
|
|
"class_id": 0,
|
|
"class_name": "invoice_number",
|
|
"x_center": 0.5,
|
|
"y_center": 0.1,
|
|
"width": 0.2,
|
|
"height": 0.05,
|
|
"bbox_x": 400,
|
|
"bbox_y": 80,
|
|
"bbox_width": 160,
|
|
"bbox_height": 40,
|
|
"text_value": "INV-001",
|
|
"confidence": 0.95,
|
|
},
|
|
{
|
|
"document_id": str(sample_document.document_id),
|
|
"page_number": 1,
|
|
"class_id": 1,
|
|
"class_name": "invoice_date",
|
|
"x_center": 0.5,
|
|
"y_center": 0.2,
|
|
"width": 0.15,
|
|
"height": 0.04,
|
|
"bbox_x": 400,
|
|
"bbox_y": 160,
|
|
"bbox_width": 120,
|
|
"bbox_height": 32,
|
|
"text_value": "2024-01-15",
|
|
"confidence": 0.92,
|
|
},
|
|
{
|
|
"document_id": str(sample_document.document_id),
|
|
"page_number": 1,
|
|
"class_id": 6,
|
|
"class_name": "amount",
|
|
"x_center": 0.7,
|
|
"y_center": 0.8,
|
|
"width": 0.1,
|
|
"height": 0.04,
|
|
"bbox_x": 560,
|
|
"bbox_y": 640,
|
|
"bbox_width": 80,
|
|
"bbox_height": 32,
|
|
"text_value": "1500.00",
|
|
"confidence": 0.98,
|
|
},
|
|
]
|
|
|
|
ids = repo.create_batch(annotations_data)
|
|
|
|
assert len(ids) == 3
|
|
|
|
# Verify all annotations exist
|
|
for ann_id in ids:
|
|
ann = repo.get(ann_id)
|
|
assert ann is not None
|
|
|
|
|
|
class TestAnnotationRepositoryRead:
|
|
"""Tests for annotation retrieval."""
|
|
|
|
def test_get_nonexistent_annotation(self, patched_session):
|
|
"""Test getting an annotation that doesn't exist."""
|
|
repo = AnnotationRepository()
|
|
|
|
ann = repo.get(str(uuid4()))
|
|
assert ann is None
|
|
|
|
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
|
|
"""Test getting all annotations for a document."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Add another annotation
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=1,
|
|
class_name="invoice_date",
|
|
x_center=0.5,
|
|
y_center=0.4,
|
|
width=0.15,
|
|
height=0.04,
|
|
bbox_x=400,
|
|
bbox_y=320,
|
|
bbox_width=120,
|
|
bbox_height=32,
|
|
text_value="2024-01-15",
|
|
)
|
|
|
|
annotations = repo.get_for_document(str(sample_document.document_id))
|
|
|
|
assert len(annotations) == 2
|
|
# Should be ordered by class_id
|
|
assert annotations[0].class_id == 0
|
|
assert annotations[1].class_id == 1
|
|
|
|
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
|
|
"""Test getting annotations for a specific page."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Create annotations on different pages
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.1,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=80,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
)
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=2,
|
|
class_id=6,
|
|
class_name="amount",
|
|
x_center=0.7,
|
|
y_center=0.8,
|
|
width=0.1,
|
|
height=0.04,
|
|
bbox_x=560,
|
|
bbox_y=640,
|
|
bbox_width=80,
|
|
bbox_height=32,
|
|
)
|
|
|
|
page1_annotations = repo.get_for_document(
|
|
str(sample_document.document_id),
|
|
page_number=1,
|
|
)
|
|
page2_annotations = repo.get_for_document(
|
|
str(sample_document.document_id),
|
|
page_number=2,
|
|
)
|
|
|
|
assert len(page1_annotations) == 1
|
|
assert len(page2_annotations) == 1
|
|
assert page1_annotations[0].page_number == 1
|
|
assert page2_annotations[0].page_number == 2
|
|
|
|
|
|
class TestAnnotationRepositoryUpdate:
|
|
"""Tests for annotation updates."""
|
|
|
|
def test_update_annotation_bbox(self, patched_session, sample_annotation):
|
|
"""Test updating annotation bounding box."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
x_center=0.6,
|
|
y_center=0.4,
|
|
width=0.25,
|
|
height=0.06,
|
|
bbox_x=480,
|
|
bbox_y=320,
|
|
bbox_width=200,
|
|
bbox_height=48,
|
|
)
|
|
|
|
assert result is True
|
|
|
|
ann = repo.get(str(sample_annotation.annotation_id))
|
|
assert ann is not None
|
|
assert ann.x_center == 0.6
|
|
assert ann.y_center == 0.4
|
|
assert ann.bbox_x == 480
|
|
assert ann.bbox_width == 200
|
|
|
|
def test_update_annotation_text(self, patched_session, sample_annotation):
|
|
"""Test updating annotation text value."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
text_value="INV-2024-002",
|
|
)
|
|
|
|
assert result is True
|
|
|
|
ann = repo.get(str(sample_annotation.annotation_id))
|
|
assert ann is not None
|
|
assert ann.text_value == "INV-2024-002"
|
|
|
|
def test_update_annotation_class(self, patched_session, sample_annotation):
|
|
"""Test updating annotation class."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.update(
|
|
str(sample_annotation.annotation_id),
|
|
class_id=1,
|
|
class_name="invoice_date",
|
|
)
|
|
|
|
assert result is True
|
|
|
|
ann = repo.get(str(sample_annotation.annotation_id))
|
|
assert ann is not None
|
|
assert ann.class_id == 1
|
|
assert ann.class_name == "invoice_date"
|
|
|
|
def test_update_nonexistent_annotation(self, patched_session):
|
|
"""Test updating annotation that doesn't exist."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.update(
|
|
str(uuid4()),
|
|
text_value="new value",
|
|
)
|
|
|
|
assert result is False
|
|
|
|
|
|
class TestAnnotationRepositoryDelete:
|
|
"""Tests for annotation deletion."""
|
|
|
|
def test_delete_annotation(self, patched_session, sample_annotation):
|
|
"""Test deleting a single annotation."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.delete(str(sample_annotation.annotation_id))
|
|
assert result is True
|
|
|
|
ann = repo.get(str(sample_annotation.annotation_id))
|
|
assert ann is None
|
|
|
|
def test_delete_nonexistent_annotation(self, patched_session):
|
|
"""Test deleting annotation that doesn't exist."""
|
|
repo = AnnotationRepository()
|
|
|
|
result = repo.delete(str(uuid4()))
|
|
assert result is False
|
|
|
|
def test_delete_annotations_for_document(self, patched_session, sample_document):
|
|
"""Test deleting all annotations for a document."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Create multiple annotations
|
|
for i in range(3):
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=i,
|
|
class_name=f"field_{i}",
|
|
x_center=0.5,
|
|
y_center=0.1 + i * 0.2,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=80 + i * 160,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
)
|
|
|
|
# Delete all
|
|
count = repo.delete_for_document(str(sample_document.document_id))
|
|
|
|
assert count == 3
|
|
|
|
annotations = repo.get_for_document(str(sample_document.document_id))
|
|
assert len(annotations) == 0
|
|
|
|
def test_delete_annotations_by_source(self, patched_session, sample_document):
|
|
"""Test deleting annotations by source type."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Create auto and manual annotations
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.1,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=400,
|
|
bbox_y=80,
|
|
bbox_width=160,
|
|
bbox_height=40,
|
|
source="auto",
|
|
)
|
|
repo.create(
|
|
document_id=str(sample_document.document_id),
|
|
page_number=1,
|
|
class_id=1,
|
|
class_name="invoice_date",
|
|
x_center=0.5,
|
|
y_center=0.2,
|
|
width=0.15,
|
|
height=0.04,
|
|
bbox_x=400,
|
|
bbox_y=160,
|
|
bbox_width=120,
|
|
bbox_height=32,
|
|
source="manual",
|
|
)
|
|
|
|
# Delete only auto annotations
|
|
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
|
|
|
|
assert count == 1
|
|
|
|
remaining = repo.get_for_document(str(sample_document.document_id))
|
|
assert len(remaining) == 1
|
|
assert remaining[0].source == "manual"
|
|
|
|
|
|
class TestAnnotationVerification:
|
|
"""Tests for annotation verification."""
|
|
|
|
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
|
|
"""Test marking annotation as verified."""
|
|
repo = AnnotationRepository()
|
|
|
|
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
|
|
|
|
assert ann is not None
|
|
assert ann.is_verified is True
|
|
assert ann.verified_by == admin_token.token
|
|
assert ann.verified_at is not None
|
|
|
|
|
|
class TestAnnotationOverride:
|
|
"""Tests for annotation override functionality."""
|
|
|
|
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
|
|
"""Test overriding an auto-generated annotation."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Override the annotation
|
|
ann = repo.override(
|
|
str(sample_annotation.annotation_id),
|
|
admin_token.token,
|
|
change_reason="Correcting OCR error",
|
|
text_value="INV-2024-CORRECTED",
|
|
x_center=0.55,
|
|
)
|
|
|
|
assert ann is not None
|
|
assert ann.text_value == "INV-2024-CORRECTED"
|
|
assert ann.x_center == 0.55
|
|
assert ann.source == "manual" # Changed from auto to manual
|
|
assert ann.override_source == "auto"
|
|
|
|
|
|
class TestAnnotationHistory:
|
|
"""Tests for annotation history tracking."""
|
|
|
|
def test_create_history_record(self, patched_session, sample_annotation):
|
|
"""Test creating annotation history record."""
|
|
repo = AnnotationRepository()
|
|
|
|
history = repo.create_history(
|
|
annotation_id=sample_annotation.annotation_id,
|
|
document_id=sample_annotation.document_id,
|
|
action="created",
|
|
new_value={"text_value": "INV-001"},
|
|
changed_by="test-user",
|
|
)
|
|
|
|
assert history is not None
|
|
assert history.action == "created"
|
|
assert history.changed_by == "test-user"
|
|
|
|
def test_get_annotation_history(self, patched_session, sample_annotation):
|
|
"""Test getting history for an annotation."""
|
|
repo = AnnotationRepository()
|
|
|
|
# Create history records
|
|
repo.create_history(
|
|
annotation_id=sample_annotation.annotation_id,
|
|
document_id=sample_annotation.document_id,
|
|
action="created",
|
|
new_value={"text_value": "INV-001"},
|
|
)
|
|
repo.create_history(
|
|
annotation_id=sample_annotation.annotation_id,
|
|
document_id=sample_annotation.document_id,
|
|
action="updated",
|
|
previous_value={"text_value": "INV-001"},
|
|
new_value={"text_value": "INV-002"},
|
|
)
|
|
|
|
history = repo.get_history(sample_annotation.annotation_id)
|
|
|
|
assert len(history) == 2
|
|
# Should be ordered by created_at desc
|
|
assert history[0].action == "updated"
|
|
assert history[1].action == "created"
|
|
|
|
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
|
|
"""Test getting all annotation history for a document."""
|
|
repo = AnnotationRepository()
|
|
|
|
repo.create_history(
|
|
annotation_id=sample_annotation.annotation_id,
|
|
document_id=sample_document.document_id,
|
|
action="created",
|
|
new_value={"class_name": "invoice_number"},
|
|
)
|
|
|
|
history = repo.get_document_history(sample_document.document_id)
|
|
|
|
assert len(history) >= 1
|
|
assert all(h.document_id == sample_document.document_id for h in history)
|