317 lines
9.9 KiB
Python
317 lines
9.9 KiB
Python
"""
|
|
Tests for Admin Annotation Routes.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import UUID
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
|
from shared.fields import FIELD_CLASSES
|
|
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
|
from inference.web.schemas.admin import (
|
|
AnnotationCreate,
|
|
AnnotationUpdate,
|
|
AutoLabelRequest,
|
|
BoundingBox,
|
|
)
|
|
|
|
|
|
# Test UUIDs
|
|
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
|
TEST_ANN_UUID = "660e8400-e29b-41d4-a716-446655440001"
|
|
TEST_TOKEN = "test-admin-token-12345"
|
|
|
|
|
|
class TestAnnotationRouterCreation:
|
|
"""Tests for annotation router creation."""
|
|
|
|
def test_creates_router_with_endpoints(self):
|
|
"""Test router is created with expected endpoints."""
|
|
router = create_annotation_router()
|
|
|
|
# Get route paths (includes prefix)
|
|
paths = [route.path for route in router.routes]
|
|
|
|
# Paths include the /admin/documents prefix
|
|
assert any("{document_id}/annotations" in p for p in paths)
|
|
assert any("{annotation_id}" in p for p in paths)
|
|
assert any("auto-label" in p for p in paths)
|
|
assert any("images" in p for p in paths)
|
|
|
|
|
|
class TestAnnotationCreateSchema:
|
|
"""Tests for AnnotationCreate schema."""
|
|
|
|
def test_valid_annotation(self):
|
|
"""Test valid annotation creation."""
|
|
ann = AnnotationCreate(
|
|
page_number=1,
|
|
class_id=0,
|
|
bbox=BoundingBox(x=100, y=100, width=200, height=50),
|
|
text_value="12345",
|
|
)
|
|
|
|
assert ann.page_number == 1
|
|
assert ann.class_id == 0
|
|
assert ann.bbox.x == 100
|
|
assert ann.text_value == "12345"
|
|
|
|
def test_class_id_range(self):
|
|
"""Test class_id must be 0-9."""
|
|
# Valid class IDs
|
|
for class_id in range(10):
|
|
ann = AnnotationCreate(
|
|
page_number=1,
|
|
class_id=class_id,
|
|
bbox=BoundingBox(x=0, y=0, width=100, height=50),
|
|
)
|
|
assert ann.class_id == class_id
|
|
|
|
def test_bbox_validation(self):
|
|
"""Test bounding box validation."""
|
|
bbox = BoundingBox(x=0, y=0, width=100, height=50)
|
|
assert bbox.width >= 1
|
|
assert bbox.height >= 1
|
|
|
|
|
|
class TestAnnotationUpdateSchema:
|
|
"""Tests for AnnotationUpdate schema."""
|
|
|
|
def test_partial_update(self):
|
|
"""Test partial update with only some fields."""
|
|
update = AnnotationUpdate(
|
|
text_value="new value",
|
|
)
|
|
|
|
assert update.text_value == "new value"
|
|
assert update.class_id is None
|
|
assert update.bbox is None
|
|
|
|
def test_bbox_update(self):
|
|
"""Test bounding box update."""
|
|
update = AnnotationUpdate(
|
|
bbox=BoundingBox(x=50, y=50, width=150, height=75),
|
|
)
|
|
|
|
assert update.bbox.x == 50
|
|
assert update.bbox.width == 150
|
|
|
|
|
|
class TestAutoLabelRequestSchema:
|
|
"""Tests for AutoLabelRequest schema."""
|
|
|
|
def test_valid_request(self):
|
|
"""Test valid auto-label request."""
|
|
request = AutoLabelRequest(
|
|
field_values={
|
|
"InvoiceNumber": "12345",
|
|
"Amount": "1000.00",
|
|
},
|
|
replace_existing=True,
|
|
)
|
|
|
|
assert len(request.field_values) == 2
|
|
assert request.field_values["InvoiceNumber"] == "12345"
|
|
assert request.replace_existing is True
|
|
|
|
def test_requires_field_values(self):
|
|
"""Test that field_values is required."""
|
|
with pytest.raises(Exception):
|
|
AutoLabelRequest(replace_existing=True)
|
|
|
|
|
|
class TestFieldClasses:
|
|
"""Tests for field class mapping."""
|
|
|
|
def test_all_classes_defined(self):
|
|
"""Test all 10 field classes are defined."""
|
|
assert len(FIELD_CLASSES) == 10
|
|
|
|
def test_class_ids_sequential(self):
|
|
"""Test class IDs are 0-9."""
|
|
assert set(FIELD_CLASSES.keys()) == set(range(10))
|
|
|
|
def test_known_field_names(self):
|
|
"""Test known field names are present."""
|
|
names = list(FIELD_CLASSES.values())
|
|
|
|
assert "invoice_number" in names
|
|
assert "invoice_date" in names
|
|
assert "amount" in names
|
|
assert "bankgiro" in names
|
|
assert "ocr_number" in names
|
|
|
|
|
|
class TestAnnotationModel:
|
|
"""Tests for AdminAnnotation model."""
|
|
|
|
def test_annotation_creation(self):
|
|
"""Test annotation model creation."""
|
|
ann = AdminAnnotation(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=100,
|
|
bbox_y=100,
|
|
bbox_width=200,
|
|
bbox_height=50,
|
|
text_value="12345",
|
|
confidence=0.95,
|
|
source="manual",
|
|
)
|
|
|
|
assert str(ann.document_id) == TEST_DOC_UUID
|
|
assert ann.class_id == 0
|
|
assert ann.x_center == 0.5
|
|
assert ann.source == "manual"
|
|
|
|
def test_normalized_coordinates(self):
|
|
"""Test normalized coordinates are 0-1 range."""
|
|
# Valid normalized coords
|
|
ann = AdminAnnotation(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
page_number=1,
|
|
class_id=0,
|
|
class_name="test",
|
|
x_center=0.5,
|
|
y_center=0.5,
|
|
width=0.2,
|
|
height=0.05,
|
|
bbox_x=0,
|
|
bbox_y=0,
|
|
bbox_width=100,
|
|
bbox_height=50,
|
|
)
|
|
|
|
assert 0 <= ann.x_center <= 1
|
|
assert 0 <= ann.y_center <= 1
|
|
assert 0 <= ann.width <= 1
|
|
assert 0 <= ann.height <= 1
|
|
|
|
|
|
class TestAutoLabelFilePathResolution:
|
|
"""Tests for auto-label file path resolution.
|
|
|
|
The auto-label endpoint needs to resolve the storage path (e.g., "raw_pdfs/uuid.pdf")
|
|
to an actual filesystem path via the storage helper.
|
|
"""
|
|
|
|
def test_extracts_filename_from_storage_path(self):
|
|
"""Test that filename is extracted from storage path correctly."""
|
|
# Storage paths are like "raw_pdfs/uuid.pdf"
|
|
storage_path = "raw_pdfs/550e8400-e29b-41d4-a716-446655440000.pdf"
|
|
|
|
# The annotation endpoint extracts filename
|
|
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
|
|
|
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
|
|
|
def test_handles_path_without_prefix(self):
|
|
"""Test that paths without prefix are handled."""
|
|
storage_path = "550e8400-e29b-41d4-a716-446655440000.pdf"
|
|
|
|
filename = storage_path.split("/")[-1] if "/" in storage_path else storage_path
|
|
|
|
assert filename == "550e8400-e29b-41d4-a716-446655440000.pdf"
|
|
|
|
def test_storage_helper_resolves_path(self):
|
|
"""Test that storage helper can resolve the path."""
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
# Mock storage helper
|
|
mock_storage = MagicMock()
|
|
mock_path = Path("/storage/raw_pdfs/test.pdf")
|
|
mock_storage.get_raw_pdf_local_path.return_value = mock_path
|
|
|
|
with patch(
|
|
"inference.web.services.storage_helpers.get_storage_helper",
|
|
return_value=mock_storage,
|
|
):
|
|
from inference.web.services.storage_helpers import get_storage_helper
|
|
|
|
storage = get_storage_helper()
|
|
result = storage.get_raw_pdf_local_path("test.pdf")
|
|
|
|
assert result == mock_path
|
|
mock_storage.get_raw_pdf_local_path.assert_called_once_with("test.pdf")
|
|
|
|
def test_auto_label_request_validation(self):
|
|
"""Test AutoLabelRequest validates field_values."""
|
|
# Valid request
|
|
request = AutoLabelRequest(
|
|
field_values={"InvoiceNumber": "12345"},
|
|
replace_existing=False,
|
|
)
|
|
assert request.field_values == {"InvoiceNumber": "12345"}
|
|
|
|
# Empty field_values should be valid at schema level
|
|
# (endpoint validates non-empty)
|
|
request_empty = AutoLabelRequest(
|
|
field_values={},
|
|
replace_existing=False,
|
|
)
|
|
assert request_empty.field_values == {}
|
|
|
|
|
|
class TestMatchClassAttributes:
|
|
"""Tests for Match class attributes used in auto-labeling.
|
|
|
|
The autolabel service uses Match objects from FieldMatcher.
|
|
Verifies the correct attribute names are used.
|
|
"""
|
|
|
|
def test_match_has_matched_text_attribute(self):
|
|
"""Test that Match class has matched_text attribute (not matched_value)."""
|
|
from shared.matcher.models import Match
|
|
|
|
# Create a Match object
|
|
match = Match(
|
|
field="invoice_number",
|
|
value="12345",
|
|
bbox=(100, 100, 200, 150),
|
|
page_no=0,
|
|
score=0.95,
|
|
matched_text="INV-12345",
|
|
context_keywords=["faktura", "nummer"],
|
|
)
|
|
|
|
# Verify matched_text exists (this is what autolabel.py should use)
|
|
assert hasattr(match, "matched_text")
|
|
assert match.matched_text == "INV-12345"
|
|
|
|
# Verify matched_value does NOT exist
|
|
# This was the bug - autolabel.py was using matched_value instead of matched_text
|
|
assert not hasattr(match, "matched_value")
|
|
|
|
def test_match_attributes_for_annotation_creation(self):
|
|
"""Test that Match has all attributes needed for annotation creation."""
|
|
from shared.matcher.models import Match
|
|
|
|
match = Match(
|
|
field="amount",
|
|
value="1000.00",
|
|
bbox=(50, 200, 150, 230),
|
|
page_no=0,
|
|
score=0.88,
|
|
matched_text="1 000,00",
|
|
context_keywords=["att betala", "summa"],
|
|
)
|
|
|
|
# These are all the attributes used in autolabel._create_annotations_from_matches
|
|
assert hasattr(match, "bbox")
|
|
assert hasattr(match, "matched_text") # NOT matched_value
|
|
assert hasattr(match, "score")
|
|
|
|
# Verify bbox format
|
|
assert len(match.bbox) == 4 # (x0, y0, x1, y1)
|