This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -11,7 +11,12 @@ from fastapi.testclient import TestClient
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
)
class MockAdminDocument:
@@ -59,14 +64,14 @@ class MockAnnotation:
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing enhanced features."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing enhanced features."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.annotations = {} # Shared reference for filtering
def get_documents_by_token(
def get_paginated(
self,
admin_token=None,
status=None,
@@ -103,32 +108,51 @@ class MockAdminDB:
total = len(docs)
return docs[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def count_documents_by_status(self, admin_token):
def count_by_status(self, admin_token=None):
"""Count documents by status."""
counts = {}
for doc in self.documents.values():
if doc.admin_token == admin_token:
if admin_token is None or doc.admin_token == admin_token:
counts[doc.status] = counts.get(doc.status, 0) + 1
return counts
def get_document_by_token(self, document_id, admin_token):
def get(self, document_id):
"""Get single document by ID."""
return self.documents.get(document_id)
def get_by_token(self, document_id, admin_token=None):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
if doc and (admin_token is None or doc.admin_token == admin_token):
return doc
return None
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing enhanced features."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing enhanced features."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return [] # No training history in this test
return self.training_links.get(str(document_id), [])
def get_training_task(self, task_id):
def get(self, task_id):
"""Get training task by ID."""
return None # No training tasks in this test
return self.training_tasks.get(str(task_id))
@pytest.fixture
@@ -136,8 +160,10 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
# Add test documents
doc1 = MockAdminDocument(
@@ -162,19 +188,19 @@ def app():
batch_id=None
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations to doc1 and doc2
mock_db.annotations[str(doc1.document_id)] = [
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(
document_id=doc1.document_id,
class_name="invoice_number",
text_value="INV-001"
)
]
mock_db.annotations[str(doc2.document_id)] = [
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(
document_id=doc2.document_id,
class_id=6,
@@ -189,9 +215,14 @@ def app():
)
]
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
# Include router
router = create_documents_router(StorageConfig())