""" Tests for Enhanced Admin Document Routes (Phase 3). """ import pytest from datetime import datetime from uuid import uuid4 from fastapi import FastAPI from fastapi.testclient import TestClient from inference.web.api.v1.admin.documents import create_documents_router from inference.web.config import StorageConfig from inference.web.core.auth import ( validate_admin_token, get_document_repository, get_annotation_repository, get_training_task_repository, ) class MockAdminDocument: """Mock AdminDocument for testing.""" def __init__(self, **kwargs): self.document_id = kwargs.get('document_id', uuid4()) self.admin_token = kwargs.get('admin_token', 'test-token') self.filename = kwargs.get('filename', 'test.pdf') self.file_size = kwargs.get('file_size', 100000) self.content_type = kwargs.get('content_type', 'application/pdf') self.page_count = kwargs.get('page_count', 1) self.status = kwargs.get('status', 'pending') self.auto_label_status = kwargs.get('auto_label_status', None) self.auto_label_error = kwargs.get('auto_label_error', None) self.upload_source = kwargs.get('upload_source', 'ui') self.batch_id = kwargs.get('batch_id', None) self.csv_field_values = kwargs.get('csv_field_values', None) self.annotation_lock_until = kwargs.get('annotation_lock_until', None) self.category = kwargs.get('category', 'invoice') self.created_at = kwargs.get('created_at', datetime.utcnow()) self.updated_at = kwargs.get('updated_at', datetime.utcnow()) class MockAnnotation: """Mock AdminAnnotation for testing.""" def __init__(self, **kwargs): self.annotation_id = kwargs.get('annotation_id', uuid4()) self.document_id = kwargs.get('document_id') self.page_number = kwargs.get('page_number', 1) self.class_id = kwargs.get('class_id', 0) self.class_name = kwargs.get('class_name', 'invoice_number') self.bbox_x = kwargs.get('bbox_x', 100.0) self.bbox_y = kwargs.get('bbox_y', 100.0) self.bbox_width = kwargs.get('bbox_width', 200.0) self.bbox_height = kwargs.get('bbox_height', 50.0) self.x_center = kwargs.get('x_center', 0.5) self.y_center = kwargs.get('y_center', 0.5) self.width = kwargs.get('width', 0.3) self.height = kwargs.get('height', 0.1) self.text_value = kwargs.get('text_value', 'INV-001') self.confidence = kwargs.get('confidence', 0.95) self.source = kwargs.get('source', 'manual') self.created_at = kwargs.get('created_at', datetime.utcnow()) class MockDocumentRepository: """Mock DocumentRepository for testing enhanced features.""" def __init__(self): self.documents = {} self.annotations = {} # Shared reference for filtering def get_paginated( self, admin_token=None, status=None, upload_source=None, has_annotations=None, auto_label_status=None, batch_id=None, category=None, limit=20, offset=0 ): """Get filtered documents.""" docs = list(self.documents.values()) # Apply filters if status: docs = [d for d in docs if d.status == status] if upload_source: docs = [d for d in docs if d.upload_source == upload_source] if has_annotations is not None: for d in docs[:]: ann_count = len(self.annotations.get(str(d.document_id), [])) if has_annotations and ann_count == 0: docs.remove(d) elif not has_annotations and ann_count > 0: docs.remove(d) if auto_label_status: docs = [d for d in docs if d.auto_label_status == auto_label_status] if batch_id: docs = [d for d in docs if str(d.batch_id) == str(batch_id)] if category: docs = [d for d in docs if d.category == category] total = len(docs) return docs[offset:offset+limit], total def count_by_status(self, admin_token=None): """Count documents by status.""" counts = {} for doc in self.documents.values(): if admin_token is None or doc.admin_token == admin_token: counts[doc.status] = counts.get(doc.status, 0) + 1 return counts def get(self, document_id): """Get single document by ID.""" return self.documents.get(document_id) def get_by_token(self, document_id, admin_token=None): """Get single document by ID and token.""" doc = self.documents.get(document_id) if doc and (admin_token is None or doc.admin_token == admin_token): return doc return None class MockAnnotationRepository: """Mock AnnotationRepository for testing enhanced features.""" def __init__(self): self.annotations = {} def get_for_document(self, document_id, page_number=None): """Get annotations for document.""" return self.annotations.get(str(document_id), []) class MockTrainingTaskRepository: """Mock TrainingTaskRepository for testing enhanced features.""" def __init__(self): self.training_tasks = {} self.training_links = {} def get_document_training_tasks(self, document_id): """Get training tasks that used this document.""" return self.training_links.get(str(document_id), []) def get(self, task_id): """Get training task by ID.""" return self.training_tasks.get(str(task_id)) @pytest.fixture def app(): """Create test FastAPI app.""" app = FastAPI() # Create mock repositories mock_document_repo = MockDocumentRepository() mock_annotation_repo = MockAnnotationRepository() mock_training_task_repo = MockTrainingTaskRepository() # Add test documents doc1 = MockAdminDocument( filename="INV001.pdf", status="labeled", upload_source="ui", auto_label_status=None, batch_id=None ) doc2 = MockAdminDocument( filename="INV002.pdf", status="labeled", upload_source="api", auto_label_status="completed", batch_id=uuid4() ) doc3 = MockAdminDocument( filename="INV003.pdf", status="pending", upload_source="ui", auto_label_status=None, # Not auto-labeled yet batch_id=None ) mock_document_repo.documents[str(doc1.document_id)] = doc1 mock_document_repo.documents[str(doc2.document_id)] = doc2 mock_document_repo.documents[str(doc3.document_id)] = doc3 # Add annotations to doc1 and doc2 mock_annotation_repo.annotations[str(doc1.document_id)] = [ MockAnnotation( document_id=doc1.document_id, class_name="invoice_number", text_value="INV-001" ) ] mock_annotation_repo.annotations[str(doc2.document_id)] = [ MockAnnotation( document_id=doc2.document_id, class_id=6, class_name="amount", text_value="1500.00" ), MockAnnotation( document_id=doc2.document_id, class_id=1, class_name="invoice_date", text_value="2024-01-15" ) ] # Share annotation data with document repo for filtering mock_document_repo.annotations = mock_annotation_repo.annotations # Override dependencies app.dependency_overrides[validate_admin_token] = lambda: "test-token" app.dependency_overrides[get_document_repository] = lambda: mock_document_repo app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo # Include router router = create_documents_router(StorageConfig()) app.include_router(router) return app @pytest.fixture def client(app): """Create test client.""" return TestClient(app) class TestEnhancedDocumentList: """Tests for enhanced document list endpoint.""" def test_list_documents_filter_by_upload_source_ui(self, client): """Test filtering documents by upload_source=ui.""" response = client.get("/admin/documents?upload_source=ui") assert response.status_code == 200 data = response.json() assert data["total"] == 2 assert all(doc["filename"].startswith("INV") for doc in data["documents"]) def test_list_documents_filter_by_upload_source_api(self, client): """Test filtering documents by upload_source=api.""" response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_list_documents_filter_by_has_annotations_true(self, client): """Test filtering documents with annotations.""" response = client.get("/admin/documents?has_annotations=true") assert response.status_code == 200 data = response.json() assert data["total"] == 2 def test_list_documents_filter_by_has_annotations_false(self, client): """Test filtering documents without annotations.""" response = client.get("/admin/documents?has_annotations=false") assert response.status_code == 200 data = response.json() assert data["total"] == 1 def test_list_documents_filter_by_auto_label_status(self, client): """Test filtering by auto_label_status.""" response = client.get("/admin/documents?auto_label_status=completed") assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_list_documents_filter_by_batch_id(self, client): """Test filtering by batch_id.""" # Get a batch_id from the test data response_all = client.get("/admin/documents?upload_source=api") batch_id = response_all.json()["documents"][0]["batch_id"] response = client.get(f"/admin/documents?batch_id={batch_id}") assert response.status_code == 200 data = response.json() assert data["total"] == 1 def test_list_documents_combined_filters(self, client): """Test combining multiple filters.""" response = client.get( "/admin/documents?status=labeled&upload_source=api" ) assert response.status_code == 200 data = response.json() assert data["total"] == 1 assert data["documents"][0]["filename"] == "INV002.pdf" def test_document_item_includes_new_fields(self, client): """Test DocumentItem includes new Phase 2/3 fields.""" response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 data = response.json() doc = data["documents"][0] # Check new fields exist assert "upload_source" in doc assert doc["upload_source"] == "api" assert "batch_id" in doc assert doc["batch_id"] is not None assert "can_annotate" in doc assert isinstance(doc["can_annotate"], bool) class TestEnhancedDocumentDetail: """Tests for enhanced document detail endpoint.""" def test_document_detail_includes_new_fields(self, client, app): """Test DocumentDetailResponse includes new Phase 2/3 fields.""" # Get a document ID from list response = client.get("/admin/documents?upload_source=api") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # Check new fields exist assert "upload_source" in doc assert doc["upload_source"] == "api" assert "batch_id" in doc assert doc["batch_id"] is not None assert "can_annotate" in doc assert isinstance(doc["can_annotate"], bool) assert "csv_field_values" in doc assert "annotation_lock_until" in doc def test_document_detail_ui_upload_defaults(self, client, app): """Test UI-uploaded document has correct defaults.""" # Get a UI-uploaded document response = client.get("/admin/documents?upload_source=ui") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # UI uploads should have these defaults assert doc["upload_source"] == "ui" assert doc["batch_id"] is None assert doc["csv_field_values"] is None assert doc["can_annotate"] is True assert doc["annotation_lock_until"] is None def test_document_detail_with_annotations(self, client, app): """Test document detail includes annotations.""" # Get a document with annotations response = client.get("/admin/documents?has_annotations=true") assert response.status_code == 200 doc_list = response.json() document_id = doc_list["documents"][0]["document_id"] # Get document detail response = client.get(f"/admin/documents/{document_id}") assert response.status_code == 200 doc = response.json() # Should have annotations assert "annotations" in doc assert len(doc["annotations"]) > 0