388 lines
14 KiB
Python
388 lines
14 KiB
Python
"""
|
|
Tests for Enhanced Admin Document Routes (Phase 3).
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from inference.web.api.v1.admin.documents import create_documents_router
|
|
from inference.web.config import StorageConfig
|
|
from inference.web.core.auth import (
|
|
validate_admin_token,
|
|
get_document_repository,
|
|
get_annotation_repository,
|
|
get_training_task_repository,
|
|
)
|
|
|
|
|
|
class MockAdminDocument:
|
|
"""Mock AdminDocument for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.document_id = kwargs.get('document_id', uuid4())
|
|
self.admin_token = kwargs.get('admin_token', 'test-token')
|
|
self.filename = kwargs.get('filename', 'test.pdf')
|
|
self.file_size = kwargs.get('file_size', 100000)
|
|
self.content_type = kwargs.get('content_type', 'application/pdf')
|
|
self.page_count = kwargs.get('page_count', 1)
|
|
self.status = kwargs.get('status', 'pending')
|
|
self.auto_label_status = kwargs.get('auto_label_status', None)
|
|
self.auto_label_error = kwargs.get('auto_label_error', None)
|
|
self.upload_source = kwargs.get('upload_source', 'ui')
|
|
self.batch_id = kwargs.get('batch_id', None)
|
|
self.csv_field_values = kwargs.get('csv_field_values', None)
|
|
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
|
self.category = kwargs.get('category', 'invoice')
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockAnnotation:
|
|
"""Mock AdminAnnotation for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.annotation_id = kwargs.get('annotation_id', uuid4())
|
|
self.document_id = kwargs.get('document_id')
|
|
self.page_number = kwargs.get('page_number', 1)
|
|
self.class_id = kwargs.get('class_id', 0)
|
|
self.class_name = kwargs.get('class_name', 'invoice_number')
|
|
self.bbox_x = kwargs.get('bbox_x', 100.0)
|
|
self.bbox_y = kwargs.get('bbox_y', 100.0)
|
|
self.bbox_width = kwargs.get('bbox_width', 200.0)
|
|
self.bbox_height = kwargs.get('bbox_height', 50.0)
|
|
self.x_center = kwargs.get('x_center', 0.5)
|
|
self.y_center = kwargs.get('y_center', 0.5)
|
|
self.width = kwargs.get('width', 0.3)
|
|
self.height = kwargs.get('height', 0.1)
|
|
self.text_value = kwargs.get('text_value', 'INV-001')
|
|
self.confidence = kwargs.get('confidence', 0.95)
|
|
self.source = kwargs.get('source', 'manual')
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
|
|
|
|
class MockDocumentRepository:
|
|
"""Mock DocumentRepository for testing enhanced features."""
|
|
|
|
def __init__(self):
|
|
self.documents = {}
|
|
self.annotations = {} # Shared reference for filtering
|
|
|
|
def get_paginated(
|
|
self,
|
|
admin_token=None,
|
|
status=None,
|
|
upload_source=None,
|
|
has_annotations=None,
|
|
auto_label_status=None,
|
|
batch_id=None,
|
|
category=None,
|
|
limit=20,
|
|
offset=0
|
|
):
|
|
"""Get filtered documents."""
|
|
docs = list(self.documents.values())
|
|
|
|
# Apply filters
|
|
if status:
|
|
docs = [d for d in docs if d.status == status]
|
|
if upload_source:
|
|
docs = [d for d in docs if d.upload_source == upload_source]
|
|
if has_annotations is not None:
|
|
for d in docs[:]:
|
|
ann_count = len(self.annotations.get(str(d.document_id), []))
|
|
if has_annotations and ann_count == 0:
|
|
docs.remove(d)
|
|
elif not has_annotations and ann_count > 0:
|
|
docs.remove(d)
|
|
if auto_label_status:
|
|
docs = [d for d in docs if d.auto_label_status == auto_label_status]
|
|
if batch_id:
|
|
docs = [d for d in docs if str(d.batch_id) == str(batch_id)]
|
|
if category:
|
|
docs = [d for d in docs if d.category == category]
|
|
|
|
total = len(docs)
|
|
return docs[offset:offset+limit], total
|
|
|
|
def count_by_status(self, admin_token=None):
|
|
"""Count documents by status."""
|
|
counts = {}
|
|
for doc in self.documents.values():
|
|
if admin_token is None or doc.admin_token == admin_token:
|
|
counts[doc.status] = counts.get(doc.status, 0) + 1
|
|
return counts
|
|
|
|
def get(self, document_id):
|
|
"""Get single document by ID."""
|
|
return self.documents.get(document_id)
|
|
|
|
def get_by_token(self, document_id, admin_token=None):
|
|
"""Get single document by ID and token."""
|
|
doc = self.documents.get(document_id)
|
|
if doc and (admin_token is None or doc.admin_token == admin_token):
|
|
return doc
|
|
return None
|
|
|
|
|
|
class MockAnnotationRepository:
|
|
"""Mock AnnotationRepository for testing enhanced features."""
|
|
|
|
def __init__(self):
|
|
self.annotations = {}
|
|
|
|
def get_for_document(self, document_id, page_number=None):
|
|
"""Get annotations for document."""
|
|
return self.annotations.get(str(document_id), [])
|
|
|
|
|
|
class MockTrainingTaskRepository:
|
|
"""Mock TrainingTaskRepository for testing enhanced features."""
|
|
|
|
def __init__(self):
|
|
self.training_tasks = {}
|
|
self.training_links = {}
|
|
|
|
def get_document_training_tasks(self, document_id):
|
|
"""Get training tasks that used this document."""
|
|
return self.training_links.get(str(document_id), [])
|
|
|
|
def get(self, task_id):
|
|
"""Get training task by ID."""
|
|
return self.training_tasks.get(str(task_id))
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create test FastAPI app."""
|
|
app = FastAPI()
|
|
|
|
# Create mock repositories
|
|
mock_document_repo = MockDocumentRepository()
|
|
mock_annotation_repo = MockAnnotationRepository()
|
|
mock_training_task_repo = MockTrainingTaskRepository()
|
|
|
|
# Add test documents
|
|
doc1 = MockAdminDocument(
|
|
filename="INV001.pdf",
|
|
status="labeled",
|
|
upload_source="ui",
|
|
auto_label_status=None,
|
|
batch_id=None
|
|
)
|
|
doc2 = MockAdminDocument(
|
|
filename="INV002.pdf",
|
|
status="labeled",
|
|
upload_source="api",
|
|
auto_label_status="completed",
|
|
batch_id=uuid4()
|
|
)
|
|
doc3 = MockAdminDocument(
|
|
filename="INV003.pdf",
|
|
status="pending",
|
|
upload_source="ui",
|
|
auto_label_status=None, # Not auto-labeled yet
|
|
batch_id=None
|
|
)
|
|
|
|
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
|
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
|
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
|
|
|
# Add annotations to doc1 and doc2
|
|
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
|
MockAnnotation(
|
|
document_id=doc1.document_id,
|
|
class_name="invoice_number",
|
|
text_value="INV-001"
|
|
)
|
|
]
|
|
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
|
MockAnnotation(
|
|
document_id=doc2.document_id,
|
|
class_id=6,
|
|
class_name="amount",
|
|
text_value="1500.00"
|
|
),
|
|
MockAnnotation(
|
|
document_id=doc2.document_id,
|
|
class_id=1,
|
|
class_name="invoice_date",
|
|
text_value="2024-01-15"
|
|
)
|
|
]
|
|
|
|
# Share annotation data with document repo for filtering
|
|
mock_document_repo.annotations = mock_annotation_repo.annotations
|
|
|
|
# Override dependencies
|
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
|
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
|
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
|
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
|
|
|
# Include router
|
|
router = create_documents_router(StorageConfig())
|
|
app.include_router(router)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
class TestEnhancedDocumentList:
|
|
"""Tests for enhanced document list endpoint."""
|
|
|
|
def test_list_documents_filter_by_upload_source_ui(self, client):
|
|
"""Test filtering documents by upload_source=ui."""
|
|
response = client.get("/admin/documents?upload_source=ui")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
assert all(doc["filename"].startswith("INV") for doc in data["documents"])
|
|
|
|
def test_list_documents_filter_by_upload_source_api(self, client):
|
|
"""Test filtering documents by upload_source=api."""
|
|
response = client.get("/admin/documents?upload_source=api")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert data["documents"][0]["filename"] == "INV002.pdf"
|
|
|
|
def test_list_documents_filter_by_has_annotations_true(self, client):
|
|
"""Test filtering documents with annotations."""
|
|
response = client.get("/admin/documents?has_annotations=true")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 2
|
|
|
|
def test_list_documents_filter_by_has_annotations_false(self, client):
|
|
"""Test filtering documents without annotations."""
|
|
response = client.get("/admin/documents?has_annotations=false")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
|
|
def test_list_documents_filter_by_auto_label_status(self, client):
|
|
"""Test filtering by auto_label_status."""
|
|
response = client.get("/admin/documents?auto_label_status=completed")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert data["documents"][0]["filename"] == "INV002.pdf"
|
|
|
|
def test_list_documents_filter_by_batch_id(self, client):
|
|
"""Test filtering by batch_id."""
|
|
# Get a batch_id from the test data
|
|
response_all = client.get("/admin/documents?upload_source=api")
|
|
batch_id = response_all.json()["documents"][0]["batch_id"]
|
|
|
|
response = client.get(f"/admin/documents?batch_id={batch_id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
|
|
def test_list_documents_combined_filters(self, client):
|
|
"""Test combining multiple filters."""
|
|
response = client.get(
|
|
"/admin/documents?status=labeled&upload_source=api"
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert data["documents"][0]["filename"] == "INV002.pdf"
|
|
|
|
def test_document_item_includes_new_fields(self, client):
|
|
"""Test DocumentItem includes new Phase 2/3 fields."""
|
|
response = client.get("/admin/documents?upload_source=api")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
doc = data["documents"][0]
|
|
|
|
# Check new fields exist
|
|
assert "upload_source" in doc
|
|
assert doc["upload_source"] == "api"
|
|
assert "batch_id" in doc
|
|
assert doc["batch_id"] is not None
|
|
assert "can_annotate" in doc
|
|
assert isinstance(doc["can_annotate"], bool)
|
|
|
|
|
|
class TestEnhancedDocumentDetail:
|
|
"""Tests for enhanced document detail endpoint."""
|
|
|
|
def test_document_detail_includes_new_fields(self, client, app):
|
|
"""Test DocumentDetailResponse includes new Phase 2/3 fields."""
|
|
# Get a document ID from list
|
|
response = client.get("/admin/documents?upload_source=api")
|
|
assert response.status_code == 200
|
|
doc_list = response.json()
|
|
document_id = doc_list["documents"][0]["document_id"]
|
|
|
|
# Get document detail
|
|
response = client.get(f"/admin/documents/{document_id}")
|
|
assert response.status_code == 200
|
|
doc = response.json()
|
|
|
|
# Check new fields exist
|
|
assert "upload_source" in doc
|
|
assert doc["upload_source"] == "api"
|
|
assert "batch_id" in doc
|
|
assert doc["batch_id"] is not None
|
|
assert "can_annotate" in doc
|
|
assert isinstance(doc["can_annotate"], bool)
|
|
assert "csv_field_values" in doc
|
|
assert "annotation_lock_until" in doc
|
|
|
|
def test_document_detail_ui_upload_defaults(self, client, app):
|
|
"""Test UI-uploaded document has correct defaults."""
|
|
# Get a UI-uploaded document
|
|
response = client.get("/admin/documents?upload_source=ui")
|
|
assert response.status_code == 200
|
|
doc_list = response.json()
|
|
document_id = doc_list["documents"][0]["document_id"]
|
|
|
|
# Get document detail
|
|
response = client.get(f"/admin/documents/{document_id}")
|
|
assert response.status_code == 200
|
|
doc = response.json()
|
|
|
|
# UI uploads should have these defaults
|
|
assert doc["upload_source"] == "ui"
|
|
assert doc["batch_id"] is None
|
|
assert doc["csv_field_values"] is None
|
|
assert doc["can_annotate"] is True
|
|
assert doc["annotation_lock_until"] is None
|
|
|
|
def test_document_detail_with_annotations(self, client, app):
|
|
"""Test document detail includes annotations."""
|
|
# Get a document with annotations
|
|
response = client.get("/admin/documents?has_annotations=true")
|
|
assert response.status_code == 200
|
|
doc_list = response.json()
|
|
document_id = doc_list["documents"][0]["document_id"]
|
|
|
|
# Get document detail
|
|
response = client.get(f"/admin/documents/{document_id}")
|
|
assert response.status_code == 200
|
|
doc = response.json()
|
|
|
|
# Should have annotations
|
|
assert "annotations" in doc
|
|
assert len(doc["annotations"]) > 0
|