This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -10,7 +10,13 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
get_model_version_repository,
)
class MockTrainingTask:
@@ -128,19 +134,17 @@ class MockModelVersion:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing Phase 4."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
self.model_versions = {}
self.annotations = {} # Shared reference for filtering
self.training_links = {} # Shared reference for filtering
def get_documents_for_training(
def get_for_training(
self,
admin_token,
admin_token=None,
status="labeled",
has_annotations=True,
min_annotation_count=None,
@@ -173,17 +177,28 @@ class MockAdminDB:
total = len(filtered)
return filtered[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing Phase 4."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
def get_training_tasks_by_token(
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing Phase 4."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_paginated(
self,
admin_token,
admin_token=None,
status=None,
limit=20,
offset=0,
@@ -196,11 +211,22 @@ class MockAdminDB:
total = len(tasks)
return tasks[offset:offset+limit], total
def get_training_task(self, task_id):
def get(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
class MockModelVersionRepository:
"""Mock ModelVersionRepository for testing Phase 4."""
def __init__(self):
self.model_versions = {}
def get_paginated(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
@@ -214,8 +240,11 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
mock_model_version_repo = MockModelVersionRepository()
# Add test documents
doc1 = MockAdminDocument(
@@ -231,22 +260,25 @@ def app():
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations
mock_db.annotations[str(doc1.document_id)] = [
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(document_id=doc1.document_id, source="manual"),
MockAnnotation(document_id=doc1.document_id, source="auto"),
]
mock_db.annotations[str(doc2.document_id)] = [
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
]
# doc3 has no annotations
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Add training tasks
task1 = MockTrainingTask(
name="Training Run 2024-01",
@@ -265,15 +297,18 @@ def app():
metrics_recall=0.92,
)
mock_db.training_tasks[str(task1.task_id)] = task1
mock_db.training_tasks[str(task2.task_id)] = task2
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
# Add training links (doc1 used in task1)
link1 = MockTrainingDocumentLink(
task_id=task1.task_id,
document_id=doc1.document_id,
)
mock_db.training_links[str(doc1.document_id)] = [link1]
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
# Share training links with document repo for filtering
mock_document_repo.training_links = mock_training_task_repo.training_links
# Add model versions
model1 = MockModelVersion(
@@ -296,12 +331,15 @@ def app():
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
# Include router
router = create_training_router()