This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -0,0 +1 @@
"""Tests for web core components."""

View File

@@ -0,0 +1,672 @@
"""Tests for unified task management interface.
TDD: These tests are written first (RED phase).
"""
from abc import ABC
from unittest.mock import MagicMock, patch
import pytest
class TestTaskStatus:
"""Tests for TaskStatus dataclass."""
def test_task_status_basic_fields(self) -> None:
"""TaskStatus has all required fields."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test_runner",
is_running=True,
pending_count=5,
processing_count=2,
)
assert status.name == "test_runner"
assert status.is_running is True
assert status.pending_count == 5
assert status.processing_count == 2
def test_task_status_with_error(self) -> None:
"""TaskStatus can include optional error message."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="failed_runner",
is_running=False,
pending_count=0,
processing_count=0,
error="Connection failed",
)
assert status.error == "Connection failed"
def test_task_status_default_error_is_none(self) -> None:
"""TaskStatus error defaults to None."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
is_running=True,
pending_count=0,
processing_count=0,
)
assert status.error is None
def test_task_status_is_frozen(self) -> None:
"""TaskStatus is immutable (frozen dataclass)."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
is_running=True,
pending_count=0,
processing_count=0,
)
with pytest.raises(AttributeError):
status.name = "changed" # type: ignore[misc]
class TestTaskRunnerInterface:
"""Tests for TaskRunner abstract base class."""
def test_cannot_instantiate_directly(self) -> None:
"""TaskRunner is abstract and cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
with pytest.raises(TypeError):
TaskRunner() # type: ignore[abstract]
def test_is_abstract_base_class(self) -> None:
"""TaskRunner inherits from ABC."""
from inference.web.core.task_interface import TaskRunner
assert issubclass(TaskRunner, ABC)
def test_subclass_missing_name_cannot_instantiate(self) -> None:
"""Subclass without name property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingName(TaskRunner):
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingName() # type: ignore[abstract]
def test_subclass_missing_start_cannot_instantiate(self) -> None:
"""Subclass without start method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingStart(TaskRunner):
@property
def name(self) -> str:
return "test"
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingStart() # type: ignore[abstract]
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
"""Subclass without stop method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingStop(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingStop() # type: ignore[abstract]
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
"""Subclass without is_running property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingIsRunning(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingIsRunning() # type: ignore[abstract]
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
"""Subclass without get_status method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
class MissingGetStatus(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
with pytest.raises(TypeError):
MissingGetStatus() # type: ignore[abstract]
def test_complete_subclass_can_instantiate(self) -> None:
"""Complete subclass implementing all methods can be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class CompleteRunner(TaskRunner):
def __init__(self) -> None:
self._running = False
@property
def name(self) -> str:
return "complete_runner"
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=0,
processing_count=0,
)
runner = CompleteRunner()
assert runner.name == "complete_runner"
assert runner.is_running is False
runner.start()
assert runner.is_running is True
status = runner.get_status()
assert status.name == "complete_runner"
assert status.is_running is True
runner.stop()
assert runner.is_running is False
class TestTaskManager:
"""Tests for TaskManager facade."""
def test_register_runner(self) -> None:
"""Can register a task runner."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
@property
def name(self) -> str:
return "mock"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("mock", False, 0, 0)
manager = TaskManager()
runner = MockRunner()
manager.register(runner)
assert manager.get_runner("mock") is runner
def test_get_runner_returns_none_for_unknown(self) -> None:
"""get_runner returns None for unknown runner name."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_runner("unknown") is None
def test_start_all_runners(self) -> None:
"""start_all starts all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
self._running = False
@property
def name(self) -> str:
return self._name
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, self._running, 0, 0)
manager = TaskManager()
runner1 = MockRunner("runner1")
runner2 = MockRunner("runner2")
manager.register(runner1)
manager.register(runner2)
assert runner1.is_running is False
assert runner2.is_running is False
manager.start_all()
assert runner1.is_running is True
assert runner2.is_running is True
def test_stop_all_runners(self) -> None:
"""stop_all stops all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
self._running = True
@property
def name(self) -> str:
return self._name
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, self._running, 0, 0)
manager = TaskManager()
runner1 = MockRunner("runner1")
runner2 = MockRunner("runner2")
manager.register(runner1)
manager.register(runner2)
assert runner1.is_running is True
assert runner2.is_running is True
manager.stop_all()
assert runner1.is_running is False
assert runner2.is_running is False
def test_get_all_status(self) -> None:
"""get_all_status returns status of all runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str, pending: int) -> None:
self._name = runner_name
self._pending = pending
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return True
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, True, self._pending, 0)
manager = TaskManager()
manager.register(MockRunner("runner1", 5))
manager.register(MockRunner("runner2", 10))
all_status = manager.get_all_status()
assert len(all_status) == 2
assert all_status["runner1"].pending_count == 5
assert all_status["runner2"].pending_count == 10
def test_get_all_status_empty_when_no_runners(self) -> None:
"""get_all_status returns empty dict when no runners registered."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_all_status() == {}
def test_runner_names_property(self) -> None:
"""runner_names returns list of all registered runner names."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, False, 0, 0)
manager = TaskManager()
manager.register(MockRunner("alpha"))
manager.register(MockRunner("beta"))
names = manager.runner_names
assert set(names) == {"alpha", "beta"}
def test_stop_all_with_timeout_distribution(self) -> None:
"""stop_all distributes timeout across runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
received_timeouts: list[float | None] = []
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
received_timeouts.append(timeout)
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, False, 0, 0)
manager = TaskManager()
manager.register(MockRunner("r1"))
manager.register(MockRunner("r2"))
manager.stop_all(timeout=20.0)
# Timeout should be distributed (20 / 2 = 10 each)
assert len(received_timeouts) == 2
assert all(t == 10.0 for t in received_timeouts)
def test_start_all_skips_runners_requiring_arguments(self) -> None:
"""start_all skips runners that require arguments."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
no_args_started = []
with_args_started = []
class NoArgsRunner(TaskRunner):
@property
def name(self) -> str:
return "no_args"
def start(self) -> None:
no_args_started.append(True)
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("no_args", False, 0, 0)
class RequiresArgsRunner(TaskRunner):
@property
def name(self) -> str:
return "requires_args"
def start(self, handler: object) -> None: # type: ignore[override]
# This runner requires an argument
with_args_started.append(True)
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("requires_args", False, 0, 0)
manager = TaskManager()
manager.register(NoArgsRunner())
manager.register(RequiresArgsRunner())
# start_all should start no_args runner but skip requires_args
manager.start_all()
assert len(no_args_started) == 1
assert len(with_args_started) == 0 # Skipped due to TypeError
def test_stop_all_with_no_runners(self) -> None:
"""stop_all does nothing when no runners registered."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
# Should not raise any exception
manager.stop_all()
# Just verify it returns without error
assert manager.runner_names == []
class TestTrainingSchedulerInterface:
"""Tests for TrainingScheduler implementing TaskRunner."""
def test_training_scheduler_is_task_runner(self) -> None:
"""TrainingScheduler inherits from TaskRunner."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskRunner
scheduler = TrainingScheduler()
assert isinstance(scheduler, TaskRunner)
def test_training_scheduler_name(self) -> None:
"""TrainingScheduler has correct name."""
from inference.web.core.scheduler import TrainingScheduler
scheduler = TrainingScheduler()
assert scheduler.name == "training_scheduler"
def test_training_scheduler_get_status(self) -> None:
"""TrainingScheduler provides status via get_status."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskStatus
scheduler = TrainingScheduler()
# Mock the training tasks repository
mock_tasks = MagicMock()
mock_tasks.get_pending.return_value = [MagicMock(), MagicMock()]
scheduler._training_tasks = mock_tasks
status = scheduler.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "training_scheduler"
assert status.is_running is False
assert status.pending_count == 2
class TestAutoLabelSchedulerInterface:
"""Tests for AutoLabelScheduler implementing TaskRunner."""
def test_autolabel_scheduler_is_task_runner(self) -> None:
"""AutoLabelScheduler inherits from TaskRunner."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskRunner
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert isinstance(scheduler, TaskRunner)
def test_autolabel_scheduler_name(self) -> None:
"""AutoLabelScheduler has correct name."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert scheduler.name == "autolabel_scheduler"
def test_autolabel_scheduler_get_status(self) -> None:
"""AutoLabelScheduler provides status via get_status."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskStatus
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
with patch(
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
) as mock_get:
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
scheduler = AutoLabelScheduler()
status = scheduler.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "autolabel_scheduler"
assert status.is_running is False
assert status.pending_count == 3
class TestAsyncTaskQueueInterface:
"""Tests for AsyncTaskQueue implementing TaskRunner."""
def test_async_queue_is_task_runner(self) -> None:
"""AsyncTaskQueue inherits from TaskRunner."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskRunner
queue = AsyncTaskQueue()
assert isinstance(queue, TaskRunner)
def test_async_queue_name(self) -> None:
"""AsyncTaskQueue has correct name."""
from inference.web.workers.async_queue import AsyncTaskQueue
queue = AsyncTaskQueue()
assert queue.name == "async_task_queue"
def test_async_queue_get_status(self) -> None:
"""AsyncTaskQueue provides status via get_status."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskStatus
queue = AsyncTaskQueue()
status = queue.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "async_task_queue"
assert status.is_running is False
assert status.pending_count == 0
assert status.processing_count == 0
class TestBatchTaskQueueInterface:
"""Tests for BatchTaskQueue implementing TaskRunner."""
def test_batch_queue_is_task_runner(self) -> None:
"""BatchTaskQueue inherits from TaskRunner."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskRunner
queue = BatchTaskQueue()
assert isinstance(queue, TaskRunner)
def test_batch_queue_name(self) -> None:
"""BatchTaskQueue has correct name."""
from inference.web.workers.batch_queue import BatchTaskQueue
queue = BatchTaskQueue()
assert queue.name == "batch_task_queue"
def test_batch_queue_get_status(self) -> None:
"""BatchTaskQueue provides status via get_status."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskStatus
queue = BatchTaskQueue()
status = queue.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "batch_task_queue"
assert status.is_running is False
assert status.pending_count == 0

