Files
invoice-master-poc-v2/tests/web/test_admin_routes_enhanced.py
Yaojia Wang a516de4320 WIP
2026-02-01 00:08:40 +01:00

357 lines
12 KiB
Python

"""
Tests for Enhanced Admin Document Routes (Phase 3).
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import validate_admin_token, get_admin_db
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'pending')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.category = kwargs.get('category', 'invoice')
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100.0)
self.bbox_y = kwargs.get('bbox_y', 100.0)
self.bbox_width = kwargs.get('bbox_width', 200.0)
self.bbox_height = kwargs.get('bbox_height', 50.0)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing enhanced features."""
def __init__(self):
self.documents = {}
self.annotations = {}
def get_documents_by_token(
self,
admin_token=None,
status=None,
upload_source=None,
has_annotations=None,
auto_label_status=None,
batch_id=None,
category=None,
limit=20,
offset=0
):
"""Get filtered documents."""
docs = list(self.documents.values())
# Apply filters
if status:
docs = [d for d in docs if d.status == status]
if upload_source:
docs = [d for d in docs if d.upload_source == upload_source]
if has_annotations is not None:
for d in docs[:]:
ann_count = len(self.annotations.get(str(d.document_id), []))
if has_annotations and ann_count == 0:
docs.remove(d)
elif not has_annotations and ann_count > 0:
docs.remove(d)
if auto_label_status:
docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
if category:
docs = [d for d in docs if d.category == category]
total = len(docs)
return docs[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def count_documents_by_status(self, admin_token):
"""Count documents by status."""
counts = {}
for doc in self.documents.values():
if doc.admin_token == admin_token:
counts[doc.status] = counts.get(doc.status, 0) + 1
return counts
def get_document_by_token(self, document_id, admin_token):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
return doc
return None
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return [] # No training history in this test
def get_training_task(self, task_id):
"""Get training task by ID."""
return None # No training tasks in this test
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Add test documents
doc1 = MockAdminDocument(
filename="INV001.pdf",
status="labeled",
upload_source="ui",
auto_label_status=None,
batch_id=None
)
doc2 = MockAdminDocument(
filename="INV002.pdf",
status="labeled",
upload_source="api",
auto_label_status="completed",
batch_id=uuid4()
)
doc3 = MockAdminDocument(
filename="INV003.pdf",
status="pending",
upload_source="ui",
auto_label_status=None, # Not auto-labeled yet
batch_id=None
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
# Add annotations to doc1 and doc2
mock_db.annotations[str(doc1.document_id)] = [
MockAnnotation(
document_id=doc1.document_id,
class_name="invoice_number",
text_value="INV-001"
)
]
mock_db.annotations[str(doc2.document_id)] = [
MockAnnotation(
document_id=doc2.document_id,
class_id=6,
class_name="amount",
text_value="1500.00"
),
MockAnnotation(
document_id=doc2.document_id,
class_id=1,
class_name="invoice_date",
text_value="2024-01-15"
)
]
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
# Include router
router = create_documents_router(StorageConfig())
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestEnhancedDocumentList:
"""Tests for enhanced document list endpoint."""
def test_list_documents_filter_by_upload_source_ui(self, client):
"""Test filtering documents by upload_source=ui."""
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(doc["filename"].startswith("INV") for doc in data["documents"])
def test_list_documents_filter_by_upload_source_api(self, client):
"""Test filtering documents by upload_source=api."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_has_annotations_true(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
def test_list_documents_filter_by_has_annotations_false(self, client):
"""Test filtering documents without annotations."""
response = client.get("/admin/documents?has_annotations=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_filter_by_auto_label_status(self, client):
"""Test filtering by auto_label_status."""
response = client.get("/admin/documents?auto_label_status=completed")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_batch_id(self, client):
"""Test filtering by batch_id."""
# Get a batch_id from the test data
response_all = client.get("/admin/documents?upload_source=api")
batch_id = response_all.json()["documents"][0]["batch_id"]
response = client.get(f"/admin/documents?batch_id={batch_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_combined_filters(self, client):
"""Test combining multiple filters."""
response = client.get(
"/admin/documents?status=labeled&upload_source=api"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_document_item_includes_new_fields(self, client):
"""Test DocumentItem includes new Phase 2/3 fields."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
doc = data["documents"][0]
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
class TestEnhancedDocumentDetail:
"""Tests for enhanced document detail endpoint."""
def test_document_detail_includes_new_fields(self, client, app):
"""Test DocumentDetailResponse includes new Phase 2/3 fields."""
# Get a document ID from list
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
assert "csv_field_values" in doc
assert "annotation_lock_until" in doc
def test_document_detail_ui_upload_defaults(self, client, app):
"""Test UI-uploaded document has correct defaults."""
# Get a UI-uploaded document
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# UI uploads should have these defaults
assert doc["upload_source"] == "ui"
assert doc["batch_id"] is None
assert doc["csv_field_values"] is None
assert doc["can_annotate"] is True
assert doc["annotation_lock_until"] is None
def test_document_detail_with_annotations(self, client, app):
"""Test document detail includes annotations."""
# Get a document with annotations
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Should have annotations
assert "annotations" in doc
assert len(doc["annotations"]) > 0