421 lines
14 KiB
Python
421 lines
14 KiB
Python
"""
|
|
Tests for Phase 5: Annotation Enhancement (Verification and Override)
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from src.web.api.v1.admin.annotations import create_annotation_router
|
|
from src.web.core.auth import validate_admin_token, get_admin_db
|
|
|
|
|
|
class MockAdminDocument:
|
|
"""Mock AdminDocument for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.document_id = kwargs.get('document_id', uuid4())
|
|
self.admin_token = kwargs.get('admin_token', 'test-token')
|
|
self.filename = kwargs.get('filename', 'test.pdf')
|
|
self.file_size = kwargs.get('file_size', 100000)
|
|
self.content_type = kwargs.get('content_type', 'application/pdf')
|
|
self.page_count = kwargs.get('page_count', 1)
|
|
self.status = kwargs.get('status', 'labeled')
|
|
self.auto_label_status = kwargs.get('auto_label_status', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockAnnotation:
|
|
"""Mock AdminAnnotation for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
|
self.document_id = kwargs.get('document_id')
|
|
self.page_number = kwargs.get('page_number', 1)
|
|
self.class_id = kwargs.get('class_id', 0)
|
|
self.class_name = kwargs.get('class_name', 'invoice_number')
|
|
self.bbox_x = kwargs.get('bbox_x', 100)
|
|
self.bbox_y = kwargs.get('bbox_y', 100)
|
|
self.bbox_width = kwargs.get('bbox_width', 200)
|
|
self.bbox_height = kwargs.get('bbox_height', 50)
|
|
self.x_center = kwargs.get('x_center', 0.5)
|
|
self.y_center = kwargs.get('y_center', 0.5)
|
|
self.width = kwargs.get('width', 0.3)
|
|
self.height = kwargs.get('height', 0.1)
|
|
self.text_value = kwargs.get('text_value', 'INV-001')
|
|
self.confidence = kwargs.get('confidence', 0.95)
|
|
self.source = kwargs.get('source', 'auto')
|
|
self.is_verified = kwargs.get('is_verified', False)
|
|
self.verified_at = kwargs.get('verified_at', None)
|
|
self.verified_by = kwargs.get('verified_by', None)
|
|
self.override_source = kwargs.get('override_source', None)
|
|
self.original_annotation_id = kwargs.get('original_annotation_id', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockAnnotationHistory:
|
|
"""Mock AnnotationHistory for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.history_id = kwargs.get('history_id', uuid4())
|
|
self.annotation_id = kwargs.get('annotation_id')
|
|
self.document_id = kwargs.get('document_id')
|
|
self.action = kwargs.get('action', 'override')
|
|
self.previous_value = kwargs.get('previous_value', {})
|
|
self.new_value = kwargs.get('new_value', {})
|
|
self.changed_by = kwargs.get('changed_by', 'test-token')
|
|
self.change_reason = kwargs.get('change_reason', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
|
|
|
|
class MockAdminDB:
|
|
"""Mock AdminDB for testing Phase 5."""
|
|
|
|
def __init__(self):
|
|
self.documents = {}
|
|
self.annotations = {}
|
|
self.annotation_history = {}
|
|
|
|
def get_document_by_token(self, document_id, admin_token):
|
|
"""Get document by ID and token."""
|
|
doc = self.documents.get(str(document_id))
|
|
if doc and doc.admin_token == admin_token:
|
|
return doc
|
|
return None
|
|
|
|
def verify_annotation(self, annotation_id, admin_token):
|
|
"""Mark annotation as verified."""
|
|
annotation = self.annotations.get(str(annotation_id))
|
|
if annotation:
|
|
annotation.is_verified = True
|
|
annotation.verified_at = datetime.utcnow()
|
|
annotation.verified_by = admin_token
|
|
return annotation
|
|
return None
|
|
|
|
def override_annotation(
|
|
self,
|
|
annotation_id,
|
|
admin_token,
|
|
change_reason=None,
|
|
**updates,
|
|
):
|
|
"""Override an annotation."""
|
|
annotation = self.annotations.get(str(annotation_id))
|
|
if annotation:
|
|
# Apply updates
|
|
for key, value in updates.items():
|
|
if hasattr(annotation, key):
|
|
setattr(annotation, key, value)
|
|
|
|
# Mark as overridden if was auto-generated
|
|
if annotation.source == "auto":
|
|
annotation.override_source = "auto"
|
|
annotation.source = "manual"
|
|
|
|
# Create history record
|
|
history = MockAnnotationHistory(
|
|
annotation_id=uuid4().hex if isinstance(annotation_id, str) else annotation_id,
|
|
document_id=annotation.document_id,
|
|
action="override",
|
|
changed_by=admin_token,
|
|
change_reason=change_reason,
|
|
)
|
|
self.annotation_history[str(annotation.annotation_id)] = [history]
|
|
|
|
return annotation
|
|
return None
|
|
|
|
def get_annotation_history(self, annotation_id):
|
|
"""Get annotation history."""
|
|
return self.annotation_history.get(str(annotation_id), [])
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create test FastAPI app."""
|
|
app = FastAPI()
|
|
|
|
# Create mock DB
|
|
mock_db = MockAdminDB()
|
|
|
|
# Add test document
|
|
doc1 = MockAdminDocument(
|
|
filename="TEST001.pdf",
|
|
status="labeled",
|
|
)
|
|
mock_db.documents[str(doc1.document_id)] = doc1
|
|
|
|
# Add test annotations
|
|
ann1 = MockAnnotation(
|
|
document_id=doc1.document_id,
|
|
class_id=0,
|
|
class_name="invoice_number",
|
|
text_value="INV-001",
|
|
source="auto",
|
|
confidence=0.95,
|
|
)
|
|
ann2 = MockAnnotation(
|
|
document_id=doc1.document_id,
|
|
class_id=6,
|
|
class_name="amount",
|
|
text_value="1500.00",
|
|
source="auto",
|
|
confidence=0.98,
|
|
)
|
|
|
|
mock_db.annotations[str(ann1.annotation_id)] = ann1
|
|
mock_db.annotations[str(ann2.annotation_id)] = ann2
|
|
|
|
# Store document ID and annotation IDs for tests
|
|
app.state.document_id = str(doc1.document_id)
|
|
app.state.annotation_id_1 = str(ann1.annotation_id)
|
|
app.state.annotation_id_2 = str(ann2.annotation_id)
|
|
|
|
# Override dependencies
|
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
|
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
|
|
|
# Include router
|
|
router = create_annotation_router()
|
|
app.include_router(router)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
class TestAnnotationVerification:
|
|
"""Tests for POST /admin/documents/{document_id}/annotations/{annotation_id}/verify endpoint."""
|
|
|
|
def test_verify_annotation_success(self, client, app):
|
|
"""Test successfully verifying an annotation."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.post(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/verify"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|
|
assert data["is_verified"] is True
|
|
assert data["verified_at"] is not None
|
|
assert data["verified_by"] == "test-token"
|
|
assert "verified successfully" in data["message"].lower()
|
|
|
|
def test_verify_annotation_not_found(self, client, app):
|
|
"""Test verifying non-existent annotation."""
|
|
document_id = app.state.document_id
|
|
fake_annotation_id = str(uuid4())
|
|
|
|
response = client.post(
|
|
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/verify"
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_verify_annotation_document_not_found(self, client):
|
|
"""Test verifying annotation with non-existent document."""
|
|
fake_document_id = str(uuid4())
|
|
fake_annotation_id = str(uuid4())
|
|
|
|
response = client.post(
|
|
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/verify"
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_verify_annotation_invalid_uuid(self, client, app):
|
|
"""Test verifying annotation with invalid UUID format."""
|
|
document_id = app.state.document_id
|
|
|
|
response = client.post(
|
|
f"/admin/documents/{document_id}/annotations/invalid-uuid/verify"
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "invalid" in response.json()["detail"].lower()
|
|
|
|
|
|
class TestAnnotationOverride:
|
|
"""Tests for PATCH /admin/documents/{document_id}/annotations/{annotation_id}/override endpoint."""
|
|
|
|
def test_override_annotation_text_value(self, client, app):
|
|
"""Test overriding annotation text value."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"text_value": "INV-001-CORRECTED",
|
|
"reason": "OCR error correction"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|
|
assert data["source"] == "manual"
|
|
assert data["override_source"] == "auto"
|
|
assert "successfully" in data["message"].lower()
|
|
assert "history_id" in data
|
|
|
|
def test_override_annotation_bbox(self, client, app):
|
|
"""Test overriding annotation bounding box."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"bbox": {
|
|
"x": 110,
|
|
"y": 205,
|
|
"width": 195,
|
|
"height": 48
|
|
},
|
|
"reason": "Bbox adjustment"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|
|
assert data["source"] == "manual"
|
|
|
|
def test_override_annotation_class(self, client, app):
|
|
"""Test overriding annotation class."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"class_id": 1,
|
|
"class_name": "invoice_date",
|
|
"reason": "Wrong field classification"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|
|
|
|
def test_override_annotation_multiple_fields(self, client, app):
|
|
"""Test overriding multiple annotation fields at once."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_2
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"text_value": "1550.00",
|
|
"bbox": {
|
|
"x": 120,
|
|
"y": 210,
|
|
"width": 180,
|
|
"height": 45
|
|
},
|
|
"reason": "Multiple corrections"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|
|
|
|
def test_override_annotation_no_updates(self, client, app):
|
|
"""Test overriding annotation without providing any updates."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={}
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "no updates" in response.json()["detail"].lower()
|
|
|
|
def test_override_annotation_not_found(self, client, app):
|
|
"""Test overriding non-existent annotation."""
|
|
document_id = app.state.document_id
|
|
fake_annotation_id = str(uuid4())
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{fake_annotation_id}/override",
|
|
json={
|
|
"text_value": "TEST"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_override_annotation_document_not_found(self, client):
|
|
"""Test overriding annotation with non-existent document."""
|
|
fake_document_id = str(uuid4())
|
|
fake_annotation_id = str(uuid4())
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{fake_document_id}/annotations/{fake_annotation_id}/override",
|
|
json={
|
|
"text_value": "TEST"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"].lower()
|
|
|
|
def test_override_annotation_creates_history(self, client, app):
|
|
"""Test that overriding annotation creates history record."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"text_value": "INV-CORRECTED",
|
|
"reason": "Test history creation"
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
# History ID should be present and valid
|
|
assert "history_id" in data
|
|
assert data["history_id"] != ""
|
|
|
|
def test_override_annotation_with_reason(self, client, app):
|
|
"""Test overriding annotation with change reason."""
|
|
document_id = app.state.document_id
|
|
annotation_id = app.state.annotation_id_1
|
|
|
|
change_reason = "Correcting OCR misread"
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/annotations/{annotation_id}/override",
|
|
json={
|
|
"text_value": "INV-002",
|
|
"reason": change_reason
|
|
}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
# Reason is stored in history, not returned in response
|
|
data = response.json()
|
|
assert data["annotation_id"] == annotation_id
|