Files
invoice-master-poc-v2/tests/web/test_admin_routes_enhanced.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

388 lines
14 KiB
Python

"""
Tests for Enhanced Admin Document Routes (Phase 3).
"""
import pytest
from datetime import datetime
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
)
class MockAdminDocument:
"""Mock AdminDocument for testing."""
def __init__(self, **kwargs):
self.document_id = kwargs.get('document_id', uuid4())
self.admin_token = kwargs.get('admin_token', 'test-token')
self.filename = kwargs.get('filename', 'test.pdf')
self.file_size = kwargs.get('file_size', 100000)
self.content_type = kwargs.get('content_type', 'application/pdf')
self.page_count = kwargs.get('page_count', 1)
self.status = kwargs.get('status', 'pending')
self.auto_label_status = kwargs.get('auto_label_status', None)
self.auto_label_error = kwargs.get('auto_label_error', None)
self.upload_source = kwargs.get('upload_source', 'ui')
self.batch_id = kwargs.get('batch_id', None)
self.csv_field_values = kwargs.get('csv_field_values', None)
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
self.category = kwargs.get('category', 'invoice')
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAnnotation:
"""Mock AdminAnnotation for testing."""
def __init__(self, **kwargs):
self.annotation_id = kwargs.get('annotation_id', uuid4())
self.document_id = kwargs.get('document_id')
self.page_number = kwargs.get('page_number', 1)
self.class_id = kwargs.get('class_id', 0)
self.class_name = kwargs.get('class_name', 'invoice_number')
self.bbox_x = kwargs.get('bbox_x', 100.0)
self.bbox_y = kwargs.get('bbox_y', 100.0)
self.bbox_width = kwargs.get('bbox_width', 200.0)
self.bbox_height = kwargs.get('bbox_height', 50.0)
self.x_center = kwargs.get('x_center', 0.5)
self.y_center = kwargs.get('y_center', 0.5)
self.width = kwargs.get('width', 0.3)
self.height = kwargs.get('height', 0.1)
self.text_value = kwargs.get('text_value', 'INV-001')
self.confidence = kwargs.get('confidence', 0.95)
self.source = kwargs.get('source', 'manual')
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockDocumentRepository:
"""Mock DocumentRepository for testing enhanced features."""
def __init__(self):
self.documents = {}
self.annotations = {} # Shared reference for filtering
def get_paginated(
self,
admin_token=None,
status=None,
upload_source=None,
has_annotations=None,
auto_label_status=None,
batch_id=None,
category=None,
limit=20,
offset=0
):
"""Get filtered documents."""
docs = list(self.documents.values())
# Apply filters
if status:
docs = [d for d in docs if d.status == status]
if upload_source:
docs = [d for d in docs if d.upload_source == upload_source]
if has_annotations is not None:
for d in docs[:]:
ann_count = len(self.annotations.get(str(d.document_id), []))
if has_annotations and ann_count == 0:
docs.remove(d)
elif not has_annotations and ann_count > 0:
docs.remove(d)
if auto_label_status:
docs = [d for d in docs if d.auto_label_status == auto_label_status]
if batch_id:
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
if category:
docs = [d for d in docs if d.category == category]
total = len(docs)
return docs[offset:offset+limit], total
def count_by_status(self, admin_token=None):
"""Count documents by status."""
counts = {}
for doc in self.documents.values():
if admin_token is None or doc.admin_token == admin_token:
counts[doc.status] = counts.get(doc.status, 0) + 1
return counts
def get(self, document_id):
"""Get single document by ID."""
return self.documents.get(document_id)
def get_by_token(self, document_id, admin_token=None):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and (admin_token is None or doc.admin_token == admin_token):
return doc
return None
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing enhanced features."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing enhanced features."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
def get(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
@pytest.fixture
def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
# Add test documents
doc1 = MockAdminDocument(
filename="INV001.pdf",
status="labeled",
upload_source="ui",
auto_label_status=None,
batch_id=None
)
doc2 = MockAdminDocument(
filename="INV002.pdf",
status="labeled",
upload_source="api",
auto_label_status="completed",
batch_id=uuid4()
)
doc3 = MockAdminDocument(
filename="INV003.pdf",
status="pending",
upload_source="ui",
auto_label_status=None, # Not auto-labeled yet
batch_id=None
)
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations to doc1 and doc2
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(
document_id=doc1.document_id,
class_name="invoice_number",
text_value="INV-001"
)
]
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(
document_id=doc2.document_id,
class_id=6,
class_name="amount",
text_value="1500.00"
),
MockAnnotation(
document_id=doc2.document_id,
class_id=1,
class_name="invoice_date",
text_value="2024-01-15"
)
]
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
# Include router
router = create_documents_router(StorageConfig())
app.include_router(router)
return app
@pytest.fixture
def client(app):
"""Create test client."""
return TestClient(app)
class TestEnhancedDocumentList:
"""Tests for enhanced document list endpoint."""
def test_list_documents_filter_by_upload_source_ui(self, client):
"""Test filtering documents by upload_source=ui."""
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
assert all(doc["filename"].startswith("INV") for doc in data["documents"])
def test_list_documents_filter_by_upload_source_api(self, client):
"""Test filtering documents by upload_source=api."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_has_annotations_true(self, client):
"""Test filtering documents with annotations."""
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2
def test_list_documents_filter_by_has_annotations_false(self, client):
"""Test filtering documents without annotations."""
response = client.get("/admin/documents?has_annotations=false")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_filter_by_auto_label_status(self, client):
"""Test filtering by auto_label_status."""
response = client.get("/admin/documents?auto_label_status=completed")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_list_documents_filter_by_batch_id(self, client):
"""Test filtering by batch_id."""
# Get a batch_id from the test data
response_all = client.get("/admin/documents?upload_source=api")
batch_id = response_all.json()["documents"][0]["batch_id"]
response = client.get(f"/admin/documents?batch_id={batch_id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
def test_list_documents_combined_filters(self, client):
"""Test combining multiple filters."""
response = client.get(
"/admin/documents?status=labeled&upload_source=api"
)
assert response.status_code == 200
data = response.json()
assert data["total"] == 1
assert data["documents"][0]["filename"] == "INV002.pdf"
def test_document_item_includes_new_fields(self, client):
"""Test DocumentItem includes new Phase 2/3 fields."""
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
data = response.json()
doc = data["documents"][0]
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
class TestEnhancedDocumentDetail:
"""Tests for enhanced document detail endpoint."""
def test_document_detail_includes_new_fields(self, client, app):
"""Test DocumentDetailResponse includes new Phase 2/3 fields."""
# Get a document ID from list
response = client.get("/admin/documents?upload_source=api")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Check new fields exist
assert "upload_source" in doc
assert doc["upload_source"] == "api"
assert "batch_id" in doc
assert doc["batch_id"] is not None
assert "can_annotate" in doc
assert isinstance(doc["can_annotate"], bool)
assert "csv_field_values" in doc
assert "annotation_lock_until" in doc
def test_document_detail_ui_upload_defaults(self, client, app):
"""Test UI-uploaded document has correct defaults."""
# Get a UI-uploaded document
response = client.get("/admin/documents?upload_source=ui")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# UI uploads should have these defaults
assert doc["upload_source"] == "ui"
assert doc["batch_id"] is None
assert doc["csv_field_values"] is None
assert doc["can_annotate"] is True
assert doc["annotation_lock_until"] is None
def test_document_detail_with_annotations(self, client, app):
"""Test document detail includes annotations."""
# Get a document with annotations
response = client.get("/admin/documents?has_annotations=true")
assert response.status_code == 200
doc_list = response.json()
document_id = doc_list["documents"][0]["document_id"]
# Get document detail
response = client.get(f"/admin/documents/{document_id}")
assert response.status_code == 200
doc = response.json()
# Should have annotations
assert "annotations" in doc
assert len(doc["annotations"]) > 0