View File

@@ -8,80 +8,80 @@ from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from inference.data.admin_db import AdminDB
from inference.data.repositories import TokenRepository
from inference.data.admin_models import AdminToken
from inference.web.core.auth import (
get_admin_db,
reset_admin_db,
get_token_repository,
reset_token_repository,
validate_admin_token,
)
@pytest.fixture
def mock_admin_db():
"""Create a mock AdminDB."""
db = MagicMock(spec=AdminDB)
db.is_valid_admin_token.return_value = True
return db
def mock_token_repo():
"""Create a mock TokenRepository."""
repo = MagicMock(spec=TokenRepository)
repo.is_valid.return_value = True
return repo
@pytest.fixture(autouse=True)
def reset_db():
"""Reset admin DB after each test."""
def reset_repo():
"""Reset token repository after each test."""
yield
reset_admin_db()
reset_token_repository()
class TestValidateAdminToken:
"""Tests for validate_admin_token dependency."""
def test_missing_token_raises_401(self, mock_admin_db):
def test_missing_token_raises_401(self, mock_token_repo):
"""Test that missing token raises 401."""
import asyncio
with pytest.raises(HTTPException) as exc_info:
asyncio.get_event_loop().run_until_complete(
validate_admin_token(None, mock_admin_db)
validate_admin_token(None, mock_token_repo)
)
assert exc_info.value.status_code == 401
assert "Admin token required" in exc_info.value.detail
def test_invalid_token_raises_401(self, mock_admin_db):
def test_invalid_token_raises_401(self, mock_token_repo):
"""Test that invalid token raises 401."""
import asyncio
mock_admin_db.is_valid_admin_token.return_value = False
mock_token_repo.is_valid.return_value = False
with pytest.raises(HTTPException) as exc_info:
asyncio.get_event_loop().run_until_complete(
validate_admin_token("invalid-token", mock_admin_db)
validate_admin_token("invalid-token", mock_token_repo)
)
assert exc_info.value.status_code == 401
assert "Invalid or expired" in exc_info.value.detail
def test_valid_token_returns_token(self, mock_admin_db):
def test_valid_token_returns_token(self, mock_token_repo):
"""Test that valid token is returned."""
import asyncio
token = "valid-test-token"
mock_admin_db.is_valid_admin_token.return_value = True
mock_token_repo.is_valid.return_value = True
result = asyncio.get_event_loop().run_until_complete(
validate_admin_token(token, mock_admin_db)
validate_admin_token(token, mock_token_repo)
)
assert result == token
mock_admin_db.update_admin_token_usage.assert_called_once_with(token)
mock_token_repo.update_usage.assert_called_once_with(token)
class TestAdminDB:
"""Tests for AdminDB operations."""
class TestTokenRepository:
"""Tests for TokenRepository operations."""
def test_is_valid_admin_token_active(self):
def test_is_valid_active_token(self):
"""Test valid active token."""
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -93,12 +93,12 @@ class TestAdminDB:
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is True
repo = TokenRepository()
assert repo.is_valid("test-token") is True
def test_is_valid_admin_token_inactive(self):
def test_is_valid_inactive_token(self):
"""Test inactive token."""
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -110,12 +110,12 @@ class TestAdminDB:
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is False
repo = TokenRepository()
assert repo.is_valid("test-token") is False
def test_is_valid_admin_token_expired(self):
def test_is_valid_expired_token(self):
"""Test expired token."""
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -127,36 +127,38 @@ class TestAdminDB:
)
mock_session.get.return_value = mock_token
db = AdminDB()
assert db.is_valid_admin_token("test-token") is False
repo = TokenRepository()
# Need to also mock _now() to ensure proper comparison
with patch.object(repo, "_now", return_value=datetime.utcnow()):
assert repo.is_valid("test-token") is False
def test_is_valid_admin_token_not_found(self):
def test_is_valid_token_not_found(self):
"""Test token not found."""
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None
db = AdminDB()
assert db.is_valid_admin_token("nonexistent") is False
repo = TokenRepository()
assert repo.is_valid("nonexistent") is False
class TestGetAdminDb:
"""Tests for get_admin_db function."""
class TestGetTokenRepository:
"""Tests for get_token_repository function."""
def test_returns_singleton(self):
"""Test that get_admin_db returns singleton."""
reset_admin_db()
"""Test that get_token_repository returns singleton."""
reset_token_repository()
db1 = get_admin_db()
db2 = get_admin_db()
repo1 = get_token_repository()
repo2 = get_token_repository()
assert db1 is db2
assert repo1 is repo2
def test_reset_clears_singleton(self):
"""Test that reset clears singleton."""
db1 = get_admin_db()
reset_admin_db()
db2 = get_admin_db()
repo1 = get_token_repository()
reset_token_repository()
repo2 = get_token_repository()
assert db1 is not db2
assert repo1 is not repo2

View File

