198 lines
5.6 KiB
Python
198 lines
5.6 KiB
Python
"""
|
|
Tests for Admin Annotation Routes.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from src.data.admin_models import AdminAnnotation, AdminDocument, FIELD_CLASSES
|
|
from src.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
|
from src.web.schemas.admin import (
|
|
AnnotationCreate,
|
|
AnnotationUpdate,
|
|
AutoLabelRequest,
|
|
BoundingBox,
|
|
)
|
|
|
|
|
|
# Test UUIDs
|
|
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
|
TEST_ANN_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
|
TEST_TOKEN = "test-admin-token-12345"
|
|
|
|
|
|
class TestAnnotationRouterCreation:
|
|
"""Tests for annotation router creation."""
|
|
|
|
def test_creates_router_with_endpoints(self):
|
|
"""Test router is created with expected endpoints."""
|
|
router = create_annotation_router()
|
|
|
|
# Get route paths (includes prefix)
|
|
paths = [route.path for route in router.routes]
|
|
|
|
# Paths include the /admin/documents prefix
|
|
assert any("{document_id}/annotations" in p for p in paths)
|
|
assert any("{annotation_id}" in p for p in paths)
|
|
assert any("auto-label" in p for p in paths)
|
|
assert any("images" in p for p in paths)
|
|
|
|
|
|
class TestAnnotationCreateSchema:
|
|
"""Tests for AnnotationCreate schema."""
|
|
|
|
def test_valid_annotation(self):
|
|
"""Test valid annotation creation."""
|
|
ann = AnnotationCreate(
|
|
page_number=1,
|
|
class_id=0,
|
|
bbox=BoundingBox(x=100, y=100, width=200, height=50),
|
|
text_value="12345",
|
|
)
|
|
|
|
assert ann.page_number == 1
|
|
assert ann.class_id == 0
|
|
assert ann.bbox.x == 100
|
|
assert ann.text_value == "12345"
|
|
|
|
def test_class_id_range(self):
|
|
"""Test class_id must be 0-9."""
|
|
# Valid class IDs
|
|
for class_id in range(10):
|
|
ann = AnnotationCreate(
|
|
page_number=1,
|
|
class_id=class_id,
|
|
bbox=BoundingBox(x=0, y=0, width=100, height=50),
|
|
)
|
|
assert ann.class_id == class_id
|
|
|
|
def test_bbox_validation(self):
|
|
"""Test bounding box validation."""
|
|
bbox = BoundingBox(x=0, y=0, width=100, height=50)
|
|
assert bbox.width >= 1
|
|
assert bbox.height >= 1
|
|
|
|
|
|
class TestAnnotationUpdateSchema:
|
|
"""Tests for AnnotationUpdate schema."""
|
|
|
|
def test_partial_update(self):
|
|
"""Test partial update with only some fields."""
|
|
update = AnnotationUpdate(
|
|
text_value="new value",
|
|
)
|
|
|
|
assert update.text_value == "new value"
|
|
assert update.class_id is None
|
|
assert update.bbox is None
|
|
|
|
def test_bbox_update(self):
|
|
"""Test bounding box update."""
|
|
update = AnnotationUpdate(
|
|
bbox=BoundingBox(x=50, y=50, width=150, height=75),
|
|
)
|
|
|
|
assert update.bbox.x == 50
|
|
assert update.bbox.width == 150
|
|
|
|
|
|
class TestAutoLabelRequestSchema:
|
|
"""Tests for AutoLabelRequest schema."""
|
|
|
|
def test_valid_request(self):
|
|
"""Test valid auto-label request."""
|
|
request = AutoLabelRequest(
|
|
field_values={
|
|
"InvoiceNumber": "12345",
|
|
"Amount": "1000.00",
|
|
},
|
|
replace_existing=True,
|
|
)
|
|
|
|
assert len(request.field_values) == 2
|
|
assert request.field_values["InvoiceNumber"] == "12345"
|
|
assert request.replace_existing is True
|
|
|
|
def test_requires_field_values(self):
|
|
"""Test that field_values is required."""
|
|
with pytest.raises(Exception):
|
|
AutoLabelRequest(replace_existing=True)
|
|
|
|
|
|
class TestFieldClasses:
|
|
"""Tests for field class mapping."""
|
|
|
|
def test_all_classes_defined(self):
|
|
"""Test all 10 field classes are defined."""
|
|
assert len(FIELD_CLASSES) == 10
|
|
|
|
def test_class_ids_sequential(self):
|
|
"""Test class IDs are 0-9."""
|
|
assert set(FIELD_CLASSES.keys()) == set(range(10))
|
|
|
|
def test_known_field_names(self):
|
|
"""Test known field names are present."""
|
|
names = list(FIELD_CLASSES.values())
|
|
|
|
assert "invoice_number" in names
|
|
assert "invoice_date" in names
|
|
assert "amount" in names
|
|
assert "bankgiro" in names
|
|
assert "ocr_number" in names
|
|
|
|
|
|
class TestAnnotationModel:
|
|
"""Tests for AdminAnnotation model."""
|
|
|
|
def test_annotation_creation(self):
|
|
"""Test annotation model creation."""
|
|
ann = AdminAnnotation(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=100,
|
|
bbox_width=200,
|
|
bbox_height=50,
|
|
text_value="12345",
|
|
confidence=0.95,
|
|
source="manual",
|
|
)
|
|
|
|
assert str(ann.document_id) == TEST_DOC_UUID
|
|
assert ann.class_id == 0
|
|
assert ann.x_center == 0.5
|
|
assert ann.source == "manual"
|
|
|
|
def test_normalized_coordinates(self):
|
|
"""Test normalized coordinates are 0-1 range."""
|
|
# Valid normalized coords
|
|
ann = AdminAnnotation(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="test",
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=0,
|
|
bbox_y=0,
|
|
bbox_width=100,
|
|
bbox_height=50,
|
|
)
|
|
|
|
assert 0 <= ann.x_center <= 1
|
|
assert 0 <= ann.y_center <= 1
|
|
assert 0 <= ann.width <= 1
|
|
assert 0 <= ann.height <= 1
|