Add more tests
This commit is contained in:
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Annotation Repository Integration Tests
|
||||
|
||||
Tests AnnotationRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepositoryCreate:
|
||||
"""Tests for annotation creation."""
|
||||
|
||||
def test_create_annotation(self, patched_session, sample_document):
|
||||
"""Test creating a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann_id = repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert ann_id is not None
|
||||
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
assert ann.class_name == "invoice_number"
|
||||
assert ann.text_value == "INV-2024-001"
|
||||
assert ann.confidence == 0.95
|
||||
assert ann.source == "auto"
|
||||
|
||||
def test_create_batch_annotations(self, patched_session, sample_document):
|
||||
"""Test batch creation of annotations."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
annotations_data = [
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.1,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 80,
|
||||
"bbox_width": 160,
|
||||
"bbox_height": 40,
|
||||
"text_value": "INV-001",
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.2,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 160,
|
||||
"bbox_width": 120,
|
||||
"bbox_height": 32,
|
||||
"text_value": "2024-01-15",
|
||||
"confidence": 0.92,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"x_center": 0.7,
|
||||
"y_center": 0.8,
|
||||
"width": 0.1,
|
||||
"height": 0.04,
|
||||
"bbox_x": 560,
|
||||
"bbox_y": 640,
|
||||
"bbox_width": 80,
|
||||
"bbox_height": 32,
|
||||
"text_value": "1500.00",
|
||||
"confidence": 0.98,
|
||||
},
|
||||
]
|
||||
|
||||
ids = repo.create_batch(annotations_data)
|
||||
|
||||
assert len(ids) == 3
|
||||
|
||||
# Verify all annotations exist
|
||||
for ann_id in ids:
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
|
||||
|
||||
class TestAnnotationRepositoryRead:
|
||||
"""Tests for annotation retrieval."""
|
||||
|
||||
def test_get_nonexistent_annotation(self, patched_session):
|
||||
"""Test getting an annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.get(str(uuid4()))
|
||||
assert ann is None
|
||||
|
||||
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Add another annotation
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=320,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
text_value="2024-01-15",
|
||||
)
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
|
||||
assert len(annotations) == 2
|
||||
# Should be ordered by class_id
|
||||
assert annotations[0].class_id == 0
|
||||
assert annotations[1].class_id == 1
|
||||
|
||||
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
|
||||
"""Test getting annotations for a specific page."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create annotations on different pages
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=2,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
x_center=0.7,
|
||||
y_center=0.8,
|
||||
width=0.1,
|
||||
height=0.04,
|
||||
bbox_x=560,
|
||||
bbox_y=640,
|
||||
bbox_width=80,
|
||||
bbox_height=32,
|
||||
)
|
||||
|
||||
page1_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=1,
|
||||
)
|
||||
page2_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=2,
|
||||
)
|
||||
|
||||
assert len(page1_annotations) == 1
|
||||
assert len(page2_annotations) == 1
|
||||
assert page1_annotations[0].page_number == 1
|
||||
assert page2_annotations[0].page_number == 2
|
||||
|
||||
|
||||
class TestAnnotationRepositoryUpdate:
|
||||
"""Tests for annotation updates."""
|
||||
|
||||
def test_update_annotation_bbox(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation bounding box."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=480,
|
||||
bbox_y=320,
|
||||
bbox_width=200,
|
||||
bbox_height=48,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.x_center == 0.6
|
||||
assert ann.y_center == 0.4
|
||||
assert ann.bbox_x == 480
|
||||
assert ann.bbox_width == 200
|
||||
|
||||
def test_update_annotation_text(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation text value."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-2024-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-002"
|
||||
|
||||
def test_update_annotation_class(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation class."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.class_id == 1
|
||||
assert ann.class_name == "invoice_date"
|
||||
|
||||
def test_update_nonexistent_annotation(self, patched_session):
|
||||
"""Test updating annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(uuid4()),
|
||||
text_value="new value",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAnnotationRepositoryDelete:
|
||||
"""Tests for annotation deletion."""
|
||||
|
||||
def test_delete_annotation(self, patched_session, sample_annotation):
|
||||
"""Test deleting a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is None
|
||||
|
||||
def test_delete_nonexistent_annotation(self, patched_session):
|
||||
"""Test deleting annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_annotations_for_document(self, patched_session, sample_document):
|
||||
"""Test deleting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create multiple annotations
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=i,
|
||||
class_name=f"field_{i}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + i * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + i * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
|
||||
# Delete all
|
||||
count = repo.delete_for_document(str(sample_document.document_id))
|
||||
|
||||
assert count == 3
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_annotations_by_source(self, patched_session, sample_document):
|
||||
"""Test deleting annotations by source type."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create auto and manual annotations
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
source="auto",
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Delete only auto annotations
|
||||
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
|
||||
|
||||
assert count == 1
|
||||
|
||||
remaining = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].source == "manual"
|
||||
|
||||
|
||||
class TestAnnotationVerification:
|
||||
"""Tests for annotation verification."""
|
||||
|
||||
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test marking annotation as verified."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.is_verified is True
|
||||
assert ann.verified_by == admin_token.token
|
||||
assert ann.verified_at is not None
|
||||
|
||||
|
||||
class TestAnnotationOverride:
|
||||
"""Tests for annotation override functionality."""
|
||||
|
||||
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test overriding an auto-generated annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Override the annotation
|
||||
ann = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
admin_token.token,
|
||||
change_reason="Correcting OCR error",
|
||||
text_value="INV-2024-CORRECTED",
|
||||
x_center=0.55,
|
||||
)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-CORRECTED"
|
||||
assert ann.x_center == 0.55
|
||||
assert ann.source == "manual" # Changed from auto to manual
|
||||
assert ann.override_source == "auto"
|
||||
|
||||
|
||||
class TestAnnotationHistory:
|
||||
"""Tests for annotation history tracking."""
|
||||
|
||||
def test_create_history_record(self, patched_session, sample_annotation):
|
||||
"""Test creating annotation history record."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
history = repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
changed_by="test-user",
|
||||
)
|
||||
|
||||
assert history is not None
|
||||
assert history.action == "created"
|
||||
assert history.changed_by == "test-user"
|
||||
|
||||
def test_get_annotation_history(self, patched_session, sample_annotation):
|
||||
"""Test getting history for an annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create history records
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
)
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="updated",
|
||||
previous_value={"text_value": "INV-001"},
|
||||
new_value={"text_value": "INV-002"},
|
||||
)
|
||||
|
||||
history = repo.get_history(sample_annotation.annotation_id)
|
||||
|
||||
assert len(history) == 2
|
||||
# Should be ordered by created_at desc
|
||||
assert history[0].action == "updated"
|
||||
assert history[1].action == "created"
|
||||
|
||||
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotation history for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="created",
|
||||
new_value={"class_name": "invoice_number"},
|
||||
)
|
||||
|
||||
history = repo.get_document_history(sample_document.document_id)
|
||||
|
||||
assert len(history) >= 1
|
||||
assert all(h.document_id == sample_document.document_id for h in history)
|
||||
Reference in New Issue
Block a user