@@ -11,7 +11,12 @@ from fastapi.testclient import TestClient
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
)
class MockAdminDocument:
@@ -59,14 +64,14 @@ class MockAnnotation:
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing enhanced features."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing enhanced features."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.annotations = {} # Shared reference for filtering
def get_documents_by_token(
def get_paginated(
self,
admin_token=None,
status=None,
@@ -103,32 +108,51 @@ class MockAdminDB:
total = len(docs)
return docs[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def count_documents_by_status(self, admin_token):
def count_by_status(self, admin_token=None):
"""Count documents by status."""
counts = {}
for doc in self.documents.values():
if doc.admin_token == admin_token:
if admin_token is None or doc.admin_token == admin_token:
counts[doc.status] = counts.get(doc.status, 0) + 1
return counts
def get_document_by_token(self, document_id, admin_token):
def get(self, document_id):
"""Get single document by ID."""
return self.documents.get(document_id)
def get_by_token(self, document_id, admin_token=None):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
if doc and (admin_token is None or doc.admin_token == admin_token):
return doc
return None
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing enhanced features."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing enhanced features."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return [] # No training history in this test
return self.training_links.get(str(document_id), [])
def get_training_task(self, task_id):
def get(self, task_id):
"""Get training task by ID."""
return None # No training tasks in this test
return self.training_tasks.get(str(task_id))
@pytest.fixture
@@ -136,8 +160,10 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
# Add test documents
doc1 = MockAdminDocument(
@@ -162,19 +188,19 @@ def app():
batch_id=None
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations to doc1 and doc2
mock_db.annotations[str(doc1.document_id)] = [
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(
document_id=doc1.document_id,
class_name="invoice_number",
text_value="INV-001"
)
]
mock_db.annotations[str(doc2.document_id)] = [
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(
document_id=doc2.document_id,
class_id=6,
@@ -189,9 +215,14 @@ def app():
)
]
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
# Include router
router = create_documents_router(StorageConfig())

View File

@@ -10,7 +10,10 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
)
class MockAdminDocument:
@@ -34,23 +37,27 @@ class MockAdminDocument:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing annotation locks."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing annotation locks."""
def __init__(self):
self.documents = {}
def get_document_by_token(self, document_id, admin_token):
def get(self, document_id):
"""Get single document by ID."""
return self.documents.get(document_id)
def get_by_token(self, document_id, admin_token=None):
"""Get single document by ID and token."""
doc = self.documents.get(document_id)
if doc and doc.admin_token == admin_token:
if doc and (admin_token is None or doc.admin_token == admin_token):
return doc
return None
def acquire_annotation_lock(self, document_id, admin_token, duration_seconds=300):
def acquire_annotation_lock(self, document_id, admin_token=None, duration_seconds=300):
"""Acquire annotation lock for a document."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
if not doc:
return None
# Check if already locked
@@ -62,20 +69,20 @@ class MockAdminDB:
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
return doc
def release_annotation_lock(self, document_id, admin_token, force=False):
def release_annotation_lock(self, document_id, admin_token=None, force=False):
"""Release annotation lock for a document."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
if not doc:
return None
# Release lock
doc.annotation_lock_until = None
return doc
def extend_annotation_lock(self, document_id, admin_token, additional_seconds=300):
def extend_annotation_lock(self, document_id, admin_token=None, additional_seconds=300):
"""Extend an existing annotation lock."""
doc = self.documents.get(document_id)
if not doc or doc.admin_token != admin_token:
if not doc:
return None
# Check if lock exists and is still valid
@@ -93,8 +100,8 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repository
mock_document_repo = MockDocumentRepository()
# Add test document
doc1 = MockAdminDocument(
@@ -103,11 +110,11 @@ def app():
upload_source="ui",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc1.document_id)] = doc1
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
# Include router
router = create_locks_router()
@@ -124,9 +131,9 @@ def client(app):
@pytest.fixture
def document_id(app):
"""Get document ID from the mock DB."""
mock_db = app.dependency_overrides[get_admin_db]()
return str(list(mock_db.documents.keys())[0])
"""Get document ID from the mock repository."""
mock_document_repo = app.dependency_overrides[get_document_repository]()
return str(list(mock_document_repo.documents.keys())[0])
class TestAnnotationLocks:

View File

@@ -9,8 +9,12 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.annotations import create_annotation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.admin.annotations import (
create_annotation_router,
get_doc_repository,
get_ann_repository,
)
from inference.web.core.auth import validate_admin_token
class MockAdminDocument:
@@ -73,22 +77,40 @@ class MockAnnotationHistory:
self.created_at = kwargs.get('created_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 5."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing Phase 5."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.annotation_history = {}
def get_document_by_token(self, document_id, admin_token):
def get(self, document_id):
"""Get document by ID."""
return self.documents.get(str(document_id))
def get_by_token(self, document_id, admin_token=None):
"""Get document by ID and token."""
doc = self.documents.get(str(document_id))
if doc and doc.admin_token == admin_token:
if doc and (admin_token is None or doc.admin_token == admin_token):
return doc
return None
def verify_annotation(self, annotation_id, admin_token):
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing Phase 5."""
def __init__(self):
self.annotations = {}
self.annotation_history = {}
def get(self, annotation_id):
"""Get annotation by ID."""
return self.annotations.get(str(annotation_id))
def get_for_document(self, document_id, page_number=None):
"""Get annotations for a document."""
return [a for a in self.annotations.values() if str(a.document_id) == str(document_id)]
def verify(self, annotation_id, admin_token):
"""Mark annotation as verified."""
annotation = self.annotations.get(str(annotation_id))
if annotation:
@@ -98,7 +120,7 @@ class MockAdminDB:
return annotation
return None
def override_annotation(
def override(
self,
annotation_id,
admin_token,
@@ -131,7 +153,7 @@ class MockAdminDB:
return annotation
return None
def get_annotation_history(self, annotation_id):
def get_history(self, annotation_id):
"""Get annotation history."""
return self.annotation_history.get(str(annotation_id), [])
@@ -141,15 +163,16 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
# Add test document
doc1 = MockAdminDocument(
filename="TEST001.pdf",
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc1.document_id)] = doc1
# Add test annotations
ann1 = MockAnnotation(
@@ -169,8 +192,8 @@ def app():
confidence=0.98,
)
mock_db.annotations[str(ann1.annotation_id)] = ann1
mock_db.annotations[str(ann2.annotation_id)] = ann2
mock_annotation_repo.annotations[str(ann1.annotation_id)] = ann1
mock_annotation_repo.annotations[str(ann2.annotation_id)] = ann2
# Store document ID and annotation IDs for tests
app.state.document_id = str(doc1.document_id)
@@ -179,7 +202,8 @@ def app():
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_doc_repository] = lambda: mock_document_repo
app.dependency_overrides[get_ann_repository] = lambda: mock_annotation_repo
# Include router
router = create_annotation_router()

View File

@@ -11,7 +11,11 @@ from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_dataset_repository,
)
TEST_ADMIN_TOKEN = "test-admin-token-12345"
@@ -26,18 +30,27 @@ def admin_token() -> str:
@pytest.fixture
def mock_admin_db() -> MagicMock:
"""Create a mock AdminDB for testing."""
def mock_document_repo() -> MagicMock:
"""Create a mock DocumentRepository for testing."""
mock = MagicMock()
# Default return values
mock.get_document_by_token.return_value = None
mock.get_dataset.return_value = None
mock.get_augmented_datasets.return_value = ([], 0)
mock.get.return_value = None
mock.get_by_token.return_value = None
return mock
@pytest.fixture
def admin_client(mock_admin_db: MagicMock) -> TestClient:
def mock_dataset_repo() -> MagicMock:
"""Create a mock DatasetRepository for testing."""
mock = MagicMock()
# Default return values
mock.get.return_value = None
mock.get_paginated.return_value = ([], 0)
return mock
@pytest.fixture
def admin_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
"""Create test client with admin authentication."""
app = FastAPI()
@@ -45,11 +58,15 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
def get_token_override():
return TEST_ADMIN_TOKEN
def get_db_override():
return mock_admin_db
def get_document_repo_override():
return mock_document_repo
def get_dataset_repo_override():
return mock_dataset_repo
app.dependency_overrides[validate_admin_token] = get_token_override
app.dependency_overrides[get_admin_db] = get_db_override
app.dependency_overrides[get_document_repository] = get_document_repo_override
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
# Include router - the router already has /augmentation prefix
# so we add /api/v1/admin to get /api/v1/admin/augmentation
@@ -60,15 +77,19 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
@pytest.fixture
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
def unauthenticated_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
"""Create test client WITHOUT admin authentication override."""
app = FastAPI()
# Only override the database, NOT the token validation
def get_db_override():
return mock_admin_db
# Only override the repositories, NOT the token validation
def get_document_repo_override():
return mock_document_repo
app.dependency_overrides[get_admin_db] = get_db_override
def get_dataset_repo_override():
return mock_dataset_repo
app.dependency_overrides[get_document_repository] = get_document_repo_override
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
router = create_augmentation_router()
app.include_router(router, prefix="/api/v1/admin")
@@ -142,13 +163,13 @@ class TestAugmentationPreviewEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
mock_document_repo: MagicMock,
) -> None:
"""Test previewing augmentation on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
mock_document_repo.get.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
@@ -218,13 +239,13 @@ class TestAugmentationPreviewConfigEndpoint:
admin_client: TestClient,
admin_token: str,
sample_document_id: str,
mock_admin_db: MagicMock,
mock_document_repo: MagicMock,
) -> None:
"""Test previewing full config on a document."""
# Mock document exists
mock_document = MagicMock()
mock_document.images_dir = "/fake/path"
mock_admin_db.get_document.return_value = mock_document
mock_document_repo.get.return_value = mock_document
# Create a fake image (100x100 RGB)
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
@@ -260,13 +281,13 @@ class TestAugmentationBatchEndpoint:
admin_client: TestClient,
admin_token: str,
sample_dataset_id: str,
mock_admin_db: MagicMock,
mock_dataset_repo: MagicMock,
) -> None:
"""Test creating augmented dataset."""
# Mock dataset exists
mock_dataset = MagicMock()
mock_dataset.total_images = 100
mock_admin_db.get_dataset.return_value = mock_dataset
mock_dataset_repo.get.return_value = mock_dataset
response = admin_client.post(
"/api/v1/admin/augmentation/batch",

View File

@@ -9,7 +9,6 @@ from unittest.mock import Mock, MagicMock
from uuid import uuid4
from inference.web.services.autolabel import AutoLabelService
from inference.data.admin_db import AdminDB
class MockDocument:
@@ -23,19 +22,18 @@ class MockDocument:
self.auto_label_error = None
class MockAdminDB:
"""Mock AdminDB for testing."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing."""
def __init__(self):
self.documents = {}
self.annotations = []
self.status_updates = []
def get_document(self, document_id):
def get(self, document_id):
"""Get document by ID."""
return self.documents.get(str(document_id))
def update_document_status(
def update_status(
self,
document_id,
status=None,
@@ -58,19 +56,32 @@ class MockAdminDB:
if auto_label_error:
doc.auto_label_error = auto_label_error
def delete_annotations_for_document(self, document_id, source=None):
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing."""
def __init__(self):
self.annotations = []
def delete_for_document(self, document_id, source=None):
"""Mock delete annotations."""
return 0
def create_annotations_batch(self, annotations):
def create_batch(self, annotations):
"""Mock create annotations."""
self.annotations.extend(annotations)
@pytest.fixture
def mock_db():
"""Create mock admin DB."""
return MockAdminDB()
def mock_doc_repo():
"""Create mock document repository."""
return MockDocumentRepository()
@pytest.fixture
def mock_ann_repo():
"""Create mock annotation repository."""
return MockAnnotationRepository()
@pytest.fixture
@@ -82,10 +93,14 @@ def auto_label_service(monkeypatch):
service._ocr_engine.extract_from_image = Mock(return_value=[])
# Mock the image processing methods to avoid file I/O errors
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
def mock_process_image(self, document_id, image_path, field_values, ann_repo, page_number=1):
return 0 # No annotations created (mocked)
def mock_process_pdf(self, document_id, pdf_path, field_values, ann_repo):
return 0 # No annotations created (mocked)
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
monkeypatch.setattr(AutoLabelService, "_process_pdf", mock_process_pdf)
return service
@@ -93,11 +108,11 @@ def auto_label_service(monkeypatch):
class TestAutoLabelWithLocks:
"""Tests for auto-label service with lock integration."""
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling succeeds on unlocked document."""
# Create test document (unlocked)
document_id = str(uuid4())
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=None,
)
@@ -111,21 +126,22 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should succeed
assert result["status"] == "completed"
# Verify status was updated to running and then completed
assert len(mock_db.status_updates) >= 2
assert mock_db.status_updates[0]["auto_label_status"] == "running"
assert len(mock_doc_repo.status_updates) >= 2
assert mock_doc_repo.status_updates[0]["auto_label_status"] == "running"
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_locked_document_fails(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling fails on locked document."""
# Create test document (locked for 1 hour)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -139,7 +155,8 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should fail
@@ -150,15 +167,15 @@ class TestAutoLabelWithLocks:
# Verify status was updated to failed
assert any(
update["auto_label_status"] == "failed"
for update in mock_db.status_updates
for update in mock_doc_repo.status_updates
)
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling succeeds when lock has expired."""
# Create test document (lock expired 1 hour ago)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -172,18 +189,19 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should succeed (lock expired)
assert result["status"] == "completed"
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_skip_lock_check(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -197,14 +215,15 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
skip_lock_check=True, # Bypass lock check
)
# Should succeed even though document is locked
assert result["status"] == "completed"
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_document_not_found(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling fails when document doesn't exist."""
# Create dummy file
test_file = tmp_path / "test.png"
@@ -215,19 +234,20 @@ class TestAutoLabelWithLocks:
document_id=str(uuid4()),
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should fail
assert result["status"] == "failed"
assert "not found" in result["error"]
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test that lock check is enabled by default."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -241,7 +261,8 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
# skip_lock_check not specified, should default to False
)

View File

@@ -11,20 +11,20 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.batch.routes import router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.api.v1.batch.routes import router, get_batch_repository
from inference.web.core.auth import validate_admin_token
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
class MockAdminDB:
"""Mock AdminDB for testing."""
class MockBatchUploadRepository:
"""Mock BatchUploadRepository for testing."""
def __init__(self):
self.batches = {}
self.batch_files = {}
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
def create(self, admin_token, filename, file_size, upload_source="ui"):
batch_id = uuid4()
batch = type('BatchUpload', (), {
'batch_id': batch_id,
@@ -46,13 +46,13 @@ class MockAdminDB:
self.batches[batch_id] = batch
return batch
def update_batch_upload(self, batch_id, **kwargs):
def update(self, batch_id, **kwargs):
if batch_id in self.batches:
batch = self.batches[batch_id]
for key, value in kwargs.items():
setattr(batch, key, value)
def create_batch_upload_file(self, batch_id, filename, **kwargs):
def create_file(self, batch_id, filename, **kwargs):
file_id = uuid4()
defaults = {
'file_id': file_id,
@@ -70,7 +70,7 @@ class MockAdminDB:
self.batch_files[batch_id].append(file_record)
return file_record
def update_batch_upload_file(self, file_id, **kwargs):
def update_file(self, file_id, **kwargs):
for files in self.batch_files.values():
for file_record in files:
if file_record.file_id == file_id:
@@ -78,7 +78,7 @@ class MockAdminDB:
setattr(file_record, key, value)
return
def get_batch_upload(self, batch_id):
def get(self, batch_id):
return self.batches.get(batch_id, type('BatchUpload', (), {
'batch_id': batch_id,
'admin_token': 'test-token',
@@ -95,12 +95,15 @@ class MockAdminDB:
'completed_at': datetime.utcnow(),
})())
def get_batch_upload_files(self, batch_id):
def get_files(self, batch_id):
return self.batch_files.get(batch_id, [])
def get_batch_uploads_by_token(self, admin_token, limit=50, offset=0):
def get_paginated(self, admin_token=None, limit=50, offset=0):
"""Get batches filtered by admin token with pagination."""
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
if admin_token:
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
else:
token_batches = list(self.batches.values())
total = len(token_batches)
return token_batches[offset:offset+limit], total
@@ -110,15 +113,15 @@ def app():
"""Create test FastAPI app with mocked dependencies."""
app = FastAPI()
# Create mock admin DB
mock_admin_db = MockAdminDB()
# Create mock batch upload repository
mock_batch_upload_repo = MockBatchUploadRepository()
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_admin_db
app.dependency_overrides[get_batch_repository] = lambda: mock_batch_upload_repo
# Initialize batch queue with mock service
batch_service = BatchUploadService(mock_admin_db)
batch_service = BatchUploadService(mock_batch_upload_repo)
init_batch_queue(batch_service)
app.include_router(router)

View File

@@ -9,19 +9,18 @@ from uuid import uuid4
import pytest
from inference.data.admin_db import AdminDB
from inference.web.services.batch_upload import BatchUploadService
@pytest.fixture
def admin_db():
"""Mock admin database for testing."""
class MockAdminDB:
def batch_repo():
"""Mock batch upload repository for testing."""
class MockBatchUploadRepository:
def __init__(self):
self.batches = {}
self.batch_files = {}
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
def create(self, admin_token, filename, file_size, upload_source):
batch_id = uuid4()
batch = type('BatchUpload', (), {
'batch_id': batch_id,
@@ -43,13 +42,13 @@ def admin_db():
self.batches[batch_id] = batch
return batch
def update_batch_upload(self, batch_id, **kwargs):
def update(self, batch_id, **kwargs):
if batch_id in self.batches:
batch = self.batches[batch_id]
for key, value in kwargs.items():
setattr(batch, key, value)
def create_batch_upload_file(self, batch_id, filename, **kwargs):
def create_file(self, batch_id, filename, **kwargs):
file_id = uuid4()
# Set defaults for attributes
defaults = {
@@ -68,7 +67,7 @@ def admin_db():
self.batch_files[batch_id].append(file_record)
return file_record
def update_batch_upload_file(self, file_id, **kwargs):
def update_file(self, file_id, **kwargs):
for files in self.batch_files.values():
for file_record in files:
if file_record.file_id == file_id:
@@ -76,19 +75,19 @@ def admin_db():
setattr(file_record, key, value)
return
def get_batch_upload(self, batch_id):
def get(self, batch_id):
return self.batches.get(batch_id)
def get_batch_upload_files(self, batch_id):
def get_files(self, batch_id):
return self.batch_files.get(batch_id, [])
return MockAdminDB()
return MockBatchUploadRepository()
@pytest.fixture
def batch_service(admin_db):
def batch_service(batch_repo):
"""Batch upload service instance."""
return BatchUploadService(admin_db)
return BatchUploadService(batch_repo)
def create_test_zip(files):
@@ -194,7 +193,7 @@ INV002,F2024-002,2024-01-16,2500.00,7350087654321,123-4567,C124
assert csv_data["INV001"]["Amount"] == "1500.00"
assert csv_data["INV001"]["customer_number"] == "C123"
def test_get_batch_status(self, batch_service, admin_db):
def test_get_batch_status(self, batch_service, batch_repo):
"""Test getting batch upload status."""
# Create a batch
zip_content = create_test_zip({"INV001.pdf": b"%PDF-1.4 test"})

View File

@@ -16,7 +16,6 @@ from inference.data.admin_models import (
AdminAnnotation,
AdminDocument,
TrainingDataset,
FIELD_CLASSES,
)
@@ -35,10 +34,10 @@ def tmp_admin_images(tmp_path):
@pytest.fixture
def mock_admin_db():
"""Mock AdminDB with dataset and document methods."""
db = MagicMock()
db.create_dataset.return_value = TrainingDataset(
def mock_datasets_repo():
"""Mock DatasetRepository."""
repo = MagicMock()
repo.create.return_value = TrainingDataset(
dataset_id=uuid4(),
name="test-dataset",
status="building",
@@ -46,7 +45,19 @@ def mock_admin_db():
val_ratio=0.1,
seed=42,
)
return db
return repo
@pytest.fixture
def mock_documents_repo():
"""Mock DocumentRepository."""
return MagicMock()
@pytest.fixture
def mock_annotations_repo():
"""Mock AnnotationRepository."""
return MagicMock()
@pytest.fixture
@@ -60,6 +71,7 @@ def sample_documents(tmp_admin_images):
doc.filename = f"{doc_id}.pdf"
doc.page_count = 2
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
doc.group_key = None # Default to no group
docs.append(doc)
return docs
@@ -89,21 +101,27 @@ class TestDatasetBuilder:
"""Tests for DatasetBuilder."""
def test_build_creates_directory_structure(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
from inference.web.services.dataset_builder import DatasetBuilder
dataset_dir = tmp_path / "datasets" / "test"
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# Mock DB calls
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
# Mock repo calls
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -119,18 +137,24 @@ class TestDatasetBuilder:
assert (result_dir / "labels" / split).exists()
def test_build_copies_images(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""Images should be copied from admin_images to dataset folder."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
result = builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -149,18 +173,24 @@ class TestDatasetBuilder:
assert total_images == 10 # 5 docs * 2 pages
def test_build_generates_yolo_labels(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""YOLO label files should be generated with correct format."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -187,18 +217,24 @@ class TestDatasetBuilder:
assert 0 <= float(parts[2]) <= 1 # y_center
def test_build_generates_data_yaml(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""data.yaml should be generated with correct field classes."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -217,18 +253,24 @@ class TestDatasetBuilder:
assert "invoice_number" in content
def test_build_splits_documents_correctly(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""Documents should be split into train/val/test according to ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -238,8 +280,8 @@ class TestDatasetBuilder:
admin_images_dir=tmp_path / "admin_images",
)
# Verify add_dataset_documents was called with correct splits
call_args = mock_admin_db.add_dataset_documents.call_args
# Verify add_documents was called with correct splits
call_args = mock_datasets_repo.add_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]
assert "train" in splits
@@ -248,18 +290,24 @@ class TestDatasetBuilder:
assert train_count >= 3 # At least 3 of 5 should be train
def test_build_updates_status_to_ready(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""After successful build, dataset status should be updated to 'ready'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -269,22 +317,27 @@ class TestDatasetBuilder:
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
mock_datasets_repo.update_status.assert_called_once()
call_kwargs = mock_datasets_repo.update_status.call_args[1]
assert call_kwargs["status"] == "ready"
assert call_kwargs["total_documents"] == 5
assert call_kwargs["total_images"] == 10
def test_build_sets_failed_on_error(
self, tmp_path, mock_admin_db
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""If build fails, dataset status should be set to 'failed'."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = [] # No docs found
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
with pytest.raises(ValueError):
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
@@ -295,27 +348,33 @@ class TestDatasetBuilder:
admin_images_dir=tmp_path / "admin_images",
)
mock_admin_db.update_dataset_status.assert_called_once()
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
mock_datasets_repo.update_status.assert_called_once()
call_kwargs = mock_datasets_repo.update_status.call_args[1]
assert call_kwargs["status"] == "failed"
def test_build_with_seed_produces_deterministic_splits(
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
sample_documents, sample_annotations
):
"""Same seed should produce same splits."""
from inference.web.services.dataset_builder import DatasetBuilder
results = []
for _ in range(2):
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = sample_documents
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = sample_documents
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
sample_annotations.get(str(doc_id), [])
)
mock_admin_db.add_dataset_documents.reset_mock()
mock_admin_db.update_dataset_status.reset_mock()
mock_datasets_repo.add_documents.reset_mock()
mock_datasets_repo.update_status.reset_mock()
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in sample_documents],
@@ -324,7 +383,7 @@ class TestDatasetBuilder:
seed=42,
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
call_args = mock_datasets_repo.add_documents.call_args
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
results.append([(d["document_id"], d["split"]) for d in docs])
@@ -342,11 +401,18 @@ class TestAssignSplitsByGroup:
doc.page_count = 1
return doc
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
def test_single_doc_groups_are_distributed(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with unique group_key are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# 3 documents, each with unique group_key
docs = [
@@ -363,11 +429,18 @@ class TestAssignSplitsByGroup:
assert train_count >= 1
assert val_count >= 1 # Ensure val is not empty
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
def test_null_group_key_treated_as_single_doc_group(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
docs = [
self._make_mock_doc(uuid4(), group_key=None),
@@ -384,11 +457,18 @@ class TestAssignSplitsByGroup:
assert train_count >= 1
assert val_count >= 1
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
def test_multi_doc_groups_stay_together(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with same group_key should be assigned to the same split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# 6 documents in 2 groups
docs = [
@@ -410,11 +490,18 @@ class TestAssignSplitsByGroup:
splits_b = [result[str(d.document_id)] for d in docs[3:]]
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
def test_multi_doc_groups_split_by_ratio(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Multi-doc groups should be split according to train/val/test ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# 10 groups with 2 docs each
docs = []
@@ -445,11 +532,18 @@ class TestAssignSplitsByGroup:
assert split_counts["val"] >= 1
assert split_counts["val"] <= 3
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
def test_mixed_single_and_multi_doc_groups(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Mix of single-doc and multi-doc groups should be handled correctly."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
docs = [
# Single-doc groups
@@ -476,11 +570,18 @@ class TestAssignSplitsByGroup:
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
def test_deterministic_with_seed(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Same seed should produce same split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
@@ -496,11 +597,18 @@ class TestAssignSplitsByGroup:
assert result1 == result2
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
def test_different_seed_may_produce_different_splits(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Different seeds should potentially produce different split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# Many groups to increase chance of different results
docs = []
@@ -515,11 +623,18 @@ class TestAssignSplitsByGroup:
# Results should be different (very likely with 20 groups)
assert result1 != result2
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
def test_all_docs_assigned(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Every document should be assigned a split."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
docs = [
self._make_mock_doc(uuid4(), group_key="group-A"),
@@ -535,21 +650,35 @@ class TestAssignSplitsByGroup:
assert str(doc.document_id) in result
assert result[str(doc.document_id)] in ["train", "val", "test"]
def test_empty_documents_list(self, tmp_path, mock_admin_db):
def test_empty_documents_list(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Empty document list should return empty result."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
assert result == {}
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
def test_only_multi_doc_groups(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""When all groups have multiple docs, splits should follow ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
# 5 groups with 3 docs each
docs = []
@@ -574,11 +703,18 @@ class TestAssignSplitsByGroup:
assert split_counts["train"] >= 2
assert split_counts["train"] <= 4
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
def test_only_single_doc_groups(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""When all groups have single doc, they are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
docs = [
self._make_mock_doc(uuid4(), group_key="unique-1"),
@@ -658,20 +794,26 @@ class TestBuildDatasetWithGroupKey:
return annotations
def test_build_respects_group_key_splits(
self, grouped_documents, grouped_annotations, mock_admin_db
self, grouped_documents, grouped_annotations,
mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""build_dataset should use group_key for split assignment."""
from inference.web.services.dataset_builder import DatasetBuilder
tmp_path, docs = grouped_documents
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = docs
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
grouped_annotations.get(str(doc_id), [])
)
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
@@ -681,8 +823,8 @@ class TestBuildDatasetWithGroupKey:
admin_images_dir=tmp_path / "admin_images",
)
# Get the document splits from add_dataset_documents call
call_args = mock_admin_db.add_dataset_documents.call_args
# Get the document splits from add_documents call
call_args = mock_datasets_repo.add_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
# Build mapping of doc_id -> split
@@ -701,7 +843,9 @@ class TestBuildDatasetWithGroupKey:
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
def test_build_with_all_same_group_key(
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""All docs with same group_key should go to same split."""
from inference.web.services.dataset_builder import DatasetBuilder
@@ -720,11 +864,16 @@ class TestBuildDatasetWithGroupKey:
doc.group_key = "same-group"
docs.append(doc)
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
mock_admin_db.get_documents_by_ids.return_value = docs
mock_admin_db.get_annotations_for_document.return_value = []
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
documents_repo=mock_documents_repo,
annotations_repo=mock_annotations_repo,
base_dir=tmp_path / "datasets",
)
mock_documents_repo.get_by_ids.return_value = docs
mock_annotations_repo.get_for_document.return_value = []
dataset = mock_admin_db.create_dataset.return_value
dataset = mock_datasets_repo.create.return_value
builder.build_dataset(
dataset_id=str(dataset.dataset_id),
document_ids=[str(d.document_id) for d in docs],
@@ -734,7 +883,7 @@ class TestBuildDatasetWithGroupKey:
admin_images_dir=tmp_path / "admin_images",
)
call_args = mock_admin_db.add_dataset_documents.call_args
call_args = mock_datasets_repo.add_documents.call_args
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
splits = [d["split"] for d in docs_added]

View File

@@ -72,6 +72,36 @@ def _find_endpoint(name: str):
raise AssertionError(f"Endpoint {name} not found")
@pytest.fixture
def mock_datasets_repo():
"""Mock DatasetRepository."""
return MagicMock()
@pytest.fixture
def mock_documents_repo():
"""Mock DocumentRepository."""
return MagicMock()
@pytest.fixture
def mock_annotations_repo():
"""Mock AnnotationRepository."""
return MagicMock()
@pytest.fixture
def mock_models_repo():
"""Mock ModelVersionRepository."""
return MagicMock()
@pytest.fixture
def mock_tasks_repo():
"""Mock TrainingTaskRepository."""
return MagicMock()
class TestCreateDatasetRoute:
"""Tests for POST /admin/training/datasets."""
@@ -80,11 +110,12 @@ class TestCreateDatasetRoute:
paths = [route.path for route in router.routes]
assert any("datasets" in p for p in paths)
def test_create_dataset_calls_builder(self):
def test_create_dataset_calls_builder(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.create.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
mock_builder.build_dataset.return_value = {
@@ -101,20 +132,30 @@ class TestCreateDatasetRoute:
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
) as mock_cls:
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
result = asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
mock_db.create_dataset.assert_called_once()
mock_datasets_repo.create.assert_called_once()
mock_builder.build_dataset.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
assert result.name == "test-dataset"
def test_create_dataset_fails_with_less_than_10_documents(self):
def test_create_dataset_fails_with_less_than_10_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test that creating dataset fails if fewer than 10 documents provided."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# Only 2 documents - should fail
request = DatasetCreateRequest(
name="test-dataset",
@@ -124,20 +165,26 @@ class TestCreateDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
assert "got 2" in exc_info.value.detail
# Ensure DB was never called since validation failed first
mock_db.create_dataset.assert_not_called()
# Ensure repo was never called since validation failed first
mock_datasets_repo.create.assert_not_called()
def test_create_dataset_fails_with_9_documents(self):
def test_create_dataset_fails_with_9_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test boundary condition: 9 documents should fail."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
# 9 documents - just under the limit
request = DatasetCreateRequest(
name="test-dataset",
@@ -147,17 +194,24 @@ class TestCreateDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
assert exc_info.value.status_code == 400
assert "Minimum 10 documents required" in exc_info.value.detail
def test_create_dataset_succeeds_with_exactly_10_documents(self):
def test_create_dataset_succeeds_with_exactly_10_documents(
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Test boundary condition: exactly 10 documents should succeed."""
fn = _find_endpoint("create_dataset")
mock_db = MagicMock()
mock_db.create_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.create.return_value = _make_dataset(status="building")
mock_builder = MagicMock()
@@ -170,25 +224,40 @@ class TestCreateDatasetRoute:
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
):
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
result = asyncio.run(fn(
request=request,
admin_token=TEST_TOKEN,
datasets=mock_datasets_repo,
docs=mock_documents_repo,
annotations=mock_annotations_repo,
))
mock_db.create_dataset.assert_called_once()
mock_datasets_repo.create.assert_called_once()
assert result.dataset_id == TEST_DATASET_UUID
class TestListDatasetsRoute:
"""Tests for GET /admin/training/datasets."""
def test_list_datasets(self):
def test_list_datasets(self, mock_datasets_repo):
fn = _find_endpoint("list_datasets")
mock_db = MagicMock()
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
mock_datasets_repo.get_paginated.return_value = ([_make_dataset()], 1)
# Mock the active training tasks lookup to return empty dict
mock_db.get_active_training_tasks_for_datasets.return_value = {}
mock_datasets_repo.get_active_training_tasks.return_value = {}
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
result = asyncio.run(fn(
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
status=None,
limit=20,
offset=0,
))
assert result.total == 1
assert len(result.datasets) == 1
@@ -198,82 +267,103 @@ class TestListDatasetsRoute:
class TestGetDatasetRoute:
"""Tests for GET /admin/training/datasets/{dataset_id}."""
def test_get_dataset_returns_detail(self):
def test_get_dataset_returns_detail(self, mock_datasets_repo):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset()
mock_db.get_dataset_documents.return_value = [
mock_datasets_repo.get.return_value = _make_dataset()
mock_datasets_repo.get_documents.return_value = [
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
]
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
assert result.dataset_id == TEST_DATASET_UUID
assert len(result.documents) == 2
def test_get_dataset_not_found(self):
def test_get_dataset_not_found(self, mock_datasets_repo):
fn = _find_endpoint("get_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = None
mock_datasets_repo.get.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
assert exc_info.value.status_code == 404
class TestDeleteDatasetRoute:
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
def test_delete_dataset(self):
def test_delete_dataset(self, mock_datasets_repo):
fn = _find_endpoint("delete_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
mock_datasets_repo.get.return_value = _make_dataset(dataset_path=None)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
))
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
mock_datasets_repo.delete.assert_called_once_with(TEST_DATASET_UUID)
assert result["message"] == "Dataset deleted"
class TestTrainFromDatasetRoute:
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
def test_train_from_ready_dataset(self):
def test_train_from_ready_dataset(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.create_training_task.return_value = TEST_TASK_UUID
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_tasks_repo.create.return_value = TEST_TASK_UUID
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert result.task_id == TEST_TASK_UUID
assert result.status == TrainingStatus.PENDING
mock_db.create_training_task.assert_called_once()
mock_tasks_repo.create.assert_called_once()
def test_train_from_building_dataset_fails(self):
def test_train_from_building_dataset_fails(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="building")
mock_datasets_repo.get.return_value = _make_dataset(status="building")
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert exc_info.value.status_code == 400
def test_incremental_training_with_base_model(self):
def test_incremental_training_with_base_model(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
"""Test training with base_model_version_id for incremental training."""
fn = _find_endpoint("train_from_dataset")
@@ -281,22 +371,28 @@ class TestTrainFromDatasetRoute:
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
mock_model_version.version = "1.0.0"
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = mock_model_version
mock_db.create_training_task.return_value = TEST_TASK_UUID
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_models_repo.get.return_value = mock_model_version
mock_tasks_repo.create.return_value = TEST_TASK_UUID
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
request = DatasetTrainRequest(name="incremental-train", config=config)
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
# Verify model version was looked up
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
mock_models_repo.get.assert_called_once_with(base_model_uuid)
# Verify task was created with finetune type
call_kwargs = mock_db.create_training_task.call_args[1]
call_kwargs = mock_tasks_repo.create.call_args[1]
assert call_kwargs["task_type"] == "finetune"
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
@@ -304,13 +400,14 @@ class TestTrainFromDatasetRoute:
assert result.task_id == TEST_TASK_UUID
assert "Incremental training" in result.message
def test_incremental_training_with_invalid_base_model_fails(self):
def test_incremental_training_with_invalid_base_model_fails(
self, mock_datasets_repo, mock_models_repo, mock_tasks_repo
):
"""Test that training fails if base_model_version_id doesn't exist."""
fn = _find_endpoint("train_from_dataset")
mock_db = MagicMock()
mock_db.get_dataset.return_value = _make_dataset(status="ready")
mock_db.get_model_version.return_value = None
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
mock_models_repo.get.return_value = None
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
config = TrainingConfig(base_model_version_id=base_model_uuid)
@@ -319,6 +416,13 @@ class TestTrainFromDatasetRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(
dataset_id=TEST_DATASET_UUID,
request=request,
admin_token=TEST_TOKEN,
datasets_repo=mock_datasets_repo,
models=mock_models_repo,
tasks=mock_tasks_repo,
))
assert exc_info.value.status_code == 404
assert "Base model version not found" in exc_info.value.detail

View File

@@ -3,7 +3,7 @@ Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
2. DatasetRepository update_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
@@ -56,12 +56,12 @@ class TestTrainingDatasetModel:
# =============================================================================
# Test AdminDB Methods
# Test DatasetRepository Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
class TestDatasetRepositoryTrainingStatus:
"""Tests for DatasetRepository.update_training_status method."""
@pytest.fixture
def mock_session(self):
@@ -69,8 +69,8 @@ class TestAdminDBDatasetTrainingStatus:
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
def test_update_training_status_sets_status(self, mock_session):
"""update_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -81,13 +81,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
@@ -96,8 +96,8 @@ class TestAdminDBDatasetTrainingStatus:
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
def test_update_training_status_sets_task_id(self, mock_session):
"""update_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -109,13 +109,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
@@ -123,10 +123,10 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
def test_update_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
"""update_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -137,13 +137,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
@@ -152,10 +152,10 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
def test_update_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
"""update_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -169,13 +169,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
@@ -183,18 +183,18 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
def test_update_training_status_handles_missing_dataset(self, mock_session):
"""update_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
repo = DatasetRepository()
# Should not raise
db.update_dataset_training_status(
repo.update_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
@@ -275,19 +275,24 @@ class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
def mock_datasets_repo(self):
"""Create mock DatasetRepository."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
mock.get.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
@pytest.fixture
def mock_training_tasks_repo(self):
"""Create mock TrainingTaskRepository."""
mock = MagicMock()
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
@@ -295,7 +300,8 @@ class TestSchedulerDatasetStatusUpdates:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
scheduler._datasets = mock_datasets_repo
scheduler._training_tasks = mock_training_tasks_repo
task_id = str(uuid4())
dataset_id = str(uuid4())
@@ -311,8 +317,8 @@ class TestSchedulerDatasetStatusUpdates:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
mock_datasets_repo.update_training_status.assert_called()
first_call = mock_datasets_repo.update_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id

View File

@@ -45,10 +45,10 @@ class TestDocumentListFilterByCategory:
"""Tests for filtering documents by category."""
@pytest.fixture
def mock_admin_db(self):
"""Create mock AdminDB."""
db = MagicMock()
db.is_valid_admin_token.return_value = True
def mock_document_repo(self):
"""Create mock DocumentRepository."""
repo = MagicMock()
repo.is_valid.return_value = True
# Mock documents with different categories
invoice_doc = MagicMock()
@@ -61,11 +61,11 @@ class TestDocumentListFilterByCategory:
letter_doc.category = "letter"
letter_doc.filename = "letter1.pdf"
db.get_documents.return_value = ([invoice_doc], 1)
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
return db
repo.get_paginated.return_value = ([invoice_doc], 1)
repo.get_categories.return_value = ["invoice", "letter", "receipt"]
return repo
def test_list_documents_accepts_category_filter(self, mock_admin_db):
def test_list_documents_accepts_category_filter(self, mock_document_repo):
"""Test list documents endpoint accepts category query parameter."""
# The endpoint should accept ?category=invoice parameter
# This test verifies the schema/query parameter exists
@@ -74,9 +74,9 @@ class TestDocumentListFilterByCategory:
# Schema should work with category filter applied
assert DocumentListResponse is not None
def test_get_document_categories_from_db(self, mock_admin_db):
"""Test fetching unique categories from database."""
categories = mock_admin_db.get_document_categories()
def test_get_document_categories_from_repo(self, mock_document_repo):
"""Test fetching unique categories from repository."""
categories = mock_document_repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
assert len(categories) == 3
@@ -122,24 +122,24 @@ class TestDocumentUploadWithCategory:
assert response.category == "invoice"
class TestAdminDBCategoryMethods:
"""Tests for AdminDB category-related methods."""
class TestDocumentRepositoryCategoryMethods:
"""Tests for DocumentRepository category-related methods."""
def test_get_document_categories_method_exists(self):
"""Test AdminDB has get_document_categories method."""
from inference.data.admin_db import AdminDB
def test_get_categories_method_exists(self):
"""Test DocumentRepository has get_categories method."""
from inference.data.repositories import DocumentRepository
db = AdminDB()
assert hasattr(db, "get_document_categories")
repo = DocumentRepository()
assert hasattr(repo, "get_categories")
def test_get_documents_accepts_category_filter(self):
"""Test get_documents_by_token method accepts category parameter."""
from inference.data.admin_db import AdminDB
def test_get_paginated_accepts_category_filter(self):
"""Test get_paginated method accepts category parameter."""
from inference.data.repositories import DocumentRepository
import inspect
db = AdminDB()
repo = DocumentRepository()
# Check the method exists and accepts category parameter
method = getattr(db, "get_documents_by_token", None)
method = getattr(repo, "get_paginated", None)
assert callable(method)
# Check category is in the method signature
@@ -150,12 +150,12 @@ class TestAdminDBCategoryMethods:
class TestUpdateDocumentCategory:
"""Tests for updating document category."""
def test_update_document_category_method_exists(self):
"""Test AdminDB has method to update document category."""
from inference.data.admin_db import AdminDB
def test_update_category_method_exists(self):
"""Test DocumentRepository has method to update document category."""
from inference.data.repositories import DocumentRepository
db = AdminDB()
assert hasattr(db, "update_document_category")
repo = DocumentRepository()
assert hasattr(repo, "update_category")
def test_update_request_schema(self):
"""Test DocumentUpdateRequest can update category."""

View File

@@ -63,6 +63,12 @@ def _find_endpoint(name: str):
raise AssertionError(f"Endpoint {name} not found")
@pytest.fixture
def mock_models_repo():
"""Mock ModelVersionRepository."""
return MagicMock()
class TestModelVersionRouterRegistration:
"""Tests that model version endpoints are registered."""
@@ -91,11 +97,10 @@ class TestModelVersionRouterRegistration:
class TestCreateModelVersionRoute:
"""Tests for POST /admin/training/models."""
def test_create_model_version(self):
def test_create_model_version(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
@@ -106,18 +111,17 @@ class TestCreateModelVersionRoute:
document_count=100,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.create_model_version.assert_called_once()
mock_models_repo.create.assert_called_once()
assert result.version_id == TEST_VERSION_UUID
assert result.status == "inactive"
assert result.message == "Model version created successfully"
def test_create_model_version_with_task_and_dataset(self):
def test_create_model_version_with_task_and_dataset(self, mock_models_repo):
fn = _find_endpoint("create_model_version")
mock_db = MagicMock()
mock_db.create_model_version.return_value = _make_model_version()
mock_models_repo.create.return_value = _make_model_version()
request = ModelVersionCreateRequest(
version="1.0.0",
@@ -127,9 +131,9 @@ class TestCreateModelVersionRoute:
dataset_id=TEST_DATASET_UUID,
)
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
call_kwargs = mock_db.create_model_version.call_args[1]
call_kwargs = mock_models_repo.create.call_args[1]
assert call_kwargs["task_id"] == TEST_TASK_UUID
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
@@ -137,30 +141,28 @@ class TestCreateModelVersionRoute:
class TestListModelVersionsRoute:
"""Tests for GET /admin/training/models."""
def test_list_model_versions(self):
def test_list_model_versions(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = (
mock_models_repo.get_paginated.return_value = (
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
2,
)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status=None, limit=20, offset=0))
assert result.total == 2
assert len(result.models) == 2
assert result.models[0].version == "1.0.0"
def test_list_model_versions_with_status_filter(self):
def test_list_model_versions_with_status_filter(self, mock_models_repo):
fn = _find_endpoint("list_model_versions")
mock_db = MagicMock()
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
mock_models_repo.get_paginated.return_value = ([_make_model_version(status="active", is_active=True)], 1)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status="active", limit=20, offset=0))
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
mock_models_repo.get_paginated.assert_called_once_with(status="active", limit=20, offset=0)
assert result.total == 1
assert result.models[0].status == "active"
@@ -168,25 +170,23 @@ class TestListModelVersionsRoute:
class TestGetActiveModelRoute:
"""Tests for GET /admin/training/models/active."""
def test_get_active_model_when_exists(self):
def test_get_active_model_when_exists(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is True
assert result.model is not None
assert result.model.is_active is True
def test_get_active_model_when_none(self):
def test_get_active_model_when_none(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_db = MagicMock()
mock_db.get_active_model_version.return_value = None
mock_models_repo.get_active.return_value = None
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is False
assert result.model is None
@@ -195,46 +195,43 @@ class TestGetActiveModelRoute:
class TestGetModelVersionRoute:
"""Tests for GET /admin/training/models/{version_id}."""
def test_get_model_version(self):
def test_get_model_version(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = _make_model_version()
mock_models_repo.get.return_value = _make_model_version()
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.version_id == TEST_VERSION_UUID
assert result.version == "1.0.0"
assert result.name == "test-model-v1"
assert result.metrics_mAP == 0.935
def test_get_model_version_not_found(self):
def test_get_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("get_model_version")
mock_db = MagicMock()
mock_db.get_model_version.return_value = None
mock_models_repo.get.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestUpdateModelVersionRoute:
"""Tests for PATCH /admin/training/models/{version_id}."""
def test_update_model_version(self):
def test_update_model_version(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
mock_models_repo.update.return_value = _make_model_version(name="updated-name")
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.update_model_version.assert_called_once_with(
mock_models_repo.update.assert_called_once_with(
version_id=TEST_VERSION_UUID,
name="updated-name",
description="Updated description",
@@ -242,45 +239,42 @@ class TestUpdateModelVersionRoute:
)
assert result.message == "Model version updated successfully"
def test_update_model_version_not_found(self):
def test_update_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("update_model_version")
mock_db = MagicMock()
mock_db.update_model_version.return_value = None
mock_models_repo.update.return_value = None
request = ModelVersionUpdateRequest(name="updated-name")
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestActivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/activate."""
def test_activate_model_version(self):
def test_activate_model_version(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
mock_models_repo.activate.return_value = _make_model_version(status="active", is_active=True)
# Create mock request with app state
mock_request = MagicMock()
mock_request.app.state.inference_service = None
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
mock_models_repo.activate.assert_called_once_with(TEST_VERSION_UUID)
assert result.status == "active"
assert result.message == "Model version activated for inference"
def test_activate_model_version_not_found(self):
def test_activate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("activate_model_version")
mock_db = MagicMock()
mock_db.activate_model_version.return_value = None
mock_models_repo.activate.return_value = None
# Create mock request with app state
mock_request = MagicMock()
@@ -289,88 +283,82 @@ class TestActivateModelVersionRoute:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestDeactivateModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
def test_deactivate_model_version(self):
def test_deactivate_model_version(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
mock_models_repo.deactivate.return_value = _make_model_version(status="inactive", is_active=False)
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "inactive"
assert result.message == "Model version deactivated"
def test_deactivate_model_version_not_found(self):
def test_deactivate_model_version_not_found(self, mock_models_repo):
fn = _find_endpoint("deactivate_model_version")
mock_db = MagicMock()
mock_db.deactivate_model_version.return_value = None
mock_models_repo.deactivate.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 404
class TestArchiveModelVersionRoute:
"""Tests for POST /admin/training/models/{version_id}/archive."""
def test_archive_model_version(self):
def test_archive_model_version(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
mock_models_repo.archive.return_value = _make_model_version(status="archived")
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.status == "archived"
assert result.message == "Model version archived"
def test_archive_active_model_fails(self):
def test_archive_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("archive_model_version")
mock_db = MagicMock()
mock_db.archive_model_version.return_value = None
mock_models_repo.archive.return_value = None
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400
class TestDeleteModelVersionRoute:
"""Tests for DELETE /admin/training/models/{version_id}."""
def test_delete_model_version(self):
def test_delete_model_version(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = True
mock_models_repo.delete.return_value = True
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
mock_models_repo.delete.assert_called_once_with(TEST_VERSION_UUID)
assert result["message"] == "Model version deleted"
def test_delete_active_model_fails(self):
def test_delete_active_model_fails(self, mock_models_repo):
fn = _find_endpoint("delete_model_version")
mock_db = MagicMock()
mock_db.delete_model_version.return_value = False
mock_models_repo.delete.return_value = False
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
assert exc_info.value.status_code == 400

View File

@@ -10,7 +10,13 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import validate_admin_token, get_admin_db
from inference.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,
get_training_task_repository,
get_model_version_repository,
)
class MockTrainingTask:
@@ -128,19 +134,17 @@ class MockModelVersion:
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
class MockAdminDB:
"""Mock AdminDB for testing Phase 4."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing Phase 4."""
def __init__(self):
self.documents = {}
self.annotations = {}
self.training_tasks = {}
self.training_links = {}
self.model_versions = {}
self.annotations = {} # Shared reference for filtering
self.training_links = {} # Shared reference for filtering
def get_documents_for_training(
def get_for_training(
self,
admin_token,
admin_token=None,
status="labeled",
has_annotations=True,
min_annotation_count=None,
@@ -173,17 +177,28 @@ class MockAdminDB:
total = len(filtered)
return filtered[offset:offset+limit], total
def get_annotations_for_document(self, document_id):
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing Phase 4."""
def __init__(self):
self.annotations = {}
def get_for_document(self, document_id, page_number=None):
"""Get annotations for document."""
return self.annotations.get(str(document_id), [])
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
def get_training_tasks_by_token(
class MockTrainingTaskRepository:
"""Mock TrainingTaskRepository for testing Phase 4."""
def __init__(self):
self.training_tasks = {}
self.training_links = {}
def get_paginated(
self,
admin_token,
admin_token=None,
status=None,
limit=20,
offset=0,
@@ -196,11 +211,22 @@ class MockAdminDB:
total = len(tasks)
return tasks[offset:offset+limit], total
def get_training_task(self, task_id):
def get(self, task_id):
"""Get training task by ID."""
return self.training_tasks.get(str(task_id))
def get_model_versions(self, status=None, limit=20, offset=0):
def get_document_training_tasks(self, document_id):
"""Get training tasks that used this document."""
return self.training_links.get(str(document_id), [])
class MockModelVersionRepository:
"""Mock ModelVersionRepository for testing Phase 4."""
def __init__(self):
self.model_versions = {}
def get_paginated(self, status=None, limit=20, offset=0):
"""Get model versions with optional filtering."""
models = list(self.model_versions.values())
if status:
@@ -214,8 +240,11 @@ def app():
"""Create test FastAPI app."""
app = FastAPI()
# Create mock DB
mock_db = MockAdminDB()
# Create mock repositories
mock_document_repo = MockDocumentRepository()
mock_annotation_repo = MockAnnotationRepository()
mock_training_task_repo = MockTrainingTaskRepository()
mock_model_version_repo = MockModelVersionRepository()
# Add test documents
doc1 = MockAdminDocument(
@@ -231,22 +260,25 @@ def app():
status="labeled",
)
mock_db.documents[str(doc1.document_id)] = doc1
mock_db.documents[str(doc2.document_id)] = doc2
mock_db.documents[str(doc3.document_id)] = doc3
mock_document_repo.documents[str(doc1.document_id)] = doc1
mock_document_repo.documents[str(doc2.document_id)] = doc2
mock_document_repo.documents[str(doc3.document_id)] = doc3
# Add annotations
mock_db.annotations[str(doc1.document_id)] = [
mock_annotation_repo.annotations[str(doc1.document_id)] = [
MockAnnotation(document_id=doc1.document_id, source="manual"),
MockAnnotation(document_id=doc1.document_id, source="auto"),
]
mock_db.annotations[str(doc2.document_id)] = [
mock_annotation_repo.annotations[str(doc2.document_id)] = [
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
MockAnnotation(document_id=doc2.document_id, source="auto"),
]
# doc3 has no annotations
# Share annotation data with document repo for filtering
mock_document_repo.annotations = mock_annotation_repo.annotations
# Add training tasks
task1 = MockTrainingTask(
name="Training Run 2024-01",
@@ -265,15 +297,18 @@ def app():
metrics_recall=0.92,
)
mock_db.training_tasks[str(task1.task_id)] = task1
mock_db.training_tasks[str(task2.task_id)] = task2
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
# Add training links (doc1 used in task1)
link1 = MockTrainingDocumentLink(
task_id=task1.task_id,
document_id=doc1.document_id,
)
mock_db.training_links[str(doc1.document_id)] = [link1]
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
# Share training links with document repo for filtering
mock_document_repo.training_links = mock_training_task_repo.training_links
# Add model versions
model1 = MockModelVersion(
@@ -296,12 +331,15 @@ def app():
metrics_recall=0.92,
document_count=600,
)
mock_db.model_versions[str(model1.version_id)] = model1
mock_db.model_versions[str(model2.version_id)] = model2
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
# Override dependencies
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
app.dependency_overrides[get_admin_db] = lambda: mock_db
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
# Include router
router = create_training_router()