WIP
This commit is contained in:
@@ -10,7 +10,13 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_training_task_repository,
|
||||
get_model_version_repository,
|
||||
)
|
||||
|
||||
|
||||
class MockTrainingTask:
|
||||
@@ -128,19 +134,17 @@ class MockModelVersion:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
self.model_versions = {}
|
||||
self.annotations = {} # Shared reference for filtering
|
||||
self.training_links = {} # Shared reference for filtering
|
||||
|
||||
def get_documents_for_training(
|
||||
def get_for_training(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status="labeled",
|
||||
has_annotations=True,
|
||||
min_annotation_count=None,
|
||||
@@ -173,17 +177,28 @@ class MockAdminDB:
|
||||
total = len(filtered)
|
||||
return filtered[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = {}
|
||||
|
||||
def get_for_document(self, document_id, page_number=None):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_tasks_by_token(
|
||||
class MockTrainingTaskRepository:
|
||||
"""Mock TrainingTaskRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
limit=20,
|
||||
offset=0,
|
||||
@@ -196,11 +211,22 @@ class MockAdminDB:
|
||||
total = len(tasks)
|
||||
return tasks[offset:offset+limit], total
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
def get(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
|
||||
class MockModelVersionRepository:
|
||||
"""Mock ModelVersionRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_versions = {}
|
||||
|
||||
def get_paginated(self, status=None, limit=20, offset=0):
|
||||
"""Get model versions with optional filtering."""
|
||||
models = list(self.model_versions.values())
|
||||
if status:
|
||||
@@ -214,8 +240,11 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repositories
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
mock_annotation_repo = MockAnnotationRepository()
|
||||
mock_training_task_repo = MockTrainingTaskRepository()
|
||||
mock_model_version_repo = MockModelVersionRepository()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
@@ -231,22 +260,25 @@ def app():
|
||||
status="labeled",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
||||
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(document_id=doc1.document_id, source="manual"),
|
||||
MockAnnotation(document_id=doc1.document_id, source="auto"),
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
]
|
||||
# doc3 has no annotations
|
||||
|
||||
# Share annotation data with document repo for filtering
|
||||
mock_document_repo.annotations = mock_annotation_repo.annotations
|
||||
|
||||
# Add training tasks
|
||||
task1 = MockTrainingTask(
|
||||
name="Training Run 2024-01",
|
||||
@@ -265,15 +297,18 @@ def app():
|
||||
metrics_recall=0.92,
|
||||
)
|
||||
|
||||
mock_db.training_tasks[str(task1.task_id)] = task1
|
||||
mock_db.training_tasks[str(task2.task_id)] = task2
|
||||
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
|
||||
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
|
||||
|
||||
# Add training links (doc1 used in task1)
|
||||
link1 = MockTrainingDocumentLink(
|
||||
task_id=task1.task_id,
|
||||
document_id=doc1.document_id,
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Share training links with document repo for filtering
|
||||
mock_document_repo.training_links = mock_training_task_repo.training_links
|
||||
|
||||
# Add model versions
|
||||
model1 = MockModelVersion(
|
||||
@@ -296,12 +331,15 @@ def app():
|
||||
metrics_recall=0.92,
|
||||
document_count=600,
|
||||
)
|
||||
mock_db.model_versions[str(model1.version_id)] = model1
|
||||
mock_db.model_versions[str(model2.version_id)] = model2
|
||||
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
|
||||
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
||||
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
||||
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
||||
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
|
||||
|
||||
# Include router
|
||||
router = create_training_router()
|
||||
|
||||
Reference in New Issue
Block a user