""" Tests for Enhanced Admin Document Routes (Phase 3). """ import pytest from datetime import datetime from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient from inference.web.api.v1.admin.documents import create_documents_router from inference.web.config import StorageConfig from inference.web.core.auth import validate_admin_token, get_admin_db class MockAdminDocument: """Mock AdminDocument for testing.""" def __init__(self, **kwargs): self.document_id = kwargs.get('document_id', uuid4()) self.admin_token = kwargs.get('admin_token', 'test-token') self.filename = kwargs.get('filename', 'test.pdf') self.file_size = kwargs.get('file_size', 100000) self.content_type = kwargs.get('content_type', 'application/pdf') self.page_count = kwargs.get('page_count', 1) self.status = kwargs.get('status', 'pending') self.auto_label_status = kwargs.get('auto_label_status', None) self.auto_label_error = kwargs.get('auto_label_error', None) self.upload_source = kwargs.get('upload_source', 'ui') self.batch_id = kwargs.get('batch_id', None) self.csv_field_values = kwargs.get('csv_field_values', None) self.annotation_lock_until = kwargs.get('annotation_lock_until', None) self.category = kwargs.get('category', 'invoice') self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAnnotation: """Mock AdminAnnotation for testing.""" def __init__(self, **kwargs): self.annotation_id = kwargs.get('annotation_id', uuid4()) self.document_id = kwargs.get('document_id') self.page_number = kwargs.get('page_number', 1) self.class_id = kwargs.get('class_id', 0) self.class_name = kwargs.get('class_name', 'invoice_number') self.bbox_x = kwargs.get('bbox_x', 100.0) self.bbox_y = kwargs.get('bbox_y', 100.0) self.bbox_width = kwargs.get('bbox_width', 200.0) self.bbox_height = kwargs.get('bbox_height', 50.0) self.x_center = kwargs.get('x_center', 0.5) self.y_center = kwargs.get('y_center', 0.5) self.width = kwargs.get('width', 0.3) self.height = kwargs.get('height', 0.1) self.text_value = kwargs.get('text_value', 'INV-001') self.confidence = kwargs.get('confidence', 0.95) self.source = kwargs.get('source', 'manual') self.created_at = kwargs.get('created_at', datetime.utcnow()) class MockAdminDB: """Mock AdminDB for testing enhanced features.""" def __init__(self): self.documents = {} self.annotations = {} def get_documents_by_token( self, admin_token=None, status=None, upload_source=None, has_annotations=None, auto_label_status=None, batch_id=None, category=None, limit=20, offset=0 ): """Get filtered documents.""" docs = list(self.documents.values()) # Apply filters if status: docs = [d for d in docs if d.status == status] if upload_source: docs = [d for d in docs if d.upload_source == upload_source] if has_annotations is not None: for d in docs[:]: ann_count = len(self.annotations.get(str(d.document_id), [])) if has_annotations and ann_count == 0: docs.remove(d) elif not has_annotations and ann_count > 0: docs.remove(d) if auto_label_status: docs = [d for d in docs if d.auto_label_status == auto_label_status] if batch_id: docs = [d for d in docs if str(d.batch_id) == str(batch_id)] if category: docs = [d for d in docs if d.category == category] total = len(docs) return docs[offset:offset+limit], total def get_annotations_for_document(self, document_id): """Get annotations for document.""" return self.annotations.get(str(document_id), []) def count_documents_by_status(self, admin_token): """Count documents by status.""" counts = {} for doc in self.documents.values(): if doc.admin_token == admin_token: counts[doc.status] = counts.get(doc.status, 0) + 1 return counts def get_document_by_token(self, document_id, admin_token): """Get single document by ID and token.""" doc = self.documents.get(document_id) if doc and doc.admin_token == admin_token: return doc return None def get_document_training_tasks(self, document_id): """Get training tasks that used this document.""" return [] # No training history in this test def get_training_task(self, task_id): """Get training task by ID.""" return None # No training tasks in this test @pytest.fixture def app(): """Create test FastAPI app.""" app = FastAPI() # Create mock DB mock_db = MockAdminDB() # Add test documents doc1 = MockAdminDocument( filename="INV001.pdf", status="labeled", upload_source="ui", auto_label_status=None, batch_id=None ) doc2 = MockAdminDocument( filename="INV002.pdf", status="labeled", upload_source="api", auto_label_status="completed", batch_id=uuid4() ) doc3 = MockAdminDocument( filename="INV003.pdf", status="pending", upload_source="ui", auto_label_status=None, # Not auto-labeled yet batch_id=None ) mock_db.documents[str(doc1.document_id)] = doc1 mock_db.documents[str(doc2.document_id)] = doc2 mock_db.documents[str(doc3.document_id)] = doc3 # Add annotations to doc1 and doc2 mock_db.annotations[str(doc1.document_id)] = [ MockAnnotation( document_id=doc1.document_id, class_name="invoice_number", text_value="INV-001" ) ] mock_db.annotations[str(doc2.document_id)] = [ MockAnnotation( document_id=doc2.document_id, class_id=6, class_name="amount", text_value="1500.00" ), MockAnnotation( document_id=doc2.document_id, class_id=1, class_name="invoice_date", text_value="2024-01-15" ) ] # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_admin_db] = lambda: mock_db # Include router router = create_documents_router(StorageConfig()) app.include_router(router) return app @pytest.fixture def client(app): """Create test client.""" return TestClient(app) class TestEnhancedDocumentList: """Tests for enhanced document list endpoint.""" def test_list_documents_filter_by_upload_source_ui(self, client): """Test filtering documents by upload_source=ui.""" response = client.get("/admin/documents?upload_source=ui") assert response.status_code == 200 data = response.json() assert data["total"] == 2 assert all(doc["filename"].startswith("INV") for doc in data["documents"]) def test_list_documents_filter_by_upload_source_api(self, client): """Test filtering documents by upload_source=api.""" response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_list_documents_filter_by_has_annotations_true(self, client): """Test filtering documents with annotations.""" response = client.get("/admin/documents?has_annotations=true") assert response.status_code == 200 data = response.json() assert data["total"] == 2 def test_list_documents_filter_by_has_annotations_false(self, client): """Test filtering documents without annotations.""" response = client.get("/admin/documents?has_annotations=false") assert response.status_code == 200 data = response.json() assert data["total"] == 1 def test_list_documents_filter_by_auto_label_status(self, client): """Test filtering by auto_label_status.""" response = client.get("/admin/documents?auto_label_status=completed") assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_list_documents_filter_by_batch_id(self, client): """Test filtering by batch_id.""" # Get a batch_id from the test data response_all = client.get("/admin/documents?upload_source=api") batch_id = response_all.json()["documents"][0]["batch_id"] response = client.get(f"/admin/documents?batch_id={batch_id}") assert response.status_code == 200 data = response.json() assert data["total"] == 1 def test_list_documents_combined_filters(self, client): """Test combining multiple filters.""" response = client.get( "/admin/documents?status=labeled&upload_source=api" ) assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_document_item_includes_new_fields(self, client): """Test DocumentItem includes new Phase 2/3 fields.""" response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 data = response.json() doc = data["documents"][0] # Check new fields exist assert "upload_source" in doc assert doc["upload_source"] == "api" assert "batch_id" in doc assert doc["batch_id"] is not None assert "can_annotate" in doc assert isinstance(doc["can_annotate"], bool) class TestEnhancedDocumentDetail: """Tests for enhanced document detail endpoint.""" def test_document_detail_includes_new_fields(self, client, app): """Test DocumentDetailResponse includes new Phase 2/3 fields.""" # Get a document ID from list response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # Check new fields exist assert "upload_source" in doc assert doc["upload_source"] == "api" assert "batch_id" in doc assert doc["batch_id"] is not None assert "can_annotate" in doc assert isinstance(doc["can_annotate"], bool) assert "csv_field_values" in doc assert "annotation_lock_until" in doc def test_document_detail_ui_upload_defaults(self, client, app): """Test UI-uploaded document has correct defaults.""" # Get a UI-uploaded document response = client.get("/admin/documents?upload_source=ui") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # UI uploads should have these defaults assert doc["upload_source"] == "ui" assert doc["batch_id"] is None assert doc["csv_field_values"] is None assert doc["can_annotate"] is True assert doc["annotation_lock_until"] is None def test_document_detail_with_annotations(self, client, app): """Test document detail includes annotations.""" # Get a document with annotations response = client.get("/admin/documents?has_annotations=true") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # Should have annotations assert "annotations" in doc assert len(doc["annotations"]) > 0