WIP
This commit is contained in:
@@ -11,7 +11,12 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_training_task_repository,
|
||||
)
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -59,14 +64,14 @@ class MockAnnotation:
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing enhanced features."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.annotations = {} # Shared reference for filtering
|
||||
|
||||
def get_documents_by_token(
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
@@ -103,32 +108,51 @@ class MockAdminDB:
|
||||
total = len(docs)
|
||||
return docs[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def count_documents_by_status(self, admin_token):
|
||||
def count_by_status(self, admin_token=None):
|
||||
"""Count documents by status."""
|
||||
counts = {}
|
||||
for doc in self.documents.values():
|
||||
if doc.admin_token == admin_token:
|
||||
if admin_token is None or doc.admin_token == admin_token:
|
||||
counts[doc.status] = counts.get(doc.status, 0) + 1
|
||||
return counts
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
def get(self, document_id):
|
||||
"""Get single document by ID."""
|
||||
return self.documents.get(document_id)
|
||||
|
||||
def get_by_token(self, document_id, admin_token=None):
|
||||
"""Get single document by ID and token."""
|
||||
doc = self.documents.get(document_id)
|
||||
if doc and doc.admin_token == admin_token:
|
||||
if doc and (admin_token is None or doc.admin_token == admin_token):
|
||||
return doc
|
||||
return None
|
||||
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = {}
|
||||
|
||||
def get_for_document(self, document_id, page_number=None):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
|
||||
class MockTrainingTaskRepository:
|
||||
"""Mock TrainingTaskRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return [] # No training history in this test
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
def get(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return None # No training tasks in this test
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -136,8 +160,10 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repositories
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
mock_annotation_repo = MockAnnotationRepository()
|
||||
mock_training_task_repo = MockTrainingTaskRepository()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
@@ -162,19 +188,19 @@ def app():
|
||||
batch_id=None
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
||||
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations to doc1 and doc2
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc1.document_id,
|
||||
class_name="invoice_number",
|
||||
text_value="INV-001"
|
||||
)
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc2.document_id,
|
||||
class_id=6,
|
||||
@@ -189,9 +215,14 @@ def app():
|
||||
)
|
||||
]
|
||||
|
||||
# Share annotation data with document repo for filtering
|
||||
mock_document_repo.annotations = mock_annotation_repo.annotations
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
||||
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
||||
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
||||
|
||||
# Include router
|
||||
router = create_documents_router(StorageConfig())
|
||||
|
||||
Reference in New Issue
Block a user