WIP
This commit is contained in:
1
tests/web/core/__init__.py
Normal file
1
tests/web/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for web core components."""
|
||||
672
tests/web/core/test_task_interface.py
Normal file
672
tests/web/core/test_task_interface.py
Normal file
@@ -0,0 +1,672 @@
|
||||
"""Tests for unified task management interface.
|
||||
|
||||
TDD: These tests are written first (RED phase).
|
||||
"""
|
||||
|
||||
from abc import ABC
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestTaskStatus:
|
||||
"""Tests for TaskStatus dataclass."""
|
||||
|
||||
def test_task_status_basic_fields(self) -> None:
|
||||
"""TaskStatus has all required fields."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test_runner",
|
||||
is_running=True,
|
||||
pending_count=5,
|
||||
processing_count=2,
|
||||
)
|
||||
assert status.name == "test_runner"
|
||||
assert status.is_running is True
|
||||
assert status.pending_count == 5
|
||||
assert status.processing_count == 2
|
||||
|
||||
def test_task_status_with_error(self) -> None:
|
||||
"""TaskStatus can include optional error message."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="failed_runner",
|
||||
is_running=False,
|
||||
pending_count=0,
|
||||
processing_count=0,
|
||||
error="Connection failed",
|
||||
)
|
||||
assert status.error == "Connection failed"
|
||||
|
||||
def test_task_status_default_error_is_none(self) -> None:
|
||||
"""TaskStatus error defaults to None."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test",
|
||||
is_running=True,
|
||||
pending_count=0,
|
||||
processing_count=0,
|
||||
)
|
||||
assert status.error is None
|
||||
|
||||
def test_task_status_is_frozen(self) -> None:
|
||||
"""TaskStatus is immutable (frozen dataclass)."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test",
|
||||
is_running=True,
|
||||
pending_count=0,
|
||||
processing_count=0,
|
||||
)
|
||||
with pytest.raises(AttributeError):
|
||||
status.name = "changed" # type: ignore[misc]
|
||||
|
||||
|
||||
class TestTaskRunnerInterface:
|
||||
"""Tests for TaskRunner abstract base class."""
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""TaskRunner is abstract and cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
TaskRunner() # type: ignore[abstract]
|
||||
|
||||
def test_is_abstract_base_class(self) -> None:
|
||||
"""TaskRunner inherits from ABC."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
assert issubclass(TaskRunner, ABC)
|
||||
|
||||
def test_subclass_missing_name_cannot_instantiate(self) -> None:
|
||||
"""Subclass without name property cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingName(TaskRunner):
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("", False, 0, 0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MissingName() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_start_cannot_instantiate(self) -> None:
|
||||
"""Subclass without start method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingStart(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "test"
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("", False, 0, 0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MissingStart() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
|
||||
"""Subclass without stop method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingStop(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "test"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("", False, 0, 0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MissingStop() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
|
||||
"""Subclass without is_running property cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingIsRunning(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "test"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("", False, 0, 0)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MissingIsRunning() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
|
||||
"""Subclass without get_status method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
class MissingGetStatus(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "test"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
MissingGetStatus() # type: ignore[abstract]
|
||||
|
||||
def test_complete_subclass_can_instantiate(self) -> None:
|
||||
"""Complete subclass implementing all methods can be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class CompleteRunner(TaskRunner):
|
||||
def __init__(self) -> None:
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "complete_runner"
|
||||
|
||||
def start(self) -> None:
|
||||
self._running = True
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(
|
||||
name=self.name,
|
||||
is_running=self._running,
|
||||
pending_count=0,
|
||||
processing_count=0,
|
||||
)
|
||||
|
||||
runner = CompleteRunner()
|
||||
assert runner.name == "complete_runner"
|
||||
assert runner.is_running is False
|
||||
|
||||
runner.start()
|
||||
assert runner.is_running is True
|
||||
|
||||
status = runner.get_status()
|
||||
assert status.name == "complete_runner"
|
||||
assert status.is_running is True
|
||||
|
||||
runner.stop()
|
||||
assert runner.is_running is False
|
||||
|
||||
|
||||
class TestTaskManager:
|
||||
"""Tests for TaskManager facade."""
|
||||
|
||||
def test_register_runner(self) -> None:
|
||||
"""Can register a task runner."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mock"
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("mock", False, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
runner = MockRunner()
|
||||
manager.register(runner)
|
||||
|
||||
assert manager.get_runner("mock") is runner
|
||||
|
||||
def test_get_runner_returns_none_for_unknown(self) -> None:
|
||||
"""get_runner returns None for unknown runner name."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
assert manager.get_runner("unknown") is None
|
||||
|
||||
def test_start_all_runners(self) -> None:
|
||||
"""start_all starts all registered runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
self._name = runner_name
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self) -> None:
|
||||
self._running = True
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(self._name, self._running, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
runner1 = MockRunner("runner1")
|
||||
runner2 = MockRunner("runner2")
|
||||
manager.register(runner1)
|
||||
manager.register(runner2)
|
||||
|
||||
assert runner1.is_running is False
|
||||
assert runner2.is_running is False
|
||||
|
||||
manager.start_all()
|
||||
|
||||
assert runner1.is_running is True
|
||||
assert runner2.is_running is True
|
||||
|
||||
def test_stop_all_runners(self) -> None:
|
||||
"""stop_all stops all registered runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
self._name = runner_name
|
||||
self._running = True
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self) -> None:
|
||||
self._running = True
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
self._running = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return self._running
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(self._name, self._running, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
runner1 = MockRunner("runner1")
|
||||
runner2 = MockRunner("runner2")
|
||||
manager.register(runner1)
|
||||
manager.register(runner2)
|
||||
|
||||
assert runner1.is_running is True
|
||||
assert runner2.is_running is True
|
||||
|
||||
manager.stop_all()
|
||||
|
||||
assert runner1.is_running is False
|
||||
assert runner2.is_running is False
|
||||
|
||||
def test_get_all_status(self) -> None:
|
||||
"""get_all_status returns status of all runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str, pending: int) -> None:
|
||||
self._name = runner_name
|
||||
self._pending = pending
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(self._name, True, self._pending, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
manager.register(MockRunner("runner1", 5))
|
||||
manager.register(MockRunner("runner2", 10))
|
||||
|
||||
all_status = manager.get_all_status()
|
||||
|
||||
assert len(all_status) == 2
|
||||
assert all_status["runner1"].pending_count == 5
|
||||
assert all_status["runner2"].pending_count == 10
|
||||
|
||||
def test_get_all_status_empty_when_no_runners(self) -> None:
|
||||
"""get_all_status returns empty dict when no runners registered."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
assert manager.get_all_status() == {}
|
||||
|
||||
def test_runner_names_property(self) -> None:
|
||||
"""runner_names returns list of all registered runner names."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
self._name = runner_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(self._name, False, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
manager.register(MockRunner("alpha"))
|
||||
manager.register(MockRunner("beta"))
|
||||
|
||||
names = manager.runner_names
|
||||
assert set(names) == {"alpha", "beta"}
|
||||
|
||||
def test_stop_all_with_timeout_distribution(self) -> None:
|
||||
"""stop_all distributes timeout across runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
received_timeouts: list[float | None] = []
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
self._name = runner_name
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def start(self) -> None:
|
||||
pass
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
received_timeouts.append(timeout)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus(self._name, False, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
manager.register(MockRunner("r1"))
|
||||
manager.register(MockRunner("r2"))
|
||||
|
||||
manager.stop_all(timeout=20.0)
|
||||
|
||||
# Timeout should be distributed (20 / 2 = 10 each)
|
||||
assert len(received_timeouts) == 2
|
||||
assert all(t == 10.0 for t in received_timeouts)
|
||||
|
||||
def test_start_all_skips_runners_requiring_arguments(self) -> None:
|
||||
"""start_all skips runners that require arguments."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
no_args_started = []
|
||||
with_args_started = []
|
||||
|
||||
class NoArgsRunner(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "no_args"
|
||||
|
||||
def start(self) -> None:
|
||||
no_args_started.append(True)
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("no_args", False, 0, 0)
|
||||
|
||||
class RequiresArgsRunner(TaskRunner):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "requires_args"
|
||||
|
||||
def start(self, handler: object) -> None: # type: ignore[override]
|
||||
# This runner requires an argument
|
||||
with_args_started.append(True)
|
||||
|
||||
def stop(self, timeout: float | None = None) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_status(self) -> TaskStatus:
|
||||
return TaskStatus("requires_args", False, 0, 0)
|
||||
|
||||
manager = TaskManager()
|
||||
manager.register(NoArgsRunner())
|
||||
manager.register(RequiresArgsRunner())
|
||||
|
||||
# start_all should start no_args runner but skip requires_args
|
||||
manager.start_all()
|
||||
|
||||
assert len(no_args_started) == 1
|
||||
assert len(with_args_started) == 0 # Skipped due to TypeError
|
||||
|
||||
def test_stop_all_with_no_runners(self) -> None:
|
||||
"""stop_all does nothing when no runners registered."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
# Should not raise any exception
|
||||
manager.stop_all()
|
||||
# Just verify it returns without error
|
||||
assert manager.runner_names == []
|
||||
|
||||
|
||||
class TestTrainingSchedulerInterface:
|
||||
"""Tests for TrainingScheduler implementing TaskRunner."""
|
||||
|
||||
def test_training_scheduler_is_task_runner(self) -> None:
|
||||
"""TrainingScheduler inherits from TaskRunner."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
assert isinstance(scheduler, TaskRunner)
|
||||
|
||||
def test_training_scheduler_name(self) -> None:
|
||||
"""TrainingScheduler has correct name."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
assert scheduler.name == "training_scheduler"
|
||||
|
||||
def test_training_scheduler_get_status(self) -> None:
|
||||
"""TrainingScheduler provides status via get_status."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
# Mock the training tasks repository
|
||||
mock_tasks = MagicMock()
|
||||
mock_tasks.get_pending.return_value = [MagicMock(), MagicMock()]
|
||||
scheduler._training_tasks = mock_tasks
|
||||
|
||||
status = scheduler.get_status()
|
||||
|
||||
assert isinstance(status, TaskStatus)
|
||||
assert status.name == "training_scheduler"
|
||||
assert status.is_running is False
|
||||
assert status.pending_count == 2
|
||||
|
||||
|
||||
class TestAutoLabelSchedulerInterface:
|
||||
"""Tests for AutoLabelScheduler implementing TaskRunner."""
|
||||
|
||||
def test_autolabel_scheduler_is_task_runner(self) -> None:
|
||||
"""AutoLabelScheduler inherits from TaskRunner."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
scheduler = AutoLabelScheduler()
|
||||
assert isinstance(scheduler, TaskRunner)
|
||||
|
||||
def test_autolabel_scheduler_name(self) -> None:
|
||||
"""AutoLabelScheduler has correct name."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
scheduler = AutoLabelScheduler()
|
||||
assert scheduler.name == "autolabel_scheduler"
|
||||
|
||||
def test_autolabel_scheduler_get_status(self) -> None:
|
||||
"""AutoLabelScheduler provides status via get_status."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
with patch(
|
||||
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
|
||||
) as mock_get:
|
||||
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
|
||||
|
||||
scheduler = AutoLabelScheduler()
|
||||
status = scheduler.get_status()
|
||||
|
||||
assert isinstance(status, TaskStatus)
|
||||
assert status.name == "autolabel_scheduler"
|
||||
assert status.is_running is False
|
||||
assert status.pending_count == 3
|
||||
|
||||
|
||||
class TestAsyncTaskQueueInterface:
|
||||
"""Tests for AsyncTaskQueue implementing TaskRunner."""
|
||||
|
||||
def test_async_queue_is_task_runner(self) -> None:
|
||||
"""AsyncTaskQueue inherits from TaskRunner."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
assert isinstance(queue, TaskRunner)
|
||||
|
||||
def test_async_queue_name(self) -> None:
|
||||
"""AsyncTaskQueue has correct name."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
assert queue.name == "async_task_queue"
|
||||
|
||||
def test_async_queue_get_status(self) -> None:
|
||||
"""AsyncTaskQueue provides status via get_status."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
status = queue.get_status()
|
||||
|
||||
assert isinstance(status, TaskStatus)
|
||||
assert status.name == "async_task_queue"
|
||||
assert status.is_running is False
|
||||
assert status.pending_count == 0
|
||||
assert status.processing_count == 0
|
||||
|
||||
|
||||
class TestBatchTaskQueueInterface:
|
||||
"""Tests for BatchTaskQueue implementing TaskRunner."""
|
||||
|
||||
def test_batch_queue_is_task_runner(self) -> None:
|
||||
"""BatchTaskQueue inherits from TaskRunner."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
assert isinstance(queue, TaskRunner)
|
||||
|
||||
def test_batch_queue_name(self) -> None:
|
||||
"""BatchTaskQueue has correct name."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
assert queue.name == "batch_task_queue"
|
||||
|
||||
def test_batch_queue_get_status(self) -> None:
|
||||
"""BatchTaskQueue provides status via get_status."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
status = queue.get_status()
|
||||
|
||||
assert isinstance(status, TaskStatus)
|
||||
assert status.name == "batch_task_queue"
|
||||
assert status.is_running is False
|
||||
assert status.pending_count == 0
|
||||
@@ -8,80 +8,80 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import TokenRepository
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.web.core.auth import (
|
||||
get_admin_db,
|
||||
reset_admin_db,
|
||||
get_token_repository,
|
||||
reset_token_repository,
|
||||
validate_admin_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Create a mock AdminDB."""
|
||||
db = MagicMock(spec=AdminDB)
|
||||
db.is_valid_admin_token.return_value = True
|
||||
return db
|
||||
def mock_token_repo():
|
||||
"""Create a mock TokenRepository."""
|
||||
repo = MagicMock(spec=TokenRepository)
|
||||
repo.is_valid.return_value = True
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_db():
|
||||
"""Reset admin DB after each test."""
|
||||
def reset_repo():
|
||||
"""Reset token repository after each test."""
|
||||
yield
|
||||
reset_admin_db()
|
||||
reset_token_repository()
|
||||
|
||||
|
||||
class TestValidateAdminToken:
|
||||
"""Tests for validate_admin_token dependency."""
|
||||
|
||||
def test_missing_token_raises_401(self, mock_admin_db):
|
||||
def test_missing_token_raises_401(self, mock_token_repo):
|
||||
"""Test that missing token raises 401."""
|
||||
import asyncio
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token(None, mock_admin_db)
|
||||
validate_admin_token(None, mock_token_repo)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Admin token required" in exc_info.value.detail
|
||||
|
||||
def test_invalid_token_raises_401(self, mock_admin_db):
|
||||
def test_invalid_token_raises_401(self, mock_token_repo):
|
||||
"""Test that invalid token raises 401."""
|
||||
import asyncio
|
||||
|
||||
mock_admin_db.is_valid_admin_token.return_value = False
|
||||
mock_token_repo.is_valid.return_value = False
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token("invalid-token", mock_admin_db)
|
||||
validate_admin_token("invalid-token", mock_token_repo)
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "Invalid or expired" in exc_info.value.detail
|
||||
|
||||
def test_valid_token_returns_token(self, mock_admin_db):
|
||||
def test_valid_token_returns_token(self, mock_token_repo):
|
||||
"""Test that valid token is returned."""
|
||||
import asyncio
|
||||
|
||||
token = "valid-test-token"
|
||||
mock_admin_db.is_valid_admin_token.return_value = True
|
||||
mock_token_repo.is_valid.return_value = True
|
||||
|
||||
result = asyncio.get_event_loop().run_until_complete(
|
||||
validate_admin_token(token, mock_admin_db)
|
||||
validate_admin_token(token, mock_token_repo)
|
||||
)
|
||||
|
||||
assert result == token
|
||||
mock_admin_db.update_admin_token_usage.assert_called_once_with(token)
|
||||
mock_token_repo.update_usage.assert_called_once_with(token)
|
||||
|
||||
|
||||
class TestAdminDB:
|
||||
"""Tests for AdminDB operations."""
|
||||
class TestTokenRepository:
|
||||
"""Tests for TokenRepository operations."""
|
||||
|
||||
def test_is_valid_admin_token_active(self):
|
||||
def test_is_valid_active_token(self):
|
||||
"""Test valid active token."""
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -93,12 +93,12 @@ class TestAdminDB:
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is True
|
||||
repo = TokenRepository()
|
||||
assert repo.is_valid("test-token") is True
|
||||
|
||||
def test_is_valid_admin_token_inactive(self):
|
||||
def test_is_valid_inactive_token(self):
|
||||
"""Test inactive token."""
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -110,12 +110,12 @@ class TestAdminDB:
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is False
|
||||
repo = TokenRepository()
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_admin_token_expired(self):
|
||||
def test_is_valid_expired_token(self):
|
||||
"""Test expired token."""
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -127,36 +127,38 @@ class TestAdminDB:
|
||||
)
|
||||
mock_session.get.return_value = mock_token
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("test-token") is False
|
||||
repo = TokenRepository()
|
||||
# Need to also mock _now() to ensure proper comparison
|
||||
with patch.object(repo, "_now", return_value=datetime.utcnow()):
|
||||
assert repo.is_valid("test-token") is False
|
||||
|
||||
def test_is_valid_admin_token_not_found(self):
|
||||
def test_is_valid_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = None
|
||||
|
||||
db = AdminDB()
|
||||
assert db.is_valid_admin_token("nonexistent") is False
|
||||
repo = TokenRepository()
|
||||
assert repo.is_valid("nonexistent") is False
|
||||
|
||||
|
||||
class TestGetAdminDb:
|
||||
"""Tests for get_admin_db function."""
|
||||
class TestGetTokenRepository:
|
||||
"""Tests for get_token_repository function."""
|
||||
|
||||
def test_returns_singleton(self):
|
||||
"""Test that get_admin_db returns singleton."""
|
||||
reset_admin_db()
|
||||
"""Test that get_token_repository returns singleton."""
|
||||
reset_token_repository()
|
||||
|
||||
db1 = get_admin_db()
|
||||
db2 = get_admin_db()
|
||||
repo1 = get_token_repository()
|
||||
repo2 = get_token_repository()
|
||||
|
||||
assert db1 is db2
|
||||
assert repo1 is repo2
|
||||
|
||||
def test_reset_clears_singleton(self):
|
||||
"""Test that reset clears singleton."""
|
||||
db1 = get_admin_db()
|
||||
reset_admin_db()
|
||||
db2 = get_admin_db()
|
||||
repo1 = get_token_repository()
|
||||
reset_token_repository()
|
||||
repo2 = get_token_repository()
|
||||
|
||||
assert db1 is not db2
|
||||
assert repo1 is not repo2
|
||||
|
||||
@@ -11,7 +11,12 @@ from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_training_task_repository,
|
||||
)
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -59,14 +64,14 @@ class MockAnnotation:
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing enhanced features."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.annotations = {} # Shared reference for filtering
|
||||
|
||||
def get_documents_by_token(
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
@@ -103,32 +108,51 @@ class MockAdminDB:
|
||||
total = len(docs)
|
||||
return docs[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def count_documents_by_status(self, admin_token):
|
||||
def count_by_status(self, admin_token=None):
|
||||
"""Count documents by status."""
|
||||
counts = {}
|
||||
for doc in self.documents.values():
|
||||
if doc.admin_token == admin_token:
|
||||
if admin_token is None or doc.admin_token == admin_token:
|
||||
counts[doc.status] = counts.get(doc.status, 0) + 1
|
||||
return counts
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
def get(self, document_id):
|
||||
"""Get single document by ID."""
|
||||
return self.documents.get(document_id)
|
||||
|
||||
def get_by_token(self, document_id, admin_token=None):
|
||||
"""Get single document by ID and token."""
|
||||
doc = self.documents.get(document_id)
|
||||
if doc and doc.admin_token == admin_token:
|
||||
if doc and (admin_token is None or doc.admin_token == admin_token):
|
||||
return doc
|
||||
return None
|
||||
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = {}
|
||||
|
||||
def get_for_document(self, document_id, page_number=None):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
|
||||
class MockTrainingTaskRepository:
|
||||
"""Mock TrainingTaskRepository for testing enhanced features."""
|
||||
|
||||
def __init__(self):
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return [] # No training history in this test
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
def get(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return None # No training tasks in this test
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -136,8 +160,10 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repositories
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
mock_annotation_repo = MockAnnotationRepository()
|
||||
mock_training_task_repo = MockTrainingTaskRepository()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
@@ -162,19 +188,19 @@ def app():
|
||||
batch_id=None
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
||||
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations to doc1 and doc2
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc1.document_id,
|
||||
class_name="invoice_number",
|
||||
text_value="INV-001"
|
||||
)
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(
|
||||
document_id=doc2.document_id,
|
||||
class_id=6,
|
||||
@@ -189,9 +215,14 @@ def app():
|
||||
)
|
||||
]
|
||||
|
||||
# Share annotation data with document repo for filtering
|
||||
mock_document_repo.annotations = mock_annotation_repo.annotations
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
||||
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
||||
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
||||
|
||||
# Include router
|
||||
router = create_documents_router(StorageConfig())
|
||||
|
||||
@@ -10,7 +10,10 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
)
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -34,23 +37,27 @@ class MockAdminDocument:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing annotation locks."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing annotation locks."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
def get(self, document_id):
|
||||
"""Get single document by ID."""
|
||||
return self.documents.get(document_id)
|
||||
|
||||
def get_by_token(self, document_id, admin_token=None):
|
||||
"""Get single document by ID and token."""
|
||||
doc = self.documents.get(document_id)
|
||||
if doc and doc.admin_token == admin_token:
|
||||
if doc and (admin_token is None or doc.admin_token == admin_token):
|
||||
return doc
|
||||
return None
|
||||
|
||||
def acquire_annotation_lock(self, document_id, admin_token, duration_seconds=300):
|
||||
def acquire_annotation_lock(self, document_id, admin_token=None, duration_seconds=300):
|
||||
"""Acquire annotation lock for a document."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
# Check if already locked
|
||||
@@ -62,20 +69,20 @@ class MockAdminDB:
|
||||
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
||||
return doc
|
||||
|
||||
def release_annotation_lock(self, document_id, admin_token, force=False):
|
||||
def release_annotation_lock(self, document_id, admin_token=None, force=False):
|
||||
"""Release annotation lock for a document."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
# Release lock
|
||||
doc.annotation_lock_until = None
|
||||
return doc
|
||||
|
||||
def extend_annotation_lock(self, document_id, admin_token, additional_seconds=300):
|
||||
def extend_annotation_lock(self, document_id, admin_token=None, additional_seconds=300):
|
||||
"""Extend an existing annotation lock."""
|
||||
doc = self.documents.get(document_id)
|
||||
if not doc or doc.admin_token != admin_token:
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
# Check if lock exists and is still valid
|
||||
@@ -93,8 +100,8 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repository
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
|
||||
# Add test document
|
||||
doc1 = MockAdminDocument(
|
||||
@@ -103,11 +110,11 @@ def app():
|
||||
upload_source="ui",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
||||
|
||||
# Include router
|
||||
router = create_locks_router()
|
||||
@@ -124,9 +131,9 @@ def client(app):
|
||||
|
||||
@pytest.fixture
|
||||
def document_id(app):
|
||||
"""Get document ID from the mock DB."""
|
||||
mock_db = app.dependency_overrides[get_admin_db]()
|
||||
return str(list(mock_db.documents.keys())[0])
|
||||
"""Get document ID from the mock repository."""
|
||||
mock_document_repo = app.dependency_overrides[get_document_repository]()
|
||||
return str(list(mock_document_repo.documents.keys())[0])
|
||||
|
||||
|
||||
class TestAnnotationLocks:
|
||||
|
||||
@@ -9,8 +9,12 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.annotations import create_annotation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.admin.annotations import (
|
||||
create_annotation_router,
|
||||
get_doc_repository,
|
||||
get_ann_repository,
|
||||
)
|
||||
from inference.web.core.auth import validate_admin_token
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
@@ -73,22 +77,40 @@ class MockAnnotationHistory:
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 5."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing Phase 5."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.annotation_history = {}
|
||||
|
||||
def get_document_by_token(self, document_id, admin_token):
|
||||
def get(self, document_id):
|
||||
"""Get document by ID."""
|
||||
return self.documents.get(str(document_id))
|
||||
|
||||
def get_by_token(self, document_id, admin_token=None):
|
||||
"""Get document by ID and token."""
|
||||
doc = self.documents.get(str(document_id))
|
||||
if doc and doc.admin_token == admin_token:
|
||||
if doc and (admin_token is None or doc.admin_token == admin_token):
|
||||
return doc
|
||||
return None
|
||||
|
||||
def verify_annotation(self, annotation_id, admin_token):
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing Phase 5."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = {}
|
||||
self.annotation_history = {}
|
||||
|
||||
def get(self, annotation_id):
|
||||
"""Get annotation by ID."""
|
||||
return self.annotations.get(str(annotation_id))
|
||||
|
||||
def get_for_document(self, document_id, page_number=None):
|
||||
"""Get annotations for a document."""
|
||||
return [a for a in self.annotations.values() if str(a.document_id) == str(document_id)]
|
||||
|
||||
def verify(self, annotation_id, admin_token):
|
||||
"""Mark annotation as verified."""
|
||||
annotation = self.annotations.get(str(annotation_id))
|
||||
if annotation:
|
||||
@@ -98,7 +120,7 @@ class MockAdminDB:
|
||||
return annotation
|
||||
return None
|
||||
|
||||
def override_annotation(
|
||||
def override(
|
||||
self,
|
||||
annotation_id,
|
||||
admin_token,
|
||||
@@ -131,7 +153,7 @@ class MockAdminDB:
|
||||
return annotation
|
||||
return None
|
||||
|
||||
def get_annotation_history(self, annotation_id):
|
||||
def get_history(self, annotation_id):
|
||||
"""Get annotation history."""
|
||||
return self.annotation_history.get(str(annotation_id), [])
|
||||
|
||||
@@ -141,15 +163,16 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repositories
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
mock_annotation_repo = MockAnnotationRepository()
|
||||
|
||||
# Add test document
|
||||
doc1 = MockAdminDocument(
|
||||
filename="TEST001.pdf",
|
||||
status="labeled",
|
||||
)
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
|
||||
# Add test annotations
|
||||
ann1 = MockAnnotation(
|
||||
@@ -169,8 +192,8 @@ def app():
|
||||
confidence=0.98,
|
||||
)
|
||||
|
||||
mock_db.annotations[str(ann1.annotation_id)] = ann1
|
||||
mock_db.annotations[str(ann2.annotation_id)] = ann2
|
||||
mock_annotation_repo.annotations[str(ann1.annotation_id)] = ann1
|
||||
mock_annotation_repo.annotations[str(ann2.annotation_id)] = ann2
|
||||
|
||||
# Store document ID and annotation IDs for tests
|
||||
app.state.document_id = str(doc1.document_id)
|
||||
@@ -179,7 +202,8 @@ def app():
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_doc_repository] = lambda: mock_document_repo
|
||||
app.dependency_overrides[get_ann_repository] = lambda: mock_annotation_repo
|
||||
|
||||
# Include router
|
||||
router = create_annotation_router()
|
||||
|
||||
@@ -11,7 +11,11 @@ from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_dataset_repository,
|
||||
)
|
||||
|
||||
|
||||
TEST_ADMIN_TOKEN = "test-admin-token-12345"
|
||||
@@ -26,18 +30,27 @@ def admin_token() -> str:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db() -> MagicMock:
|
||||
"""Create a mock AdminDB for testing."""
|
||||
def mock_document_repo() -> MagicMock:
|
||||
"""Create a mock DocumentRepository for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get_document_by_token.return_value = None
|
||||
mock.get_dataset.return_value = None
|
||||
mock.get_augmented_datasets.return_value = ([], 0)
|
||||
mock.get.return_value = None
|
||||
mock.get_by_token.return_value = None
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def mock_dataset_repo() -> MagicMock:
|
||||
"""Create a mock DatasetRepository for testing."""
|
||||
mock = MagicMock()
|
||||
# Default return values
|
||||
mock.get.return_value = None
|
||||
mock.get_paginated.return_value = ([], 0)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
||||
"""Create test client with admin authentication."""
|
||||
app = FastAPI()
|
||||
|
||||
@@ -45,11 +58,15 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def get_token_override():
|
||||
return TEST_ADMIN_TOKEN
|
||||
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
def get_document_repo_override():
|
||||
return mock_document_repo
|
||||
|
||||
def get_dataset_repo_override():
|
||||
return mock_dataset_repo
|
||||
|
||||
app.dependency_overrides[validate_admin_token] = get_token_override
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
||||
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
||||
|
||||
# Include router - the router already has /augmentation prefix
|
||||
# so we add /api/v1/admin to get /api/v1/admin/augmentation
|
||||
@@ -60,15 +77,19 @@ def admin_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def unauthenticated_client(mock_admin_db: MagicMock) -> TestClient:
|
||||
def unauthenticated_client(mock_document_repo: MagicMock, mock_dataset_repo: MagicMock) -> TestClient:
|
||||
"""Create test client WITHOUT admin authentication override."""
|
||||
app = FastAPI()
|
||||
|
||||
# Only override the database, NOT the token validation
|
||||
def get_db_override():
|
||||
return mock_admin_db
|
||||
# Only override the repositories, NOT the token validation
|
||||
def get_document_repo_override():
|
||||
return mock_document_repo
|
||||
|
||||
app.dependency_overrides[get_admin_db] = get_db_override
|
||||
def get_dataset_repo_override():
|
||||
return mock_dataset_repo
|
||||
|
||||
app.dependency_overrides[get_document_repository] = get_document_repo_override
|
||||
app.dependency_overrides[get_dataset_repository] = get_dataset_repo_override
|
||||
|
||||
router = create_augmentation_router()
|
||||
app.include_router(router, prefix="/api/v1/admin")
|
||||
@@ -142,13 +163,13 @@ class TestAugmentationPreviewEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_document_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing augmentation on a document."""
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
mock_document_repo.get.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
@@ -218,13 +239,13 @@ class TestAugmentationPreviewConfigEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_document_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_document_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test previewing full config on a document."""
|
||||
# Mock document exists
|
||||
mock_document = MagicMock()
|
||||
mock_document.images_dir = "/fake/path"
|
||||
mock_admin_db.get_document.return_value = mock_document
|
||||
mock_document_repo.get.return_value = mock_document
|
||||
|
||||
# Create a fake image (100x100 RGB)
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
@@ -260,13 +281,13 @@ class TestAugmentationBatchEndpoint:
|
||||
admin_client: TestClient,
|
||||
admin_token: str,
|
||||
sample_dataset_id: str,
|
||||
mock_admin_db: MagicMock,
|
||||
mock_dataset_repo: MagicMock,
|
||||
) -> None:
|
||||
"""Test creating augmented dataset."""
|
||||
# Mock dataset exists
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.total_images = 100
|
||||
mock_admin_db.get_dataset.return_value = mock_dataset
|
||||
mock_dataset_repo.get.return_value = mock_dataset
|
||||
|
||||
response = admin_client.post(
|
||||
"/api/v1/admin/augmentation/batch",
|
||||
|
||||
@@ -9,7 +9,6 @@ from unittest.mock import Mock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.web.services.autolabel import AutoLabelService
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
|
||||
class MockDocument:
|
||||
@@ -23,19 +22,18 @@ class MockDocument:
|
||||
self.auto_label_error = None
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = []
|
||||
self.status_updates = []
|
||||
|
||||
def get_document(self, document_id):
|
||||
def get(self, document_id):
|
||||
"""Get document by ID."""
|
||||
return self.documents.get(str(document_id))
|
||||
|
||||
def update_document_status(
|
||||
def update_status(
|
||||
self,
|
||||
document_id,
|
||||
status=None,
|
||||
@@ -58,19 +56,32 @@ class MockAdminDB:
|
||||
if auto_label_error:
|
||||
doc.auto_label_error = auto_label_error
|
||||
|
||||
def delete_annotations_for_document(self, document_id, source=None):
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = []
|
||||
|
||||
def delete_for_document(self, document_id, source=None):
|
||||
"""Mock delete annotations."""
|
||||
return 0
|
||||
|
||||
def create_annotations_batch(self, annotations):
|
||||
def create_batch(self, annotations):
|
||||
"""Mock create annotations."""
|
||||
self.annotations.extend(annotations)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create mock admin DB."""
|
||||
return MockAdminDB()
|
||||
def mock_doc_repo():
|
||||
"""Create mock document repository."""
|
||||
return MockDocumentRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ann_repo():
|
||||
"""Create mock annotation repository."""
|
||||
return MockAnnotationRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -82,10 +93,14 @@ def auto_label_service(monkeypatch):
|
||||
service._ocr_engine.extract_from_image = Mock(return_value=[])
|
||||
|
||||
# Mock the image processing methods to avoid file I/O errors
|
||||
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
|
||||
def mock_process_image(self, document_id, image_path, field_values, ann_repo, page_number=1):
|
||||
return 0 # No annotations created (mocked)
|
||||
|
||||
def mock_process_pdf(self, document_id, pdf_path, field_values, ann_repo):
|
||||
return 0 # No annotations created (mocked)
|
||||
|
||||
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
|
||||
monkeypatch.setattr(AutoLabelService, "_process_pdf", mock_process_pdf)
|
||||
|
||||
return service
|
||||
|
||||
@@ -93,11 +108,11 @@ def auto_label_service(monkeypatch):
|
||||
class TestAutoLabelWithLocks:
|
||||
"""Tests for auto-label service with lock integration."""
|
||||
|
||||
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling succeeds on unlocked document."""
|
||||
# Create test document (unlocked)
|
||||
document_id = str(uuid4())
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=None,
|
||||
)
|
||||
@@ -111,21 +126,22 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result["status"] == "completed"
|
||||
# Verify status was updated to running and then completed
|
||||
assert len(mock_db.status_updates) >= 2
|
||||
assert mock_db.status_updates[0]["auto_label_status"] == "running"
|
||||
assert len(mock_doc_repo.status_updates) >= 2
|
||||
assert mock_doc_repo.status_updates[0]["auto_label_status"] == "running"
|
||||
|
||||
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_locked_document_fails(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling fails on locked document."""
|
||||
# Create test document (locked for 1 hour)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -139,7 +155,8 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
@@ -150,15 +167,15 @@ class TestAutoLabelWithLocks:
|
||||
# Verify status was updated to failed
|
||||
assert any(
|
||||
update["auto_label_status"] == "failed"
|
||||
for update in mock_db.status_updates
|
||||
for update in mock_doc_repo.status_updates
|
||||
)
|
||||
|
||||
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling succeeds when lock has expired."""
|
||||
# Create test document (lock expired 1 hour ago)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -172,18 +189,19 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should succeed (lock expired)
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_skip_lock_check(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -197,14 +215,15 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
skip_lock_check=True, # Bypass lock check
|
||||
)
|
||||
|
||||
# Should succeed even though document is locked
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_document_not_found(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling fails when document doesn't exist."""
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
@@ -215,19 +234,20 @@ class TestAutoLabelWithLocks:
|
||||
document_id=str(uuid4()),
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
assert result["status"] == "failed"
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test that lock check is enabled by default."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -241,7 +261,8 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
# skip_lock_check not specified, should default to False
|
||||
)
|
||||
|
||||
|
||||
@@ -11,20 +11,20 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.batch.routes import router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.api.v1.batch.routes import router, get_batch_repository
|
||||
from inference.web.core.auth import validate_admin_token
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing."""
|
||||
class MockBatchUploadRepository:
|
||||
"""Mock BatchUploadRepository for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.batches = {}
|
||||
self.batch_files = {}
|
||||
|
||||
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
|
||||
def create(self, admin_token, filename, file_size, upload_source="ui"):
|
||||
batch_id = uuid4()
|
||||
batch = type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
@@ -46,13 +46,13 @@ class MockAdminDB:
|
||||
self.batches[batch_id] = batch
|
||||
return batch
|
||||
|
||||
def update_batch_upload(self, batch_id, **kwargs):
|
||||
def update(self, batch_id, **kwargs):
|
||||
if batch_id in self.batches:
|
||||
batch = self.batches[batch_id]
|
||||
for key, value in kwargs.items():
|
||||
setattr(batch, key, value)
|
||||
|
||||
def create_batch_upload_file(self, batch_id, filename, **kwargs):
|
||||
def create_file(self, batch_id, filename, **kwargs):
|
||||
file_id = uuid4()
|
||||
defaults = {
|
||||
'file_id': file_id,
|
||||
@@ -70,7 +70,7 @@ class MockAdminDB:
|
||||
self.batch_files[batch_id].append(file_record)
|
||||
return file_record
|
||||
|
||||
def update_batch_upload_file(self, file_id, **kwargs):
|
||||
def update_file(self, file_id, **kwargs):
|
||||
for files in self.batch_files.values():
|
||||
for file_record in files:
|
||||
if file_record.file_id == file_id:
|
||||
@@ -78,7 +78,7 @@ class MockAdminDB:
|
||||
setattr(file_record, key, value)
|
||||
return
|
||||
|
||||
def get_batch_upload(self, batch_id):
|
||||
def get(self, batch_id):
|
||||
return self.batches.get(batch_id, type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
'admin_token': 'test-token',
|
||||
@@ -95,12 +95,15 @@ class MockAdminDB:
|
||||
'completed_at': datetime.utcnow(),
|
||||
})())
|
||||
|
||||
def get_batch_upload_files(self, batch_id):
|
||||
def get_files(self, batch_id):
|
||||
return self.batch_files.get(batch_id, [])
|
||||
|
||||
def get_batch_uploads_by_token(self, admin_token, limit=50, offset=0):
|
||||
def get_paginated(self, admin_token=None, limit=50, offset=0):
|
||||
"""Get batches filtered by admin token with pagination."""
|
||||
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
|
||||
if admin_token:
|
||||
token_batches = [b for b in self.batches.values() if b.admin_token == admin_token]
|
||||
else:
|
||||
token_batches = list(self.batches.values())
|
||||
total = len(token_batches)
|
||||
return token_batches[offset:offset+limit], total
|
||||
|
||||
@@ -110,15 +113,15 @@ def app():
|
||||
"""Create test FastAPI app with mocked dependencies."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock admin DB
|
||||
mock_admin_db = MockAdminDB()
|
||||
# Create mock batch upload repository
|
||||
mock_batch_upload_repo = MockBatchUploadRepository()
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_admin_db
|
||||
app.dependency_overrides[get_batch_repository] = lambda: mock_batch_upload_repo
|
||||
|
||||
# Initialize batch queue with mock service
|
||||
batch_service = BatchUploadService(mock_admin_db)
|
||||
batch_service = BatchUploadService(mock_batch_upload_repo)
|
||||
init_batch_queue(batch_service)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
@@ -9,19 +9,18 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_db():
|
||||
"""Mock admin database for testing."""
|
||||
class MockAdminDB:
|
||||
def batch_repo():
|
||||
"""Mock batch upload repository for testing."""
|
||||
class MockBatchUploadRepository:
|
||||
def __init__(self):
|
||||
self.batches = {}
|
||||
self.batch_files = {}
|
||||
|
||||
def create_batch_upload(self, admin_token, filename, file_size, upload_source):
|
||||
def create(self, admin_token, filename, file_size, upload_source):
|
||||
batch_id = uuid4()
|
||||
batch = type('BatchUpload', (), {
|
||||
'batch_id': batch_id,
|
||||
@@ -43,13 +42,13 @@ def admin_db():
|
||||
self.batches[batch_id] = batch
|
||||
return batch
|
||||
|
||||
def update_batch_upload(self, batch_id, **kwargs):
|
||||
def update(self, batch_id, **kwargs):
|
||||
if batch_id in self.batches:
|
||||
batch = self.batches[batch_id]
|
||||
for key, value in kwargs.items():
|
||||
setattr(batch, key, value)
|
||||
|
||||
def create_batch_upload_file(self, batch_id, filename, **kwargs):
|
||||
def create_file(self, batch_id, filename, **kwargs):
|
||||
file_id = uuid4()
|
||||
# Set defaults for attributes
|
||||
defaults = {
|
||||
@@ -68,7 +67,7 @@ def admin_db():
|
||||
self.batch_files[batch_id].append(file_record)
|
||||
return file_record
|
||||
|
||||
def update_batch_upload_file(self, file_id, **kwargs):
|
||||
def update_file(self, file_id, **kwargs):
|
||||
for files in self.batch_files.values():
|
||||
for file_record in files:
|
||||
if file_record.file_id == file_id:
|
||||
@@ -76,19 +75,19 @@ def admin_db():
|
||||
setattr(file_record, key, value)
|
||||
return
|
||||
|
||||
def get_batch_upload(self, batch_id):
|
||||
def get(self, batch_id):
|
||||
return self.batches.get(batch_id)
|
||||
|
||||
def get_batch_upload_files(self, batch_id):
|
||||
def get_files(self, batch_id):
|
||||
return self.batch_files.get(batch_id, [])
|
||||
|
||||
return MockAdminDB()
|
||||
return MockBatchUploadRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def batch_service(admin_db):
|
||||
def batch_service(batch_repo):
|
||||
"""Batch upload service instance."""
|
||||
return BatchUploadService(admin_db)
|
||||
return BatchUploadService(batch_repo)
|
||||
|
||||
|
||||
def create_test_zip(files):
|
||||
@@ -194,7 +193,7 @@ INV002,F2024-002,2024-01-16,2500.00,7350087654321,123-4567,C124
|
||||
assert csv_data["INV001"]["Amount"] == "1500.00"
|
||||
assert csv_data["INV001"]["customer_number"] == "C123"
|
||||
|
||||
def test_get_batch_status(self, batch_service, admin_db):
|
||||
def test_get_batch_status(self, batch_service, batch_repo):
|
||||
"""Test getting batch upload status."""
|
||||
# Create a batch
|
||||
zip_content = create_test_zip({"INV001.pdf": b"%PDF-1.4 test"})
|
||||
|
||||
@@ -16,7 +16,6 @@ from inference.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
TrainingDataset,
|
||||
FIELD_CLASSES,
|
||||
)
|
||||
|
||||
|
||||
@@ -35,10 +34,10 @@ def tmp_admin_images(tmp_path):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db():
|
||||
"""Mock AdminDB with dataset and document methods."""
|
||||
db = MagicMock()
|
||||
db.create_dataset.return_value = TrainingDataset(
|
||||
def mock_datasets_repo():
|
||||
"""Mock DatasetRepository."""
|
||||
repo = MagicMock()
|
||||
repo.create.return_value = TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
status="building",
|
||||
@@ -46,7 +45,19 @@ def mock_admin_db():
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
return db
|
||||
return repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents_repo():
|
||||
"""Mock DocumentRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_annotations_repo():
|
||||
"""Mock AnnotationRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -60,6 +71,7 @@ def sample_documents(tmp_admin_images):
|
||||
doc.filename = f"{doc_id}.pdf"
|
||||
doc.page_count = 2
|
||||
doc.file_path = str(tmp_path / "admin_images" / str(doc_id))
|
||||
doc.group_key = None # Default to no group
|
||||
docs.append(doc)
|
||||
return docs
|
||||
|
||||
@@ -89,21 +101,27 @@ class TestDatasetBuilder:
|
||||
"""Tests for DatasetBuilder."""
|
||||
|
||||
def test_build_creates_directory_structure(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset_dir = tmp_path / "datasets" / "test"
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# Mock DB calls
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
# Mock repo calls
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -119,18 +137,24 @@ class TestDatasetBuilder:
|
||||
assert (result_dir / "labels" / split).exists()
|
||||
|
||||
def test_build_copies_images(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Images should be copied from admin_images to dataset folder."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
result = builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -149,18 +173,24 @@ class TestDatasetBuilder:
|
||||
assert total_images == 10 # 5 docs * 2 pages
|
||||
|
||||
def test_build_generates_yolo_labels(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""YOLO label files should be generated with correct format."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -187,18 +217,24 @@ class TestDatasetBuilder:
|
||||
assert 0 <= float(parts[2]) <= 1 # y_center
|
||||
|
||||
def test_build_generates_data_yaml(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""data.yaml should be generated with correct field classes."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -217,18 +253,24 @@ class TestDatasetBuilder:
|
||||
assert "invoice_number" in content
|
||||
|
||||
def test_build_splits_documents_correctly(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Documents should be split into train/val/test according to ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -238,8 +280,8 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Verify add_dataset_documents was called with correct splits
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
# Verify add_documents was called with correct splits
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
splits = [d["split"] for d in docs_added]
|
||||
assert "train" in splits
|
||||
@@ -248,18 +290,24 @@ class TestDatasetBuilder:
|
||||
assert train_count >= 3 # At least 3 of 5 should be train
|
||||
|
||||
def test_build_updates_status_to_ready(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""After successful build, dataset status should be updated to 'ready'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -269,22 +317,27 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
mock_datasets_repo.update_status.assert_called_once()
|
||||
call_kwargs = mock_datasets_repo.update_status.call_args[1]
|
||||
assert call_kwargs["status"] == "ready"
|
||||
assert call_kwargs["total_documents"] == 5
|
||||
assert call_kwargs["total_images"] == 10
|
||||
|
||||
def test_build_sets_failed_on_error(
|
||||
self, tmp_path, mock_admin_db
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""If build fails, dataset status should be set to 'failed'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = [] # No docs found
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = [] # No docs found
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
with pytest.raises(ValueError):
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
@@ -295,27 +348,33 @@ class TestDatasetBuilder:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
mock_admin_db.update_dataset_status.assert_called_once()
|
||||
call_kwargs = mock_admin_db.update_dataset_status.call_args[1]
|
||||
mock_datasets_repo.update_status.assert_called_once()
|
||||
call_kwargs = mock_datasets_repo.update_status.call_args[1]
|
||||
assert call_kwargs["status"] == "failed"
|
||||
|
||||
def test_build_with_seed_produces_deterministic_splits(
|
||||
self, tmp_path, mock_admin_db, sample_documents, sample_annotations
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo,
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Same seed should produce same splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
results = []
|
||||
for _ in range(2):
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = sample_documents
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = sample_documents
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
sample_annotations.get(str(doc_id), [])
|
||||
)
|
||||
mock_admin_db.add_dataset_documents.reset_mock()
|
||||
mock_admin_db.update_dataset_status.reset_mock()
|
||||
mock_datasets_repo.add_documents.reset_mock()
|
||||
mock_datasets_repo.update_status.reset_mock()
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in sample_documents],
|
||||
@@ -324,7 +383,7 @@ class TestDatasetBuilder:
|
||||
seed=42,
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
results.append([(d["document_id"], d["split"]) for d in docs])
|
||||
|
||||
@@ -342,11 +401,18 @@ class TestAssignSplitsByGroup:
|
||||
doc.page_count = 1
|
||||
return doc
|
||||
|
||||
def test_single_doc_groups_are_distributed(self, tmp_path, mock_admin_db):
|
||||
def test_single_doc_groups_are_distributed(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with unique group_key are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 3 documents, each with unique group_key
|
||||
docs = [
|
||||
@@ -363,11 +429,18 @@ class TestAssignSplitsByGroup:
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1 # Ensure val is not empty
|
||||
|
||||
def test_null_group_key_treated_as_single_doc_group(self, tmp_path, mock_admin_db):
|
||||
def test_null_group_key_treated_as_single_doc_group(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key=None),
|
||||
@@ -384,11 +457,18 @@ class TestAssignSplitsByGroup:
|
||||
assert train_count >= 1
|
||||
assert val_count >= 1
|
||||
|
||||
def test_multi_doc_groups_stay_together(self, tmp_path, mock_admin_db):
|
||||
def test_multi_doc_groups_stay_together(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with same group_key should be assigned to the same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 6 documents in 2 groups
|
||||
docs = [
|
||||
@@ -410,11 +490,18 @@ class TestAssignSplitsByGroup:
|
||||
splits_b = [result[str(d.document_id)] for d in docs[3:]]
|
||||
assert len(set(splits_b)) == 1, "All docs in supplier-B should be in same split"
|
||||
|
||||
def test_multi_doc_groups_split_by_ratio(self, tmp_path, mock_admin_db):
|
||||
def test_multi_doc_groups_split_by_ratio(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Multi-doc groups should be split according to train/val/test ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 10 groups with 2 docs each
|
||||
docs = []
|
||||
@@ -445,11 +532,18 @@ class TestAssignSplitsByGroup:
|
||||
assert split_counts["val"] >= 1
|
||||
assert split_counts["val"] <= 3
|
||||
|
||||
def test_mixed_single_and_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_mixed_single_and_multi_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
# Single-doc groups
|
||||
@@ -476,11 +570,18 @@ class TestAssignSplitsByGroup:
|
||||
assert result[str(docs[3].document_id)] == result[str(docs[4].document_id)]
|
||||
assert result[str(docs[5].document_id)] == result[str(docs[6].document_id)]
|
||||
|
||||
def test_deterministic_with_seed(self, tmp_path, mock_admin_db):
|
||||
def test_deterministic_with_seed(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Same seed should produce same split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
@@ -496,11 +597,18 @@ class TestAssignSplitsByGroup:
|
||||
|
||||
assert result1 == result2
|
||||
|
||||
def test_different_seed_may_produce_different_splits(self, tmp_path, mock_admin_db):
|
||||
def test_different_seed_may_produce_different_splits(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Different seeds should potentially produce different split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# Many groups to increase chance of different results
|
||||
docs = []
|
||||
@@ -515,11 +623,18 @@ class TestAssignSplitsByGroup:
|
||||
# Results should be different (very likely with 20 groups)
|
||||
assert result1 != result2
|
||||
|
||||
def test_all_docs_assigned(self, tmp_path, mock_admin_db):
|
||||
def test_all_docs_assigned(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Every document should be assigned a split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="group-A"),
|
||||
@@ -535,21 +650,35 @@ class TestAssignSplitsByGroup:
|
||||
assert str(doc.document_id) in result
|
||||
assert result[str(doc.document_id)] in ["train", "val", "test"]
|
||||
|
||||
def test_empty_documents_list(self, tmp_path, mock_admin_db):
|
||||
def test_empty_documents_list(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Empty document list should return empty result."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
result = builder._assign_splits_by_group([], train_ratio=0.7, val_ratio=0.2, seed=42)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_only_multi_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_only_multi_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have multiple docs, splits should follow ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
# 5 groups with 3 docs each
|
||||
docs = []
|
||||
@@ -574,11 +703,18 @@ class TestAssignSplitsByGroup:
|
||||
assert split_counts["train"] >= 2
|
||||
assert split_counts["train"] <= 4
|
||||
|
||||
def test_only_single_doc_groups(self, tmp_path, mock_admin_db):
|
||||
def test_only_single_doc_groups(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have single doc, they are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
|
||||
docs = [
|
||||
self._make_mock_doc(uuid4(), group_key="unique-1"),
|
||||
@@ -658,20 +794,26 @@ class TestBuildDatasetWithGroupKey:
|
||||
return annotations
|
||||
|
||||
def test_build_respects_group_key_splits(
|
||||
self, grouped_documents, grouped_annotations, mock_admin_db
|
||||
self, grouped_documents, grouped_annotations,
|
||||
mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""build_dataset should use group_key for split assignment."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
tmp_path, docs = grouped_documents
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.side_effect = lambda doc_id: (
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = docs
|
||||
mock_annotations_repo.get_for_document.side_effect = lambda doc_id, page_number=None: (
|
||||
grouped_annotations.get(str(doc_id), [])
|
||||
)
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
@@ -681,8 +823,8 @@ class TestBuildDatasetWithGroupKey:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
# Get the document splits from add_dataset_documents call
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
# Get the document splits from add_documents call
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
# Build mapping of doc_id -> split
|
||||
@@ -701,7 +843,9 @@ class TestBuildDatasetWithGroupKey:
|
||||
supplier_b_splits = [doc_split_map[doc_id] for doc_id in supplier_b_ids]
|
||||
assert len(set(supplier_b_splits)) == 1, "supplier-B docs should be in same split"
|
||||
|
||||
def test_build_with_all_same_group_key(self, tmp_path, mock_admin_db):
|
||||
def test_build_with_all_same_group_key(
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""All docs with same group_key should go to same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
@@ -720,11 +864,16 @@ class TestBuildDatasetWithGroupKey:
|
||||
doc.group_key = "same-group"
|
||||
docs.append(doc)
|
||||
|
||||
builder = DatasetBuilder(db=mock_admin_db, base_dir=tmp_path / "datasets")
|
||||
mock_admin_db.get_documents_by_ids.return_value = docs
|
||||
mock_admin_db.get_annotations_for_document.return_value = []
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
documents_repo=mock_documents_repo,
|
||||
annotations_repo=mock_annotations_repo,
|
||||
base_dir=tmp_path / "datasets",
|
||||
)
|
||||
mock_documents_repo.get_by_ids.return_value = docs
|
||||
mock_annotations_repo.get_for_document.return_value = []
|
||||
|
||||
dataset = mock_admin_db.create_dataset.return_value
|
||||
dataset = mock_datasets_repo.create.return_value
|
||||
builder.build_dataset(
|
||||
dataset_id=str(dataset.dataset_id),
|
||||
document_ids=[str(d.document_id) for d in docs],
|
||||
@@ -734,7 +883,7 @@ class TestBuildDatasetWithGroupKey:
|
||||
admin_images_dir=tmp_path / "admin_images",
|
||||
)
|
||||
|
||||
call_args = mock_admin_db.add_dataset_documents.call_args
|
||||
call_args = mock_datasets_repo.add_documents.call_args
|
||||
docs_added = call_args[1]["documents"] if "documents" in call_args[1] else call_args[0][1]
|
||||
|
||||
splits = [d["split"] for d in docs_added]
|
||||
|
||||
@@ -72,6 +72,36 @@ def _find_endpoint(name: str):
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_datasets_repo():
|
||||
"""Mock DatasetRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_documents_repo():
|
||||
"""Mock DocumentRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_annotations_repo():
|
||||
"""Mock AnnotationRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_repo():
|
||||
"""Mock ModelVersionRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tasks_repo():
|
||||
"""Mock TrainingTaskRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestCreateDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets."""
|
||||
|
||||
@@ -80,11 +110,12 @@ class TestCreateDatasetRoute:
|
||||
paths = [route.path for route in router.routes]
|
||||
assert any("datasets" in p for p in paths)
|
||||
|
||||
def test_create_dataset_calls_builder(self):
|
||||
def test_create_dataset_calls_builder(
|
||||
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
mock_datasets_repo.create.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
mock_builder.build_dataset.return_value = {
|
||||
@@ -101,20 +132,30 @@ class TestCreateDatasetRoute:
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
) as mock_cls:
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
), patch(
|
||||
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
||||
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
||||
result = asyncio.run(fn(
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets=mock_datasets_repo,
|
||||
docs=mock_documents_repo,
|
||||
annotations=mock_annotations_repo,
|
||||
))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
mock_datasets_repo.create.assert_called_once()
|
||||
mock_builder.build_dataset.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert result.name == "test-dataset"
|
||||
|
||||
def test_create_dataset_fails_with_less_than_10_documents(self):
|
||||
def test_create_dataset_fails_with_less_than_10_documents(
|
||||
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Test that creating dataset fails if fewer than 10 documents provided."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# Only 2 documents - should fail
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
@@ -124,20 +165,26 @@ class TestCreateDatasetRoute:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets=mock_datasets_repo,
|
||||
docs=mock_documents_repo,
|
||||
annotations=mock_annotations_repo,
|
||||
))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
assert "got 2" in exc_info.value.detail
|
||||
# Ensure DB was never called since validation failed first
|
||||
mock_db.create_dataset.assert_not_called()
|
||||
# Ensure repo was never called since validation failed first
|
||||
mock_datasets_repo.create.assert_not_called()
|
||||
|
||||
def test_create_dataset_fails_with_9_documents(self):
|
||||
def test_create_dataset_fails_with_9_documents(
|
||||
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Test boundary condition: 9 documents should fail."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
|
||||
# 9 documents - just under the limit
|
||||
request = DatasetCreateRequest(
|
||||
name="test-dataset",
|
||||
@@ -147,17 +194,24 @@ class TestCreateDatasetRoute:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets=mock_datasets_repo,
|
||||
docs=mock_documents_repo,
|
||||
annotations=mock_annotations_repo,
|
||||
))
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "Minimum 10 documents required" in exc_info.value.detail
|
||||
|
||||
def test_create_dataset_succeeds_with_exactly_10_documents(self):
|
||||
def test_create_dataset_succeeds_with_exactly_10_documents(
|
||||
self, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Test boundary condition: exactly 10 documents should succeed."""
|
||||
fn = _find_endpoint("create_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_dataset.return_value = _make_dataset(status="building")
|
||||
mock_datasets_repo.create.return_value = _make_dataset(status="building")
|
||||
|
||||
mock_builder = MagicMock()
|
||||
|
||||
@@ -170,25 +224,40 @@ class TestCreateDatasetRoute:
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
):
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
), patch(
|
||||
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
||||
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
||||
result = asyncio.run(fn(
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets=mock_datasets_repo,
|
||||
docs=mock_documents_repo,
|
||||
annotations=mock_annotations_repo,
|
||||
))
|
||||
|
||||
mock_db.create_dataset.assert_called_once()
|
||||
mock_datasets_repo.create.assert_called_once()
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
|
||||
|
||||
class TestListDatasetsRoute:
|
||||
"""Tests for GET /admin/training/datasets."""
|
||||
|
||||
def test_list_datasets(self):
|
||||
def test_list_datasets(self, mock_datasets_repo):
|
||||
fn = _find_endpoint("list_datasets")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_datasets.return_value = ([_make_dataset()], 1)
|
||||
mock_datasets_repo.get_paginated.return_value = ([_make_dataset()], 1)
|
||||
# Mock the active training tasks lookup to return empty dict
|
||||
mock_db.get_active_training_tasks_for_datasets.return_value = {}
|
||||
mock_datasets_repo.get_active_training_tasks.return_value = {}
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
result = asyncio.run(fn(
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
status=None,
|
||||
limit=20,
|
||||
offset=0,
|
||||
))
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.datasets) == 1
|
||||
@@ -198,82 +267,103 @@ class TestListDatasetsRoute:
|
||||
class TestGetDatasetRoute:
|
||||
"""Tests for GET /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_get_dataset_returns_detail(self):
|
||||
def test_get_dataset_returns_detail(self, mock_datasets_repo):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset()
|
||||
mock_db.get_dataset_documents.return_value = [
|
||||
mock_datasets_repo.get.return_value = _make_dataset()
|
||||
mock_datasets_repo.get_documents.return_value = [
|
||||
_make_dataset_doc(TEST_DOC_UUID_1, "train"),
|
||||
_make_dataset_doc(TEST_DOC_UUID_2, "val"),
|
||||
]
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
))
|
||||
|
||||
assert result.dataset_id == TEST_DATASET_UUID
|
||||
assert len(result.documents) == 2
|
||||
|
||||
def test_get_dataset_not_found(self):
|
||||
def test_get_dataset_not_found(self, mock_datasets_repo):
|
||||
fn = _find_endpoint("get_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = None
|
||||
mock_datasets_repo.get.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeleteDatasetRoute:
|
||||
"""Tests for DELETE /admin/training/datasets/{dataset_id}."""
|
||||
|
||||
def test_delete_dataset(self):
|
||||
def test_delete_dataset(self, mock_datasets_repo):
|
||||
fn = _find_endpoint("delete_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(dataset_path=None)
|
||||
mock_datasets_repo.get.return_value = _make_dataset(dataset_path=None)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
))
|
||||
|
||||
mock_db.delete_dataset.assert_called_once_with(TEST_DATASET_UUID)
|
||||
mock_datasets_repo.delete.assert_called_once_with(TEST_DATASET_UUID)
|
||||
assert result["message"] == "Dataset deleted"
|
||||
|
||||
|
||||
class TestTrainFromDatasetRoute:
|
||||
"""Tests for POST /admin/training/datasets/{dataset_id}/train."""
|
||||
|
||||
def test_train_from_ready_dataset(self):
|
||||
def test_train_from_ready_dataset(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
||||
mock_tasks_repo.create.return_value = TEST_TASK_UUID
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
models=mock_models_repo,
|
||||
tasks=mock_tasks_repo,
|
||||
))
|
||||
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert result.status == TrainingStatus.PENDING
|
||||
mock_db.create_training_task.assert_called_once()
|
||||
mock_tasks_repo.create.assert_called_once()
|
||||
|
||||
def test_train_from_building_dataset_fails(self):
|
||||
def test_train_from_building_dataset_fails(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="building")
|
||||
mock_datasets_repo.get.return_value = _make_dataset(status="building")
|
||||
|
||||
request = DatasetTrainRequest(name="train-from-dataset", config=TrainingConfig())
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
models=mock_models_repo,
|
||||
tasks=mock_tasks_repo,
|
||||
))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_incremental_training_with_base_model(self):
|
||||
def test_incremental_training_with_base_model(self, mock_datasets_repo, mock_models_repo, mock_tasks_repo):
|
||||
"""Test training with base_model_version_id for incremental training."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
@@ -281,22 +371,28 @@ class TestTrainFromDatasetRoute:
|
||||
mock_model_version.model_path = "runs/train/invoice_fields/weights/best.pt"
|
||||
mock_model_version.version = "1.0.0"
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = mock_model_version
|
||||
mock_db.create_training_task.return_value = TEST_TASK_UUID
|
||||
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
||||
mock_models_repo.get.return_value = mock_model_version
|
||||
mock_tasks_repo.create.return_value = TEST_TASK_UUID
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
request = DatasetTrainRequest(name="incremental-train", config=config)
|
||||
|
||||
result = asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
models=mock_models_repo,
|
||||
tasks=mock_tasks_repo,
|
||||
))
|
||||
|
||||
# Verify model version was looked up
|
||||
mock_db.get_model_version.assert_called_once_with(base_model_uuid)
|
||||
mock_models_repo.get.assert_called_once_with(base_model_uuid)
|
||||
|
||||
# Verify task was created with finetune type
|
||||
call_kwargs = mock_db.create_training_task.call_args[1]
|
||||
call_kwargs = mock_tasks_repo.create.call_args[1]
|
||||
assert call_kwargs["task_type"] == "finetune"
|
||||
assert call_kwargs["config"]["base_model_path"] == "runs/train/invoice_fields/weights/best.pt"
|
||||
assert call_kwargs["config"]["base_model_version"] == "1.0.0"
|
||||
@@ -304,13 +400,14 @@ class TestTrainFromDatasetRoute:
|
||||
assert result.task_id == TEST_TASK_UUID
|
||||
assert "Incremental training" in result.message
|
||||
|
||||
def test_incremental_training_with_invalid_base_model_fails(self):
|
||||
def test_incremental_training_with_invalid_base_model_fails(
|
||||
self, mock_datasets_repo, mock_models_repo, mock_tasks_repo
|
||||
):
|
||||
"""Test that training fails if base_model_version_id doesn't exist."""
|
||||
fn = _find_endpoint("train_from_dataset")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_dataset.return_value = _make_dataset(status="ready")
|
||||
mock_db.get_model_version.return_value = None
|
||||
mock_datasets_repo.get.return_value = _make_dataset(status="ready")
|
||||
mock_models_repo.get.return_value = None
|
||||
|
||||
base_model_uuid = "550e8400-e29b-41d4-a716-446655440099"
|
||||
config = TrainingConfig(base_model_version_id=base_model_uuid)
|
||||
@@ -319,6 +416,13 @@ class TestTrainFromDatasetRoute:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(dataset_id=TEST_DATASET_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
request=request,
|
||||
admin_token=TEST_TOKEN,
|
||||
datasets_repo=mock_datasets_repo,
|
||||
models=mock_models_repo,
|
||||
tasks=mock_tasks_repo,
|
||||
))
|
||||
assert exc_info.value.status_code == 404
|
||||
assert "Base model version not found" in exc_info.value.detail
|
||||
|
||||
@@ -3,7 +3,7 @@ Tests for dataset training status feature.
|
||||
|
||||
Tests cover:
|
||||
1. Database model fields (training_status, active_training_task_id)
|
||||
2. AdminDB update_dataset_training_status method
|
||||
2. DatasetRepository update_training_status method
|
||||
3. API response includes training status fields
|
||||
4. Scheduler updates dataset status during training lifecycle
|
||||
"""
|
||||
@@ -56,12 +56,12 @@ class TestTrainingDatasetModel:
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test AdminDB Methods
|
||||
# Test DatasetRepository Methods
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAdminDBDatasetTrainingStatus:
|
||||
"""Tests for AdminDB.update_dataset_training_status method."""
|
||||
class TestDatasetRepositoryTrainingStatus:
|
||||
"""Tests for DatasetRepository.update_training_status method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session(self):
|
||||
@@ -69,8 +69,8 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
session = MagicMock()
|
||||
return session
|
||||
|
||||
def test_update_dataset_training_status_sets_status(self, mock_session):
|
||||
"""update_dataset_training_status should set training_status."""
|
||||
def test_update_training_status_sets_status(self, mock_session):
|
||||
"""update_training_status should set training_status."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
@@ -81,13 +81,13 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DatasetRepository
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
)
|
||||
@@ -96,8 +96,8 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
mock_session.add.assert_called_once_with(dataset)
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_dataset_training_status_sets_task_id(self, mock_session):
|
||||
"""update_dataset_training_status should set active_training_task_id."""
|
||||
def test_update_training_status_sets_task_id(self, mock_session):
|
||||
"""update_training_status should set active_training_task_id."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
@@ -109,13 +109,13 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DatasetRepository
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(task_id),
|
||||
@@ -123,10 +123,10 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
|
||||
assert dataset.active_training_task_id == task_id
|
||||
|
||||
def test_update_dataset_training_status_updates_main_status_on_complete(
|
||||
def test_update_training_status_updates_main_status_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should update main status to 'trained' when completed."""
|
||||
"""update_training_status should update main status to 'trained' when completed."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
@@ -137,13 +137,13 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DatasetRepository
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
@@ -152,10 +152,10 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
assert dataset.status == "trained"
|
||||
assert dataset.training_status == "completed"
|
||||
|
||||
def test_update_dataset_training_status_clears_task_id_on_complete(
|
||||
def test_update_training_status_clears_task_id_on_complete(
|
||||
self, mock_session
|
||||
):
|
||||
"""update_dataset_training_status should clear task_id when training completes."""
|
||||
"""update_training_status should clear task_id when training completes."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
@@ -169,13 +169,13 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DatasetRepository
|
||||
|
||||
db = AdminDB()
|
||||
db.update_dataset_training_status(
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
dataset_id=str(dataset_id),
|
||||
training_status="completed",
|
||||
active_training_task_id=None,
|
||||
@@ -183,18 +183,18 @@ class TestAdminDBDatasetTrainingStatus:
|
||||
|
||||
assert dataset.active_training_task_id is None
|
||||
|
||||
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
|
||||
"""update_dataset_training_status should handle missing dataset gracefully."""
|
||||
def test_update_training_status_handles_missing_dataset(self, mock_session):
|
||||
"""update_training_status should handle missing dataset gracefully."""
|
||||
mock_session.get.return_value = None
|
||||
|
||||
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.admin_db import AdminDB
|
||||
from inference.data.repositories import DatasetRepository
|
||||
|
||||
db = AdminDB()
|
||||
repo = DatasetRepository()
|
||||
# Should not raise
|
||||
db.update_dataset_training_status(
|
||||
repo.update_training_status(
|
||||
dataset_id=str(uuid4()),
|
||||
training_status="running",
|
||||
)
|
||||
@@ -275,19 +275,24 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
"""Tests for scheduler updating dataset status during training."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
def mock_datasets_repo(self):
|
||||
"""Create mock DatasetRepository."""
|
||||
mock = MagicMock()
|
||||
mock.get_dataset.return_value = MagicMock(
|
||||
mock.get.return_value = MagicMock(
|
||||
dataset_id=uuid4(),
|
||||
name="test-dataset",
|
||||
dataset_path="/path/to/dataset",
|
||||
total_images=100,
|
||||
)
|
||||
mock.get_pending_training_tasks.return_value = []
|
||||
return mock
|
||||
|
||||
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
|
||||
@pytest.fixture
|
||||
def mock_training_tasks_repo(self):
|
||||
"""Create mock TrainingTaskRepository."""
|
||||
mock = MagicMock()
|
||||
return mock
|
||||
|
||||
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
|
||||
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
|
||||
@@ -295,7 +300,8 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
scheduler._db = mock_db
|
||||
scheduler._datasets = mock_datasets_repo
|
||||
scheduler._training_tasks = mock_training_tasks_repo
|
||||
|
||||
task_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
@@ -311,8 +317,8 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
pass # Expected to fail in test environment
|
||||
|
||||
# Check that training status was updated to running
|
||||
mock_db.update_dataset_training_status.assert_called()
|
||||
first_call = mock_db.update_dataset_training_status.call_args_list[0]
|
||||
mock_datasets_repo.update_training_status.assert_called()
|
||||
first_call = mock_datasets_repo.update_training_status.call_args_list[0]
|
||||
assert first_call.kwargs["training_status"] == "running"
|
||||
assert first_call.kwargs["active_training_task_id"] == task_id
|
||||
|
||||
|
||||
@@ -45,10 +45,10 @@ class TestDocumentListFilterByCategory:
|
||||
"""Tests for filtering documents by category."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_admin_db(self):
|
||||
"""Create mock AdminDB."""
|
||||
db = MagicMock()
|
||||
db.is_valid_admin_token.return_value = True
|
||||
def mock_document_repo(self):
|
||||
"""Create mock DocumentRepository."""
|
||||
repo = MagicMock()
|
||||
repo.is_valid.return_value = True
|
||||
|
||||
# Mock documents with different categories
|
||||
invoice_doc = MagicMock()
|
||||
@@ -61,11 +61,11 @@ class TestDocumentListFilterByCategory:
|
||||
letter_doc.category = "letter"
|
||||
letter_doc.filename = "letter1.pdf"
|
||||
|
||||
db.get_documents.return_value = ([invoice_doc], 1)
|
||||
db.get_document_categories.return_value = ["invoice", "letter", "receipt"]
|
||||
return db
|
||||
repo.get_paginated.return_value = ([invoice_doc], 1)
|
||||
repo.get_categories.return_value = ["invoice", "letter", "receipt"]
|
||||
return repo
|
||||
|
||||
def test_list_documents_accepts_category_filter(self, mock_admin_db):
|
||||
def test_list_documents_accepts_category_filter(self, mock_document_repo):
|
||||
"""Test list documents endpoint accepts category query parameter."""
|
||||
# The endpoint should accept ?category=invoice parameter
|
||||
# This test verifies the schema/query parameter exists
|
||||
@@ -74,9 +74,9 @@ class TestDocumentListFilterByCategory:
|
||||
# Schema should work with category filter applied
|
||||
assert DocumentListResponse is not None
|
||||
|
||||
def test_get_document_categories_from_db(self, mock_admin_db):
|
||||
"""Test fetching unique categories from database."""
|
||||
categories = mock_admin_db.get_document_categories()
|
||||
def test_get_document_categories_from_repo(self, mock_document_repo):
|
||||
"""Test fetching unique categories from repository."""
|
||||
categories = mock_document_repo.get_categories()
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
assert len(categories) == 3
|
||||
@@ -122,24 +122,24 @@ class TestDocumentUploadWithCategory:
|
||||
assert response.category == "invoice"
|
||||
|
||||
|
||||
class TestAdminDBCategoryMethods:
|
||||
"""Tests for AdminDB category-related methods."""
|
||||
class TestDocumentRepositoryCategoryMethods:
|
||||
"""Tests for DocumentRepository category-related methods."""
|
||||
|
||||
def test_get_document_categories_method_exists(self):
|
||||
"""Test AdminDB has get_document_categories method."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
def test_get_categories_method_exists(self):
|
||||
"""Test DocumentRepository has get_categories method."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "get_document_categories")
|
||||
repo = DocumentRepository()
|
||||
assert hasattr(repo, "get_categories")
|
||||
|
||||
def test_get_documents_accepts_category_filter(self):
|
||||
"""Test get_documents_by_token method accepts category parameter."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
def test_get_paginated_accepts_category_filter(self):
|
||||
"""Test get_paginated method accepts category parameter."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
import inspect
|
||||
|
||||
db = AdminDB()
|
||||
repo = DocumentRepository()
|
||||
# Check the method exists and accepts category parameter
|
||||
method = getattr(db, "get_documents_by_token", None)
|
||||
method = getattr(repo, "get_paginated", None)
|
||||
assert callable(method)
|
||||
|
||||
# Check category is in the method signature
|
||||
@@ -150,12 +150,12 @@ class TestAdminDBCategoryMethods:
|
||||
class TestUpdateDocumentCategory:
|
||||
"""Tests for updating document category."""
|
||||
|
||||
def test_update_document_category_method_exists(self):
|
||||
"""Test AdminDB has method to update document category."""
|
||||
from inference.data.admin_db import AdminDB
|
||||
def test_update_category_method_exists(self):
|
||||
"""Test DocumentRepository has method to update document category."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
|
||||
db = AdminDB()
|
||||
assert hasattr(db, "update_document_category")
|
||||
repo = DocumentRepository()
|
||||
assert hasattr(repo, "update_category")
|
||||
|
||||
def test_update_request_schema(self):
|
||||
"""Test DocumentUpdateRequest can update category."""
|
||||
|
||||
@@ -63,6 +63,12 @@ def _find_endpoint(name: str):
|
||||
raise AssertionError(f"Endpoint {name} not found")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models_repo():
|
||||
"""Mock ModelVersionRepository."""
|
||||
return MagicMock()
|
||||
|
||||
|
||||
class TestModelVersionRouterRegistration:
|
||||
"""Tests that model version endpoints are registered."""
|
||||
|
||||
@@ -91,11 +97,10 @@ class TestModelVersionRouterRegistration:
|
||||
class TestCreateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models."""
|
||||
|
||||
def test_create_model_version(self):
|
||||
def test_create_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
mock_models_repo.create.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
@@ -106,18 +111,17 @@ class TestCreateModelVersionRoute:
|
||||
document_count=100,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
mock_db.create_model_version.assert_called_once()
|
||||
mock_models_repo.create.assert_called_once()
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version created successfully"
|
||||
|
||||
def test_create_model_version_with_task_and_dataset(self):
|
||||
def test_create_model_version_with_task_and_dataset(self, mock_models_repo):
|
||||
fn = _find_endpoint("create_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.create_model_version.return_value = _make_model_version()
|
||||
mock_models_repo.create.return_value = _make_model_version()
|
||||
|
||||
request = ModelVersionCreateRequest(
|
||||
version="1.0.0",
|
||||
@@ -127,9 +131,9 @@ class TestCreateModelVersionRoute:
|
||||
dataset_id=TEST_DATASET_UUID,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
call_kwargs = mock_db.create_model_version.call_args[1]
|
||||
call_kwargs = mock_models_repo.create.call_args[1]
|
||||
assert call_kwargs["task_id"] == TEST_TASK_UUID
|
||||
assert call_kwargs["dataset_id"] == TEST_DATASET_UUID
|
||||
|
||||
@@ -137,30 +141,28 @@ class TestCreateModelVersionRoute:
|
||||
class TestListModelVersionsRoute:
|
||||
"""Tests for GET /admin/training/models."""
|
||||
|
||||
def test_list_model_versions(self):
|
||||
def test_list_model_versions(self, mock_models_repo):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = (
|
||||
mock_models_repo.get_paginated.return_value = (
|
||||
[_make_model_version(), _make_model_version(version_id=UUID(TEST_VERSION_UUID_2), version="1.1.0")],
|
||||
2,
|
||||
)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status=None, limit=20, offset=0))
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status=None, limit=20, offset=0))
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.models) == 2
|
||||
assert result.models[0].version == "1.0.0"
|
||||
|
||||
def test_list_model_versions_with_status_filter(self):
|
||||
def test_list_model_versions_with_status_filter(self, mock_models_repo):
|
||||
fn = _find_endpoint("list_model_versions")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_versions.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
||||
mock_models_repo.get_paginated.return_value = ([_make_model_version(status="active", is_active=True)], 1)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db, status="active", limit=20, offset=0))
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo, status="active", limit=20, offset=0))
|
||||
|
||||
mock_db.get_model_versions.assert_called_once_with(status="active", limit=20, offset=0)
|
||||
mock_models_repo.get_paginated.assert_called_once_with(status="active", limit=20, offset=0)
|
||||
assert result.total == 1
|
||||
assert result.models[0].status == "active"
|
||||
|
||||
@@ -168,25 +170,23 @@ class TestListModelVersionsRoute:
|
||||
class TestGetActiveModelRoute:
|
||||
"""Tests for GET /admin/training/models/active."""
|
||||
|
||||
def test_get_active_model_when_exists(self):
|
||||
def test_get_active_model_when_exists(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.has_active_model is True
|
||||
assert result.model is not None
|
||||
assert result.model.is_active is True
|
||||
|
||||
def test_get_active_model_when_none(self):
|
||||
def test_get_active_model_when_none(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_active_model_version.return_value = None
|
||||
mock_models_repo.get_active.return_value = None
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.has_active_model is False
|
||||
assert result.model is None
|
||||
@@ -195,46 +195,43 @@ class TestGetActiveModelRoute:
|
||||
class TestGetModelVersionRoute:
|
||||
"""Tests for GET /admin/training/models/{version_id}."""
|
||||
|
||||
def test_get_model_version(self):
|
||||
def test_get_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = _make_model_version()
|
||||
mock_models_repo.get.return_value = _make_model_version()
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.version_id == TEST_VERSION_UUID
|
||||
assert result.version == "1.0.0"
|
||||
assert result.name == "test-model-v1"
|
||||
assert result.metrics_mAP == 0.935
|
||||
|
||||
def test_get_model_version_not_found(self):
|
||||
def test_get_model_version_not_found(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_model_version.return_value = None
|
||||
mock_models_repo.get.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestUpdateModelVersionRoute:
|
||||
"""Tests for PATCH /admin/training/models/{version_id}."""
|
||||
|
||||
def test_update_model_version(self):
|
||||
def test_update_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = _make_model_version(name="updated-name")
|
||||
mock_models_repo.update.return_value = _make_model_version(name="updated-name")
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name", description="Updated description")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
mock_db.update_model_version.assert_called_once_with(
|
||||
mock_models_repo.update.assert_called_once_with(
|
||||
version_id=TEST_VERSION_UUID,
|
||||
name="updated-name",
|
||||
description="Updated description",
|
||||
@@ -242,45 +239,42 @@ class TestUpdateModelVersionRoute:
|
||||
)
|
||||
assert result.message == "Model version updated successfully"
|
||||
|
||||
def test_update_model_version_not_found(self):
|
||||
def test_update_model_version_not_found(self, mock_models_repo):
|
||||
fn = _find_endpoint("update_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.update_model_version.return_value = None
|
||||
mock_models_repo.update.return_value = None
|
||||
|
||||
request = ModelVersionUpdateRequest(name="updated-name")
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestActivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/activate."""
|
||||
|
||||
def test_activate_model_version(self):
|
||||
def test_activate_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = _make_model_version(status="active", is_active=True)
|
||||
mock_models_repo.activate.return_value = _make_model_version(status="active", is_active=True)
|
||||
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
mock_request.app.state.inference_service = None
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
mock_db.activate_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
mock_models_repo.activate.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result.status == "active"
|
||||
assert result.message == "Model version activated for inference"
|
||||
|
||||
def test_activate_model_version_not_found(self):
|
||||
def test_activate_model_version_not_found(self, mock_models_repo):
|
||||
fn = _find_endpoint("activate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.activate_model_version.return_value = None
|
||||
mock_models_repo.activate.return_value = None
|
||||
|
||||
# Create mock request with app state
|
||||
mock_request = MagicMock()
|
||||
@@ -289,88 +283,82 @@ class TestActivateModelVersionRoute:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, request=mock_request, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestDeactivateModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/deactivate."""
|
||||
|
||||
def test_deactivate_model_version(self):
|
||||
def test_deactivate_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = _make_model_version(status="inactive", is_active=False)
|
||||
mock_models_repo.deactivate.return_value = _make_model_version(status="inactive", is_active=False)
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.status == "inactive"
|
||||
assert result.message == "Model version deactivated"
|
||||
|
||||
def test_deactivate_model_version_not_found(self):
|
||||
def test_deactivate_model_version_not_found(self, mock_models_repo):
|
||||
fn = _find_endpoint("deactivate_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.deactivate_model_version.return_value = None
|
||||
mock_models_repo.deactivate.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
class TestArchiveModelVersionRoute:
|
||||
"""Tests for POST /admin/training/models/{version_id}/archive."""
|
||||
|
||||
def test_archive_model_version(self):
|
||||
def test_archive_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = _make_model_version(status="archived")
|
||||
mock_models_repo.archive.return_value = _make_model_version(status="archived")
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.status == "archived"
|
||||
assert result.message == "Model version archived"
|
||||
|
||||
def test_archive_active_model_fails(self):
|
||||
def test_archive_active_model_fails(self, mock_models_repo):
|
||||
fn = _find_endpoint("archive_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.archive_model_version.return_value = None
|
||||
mock_models_repo.archive.return_value = None
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
class TestDeleteModelVersionRoute:
|
||||
"""Tests for DELETE /admin/training/models/{version_id}."""
|
||||
|
||||
def test_delete_model_version(self):
|
||||
def test_delete_model_version(self, mock_models_repo):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = True
|
||||
mock_models_repo.delete.return_value = True
|
||||
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
result = asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
mock_db.delete_model_version.assert_called_once_with(TEST_VERSION_UUID)
|
||||
mock_models_repo.delete.assert_called_once_with(TEST_VERSION_UUID)
|
||||
assert result["message"] == "Model version deleted"
|
||||
|
||||
def test_delete_active_model_fails(self):
|
||||
def test_delete_active_model_fails(self, mock_models_repo):
|
||||
fn = _find_endpoint("delete_model_version")
|
||||
|
||||
mock_db = MagicMock()
|
||||
mock_db.delete_model_version.return_value = False
|
||||
mock_models_repo.delete.return_value = False
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, db=mock_db))
|
||||
asyncio.run(fn(version_id=TEST_VERSION_UUID, admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,13 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.core.auth import validate_admin_token, get_admin_db
|
||||
from inference.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
get_training_task_repository,
|
||||
get_model_version_repository,
|
||||
)
|
||||
|
||||
|
||||
class MockTrainingTask:
|
||||
@@ -128,19 +134,17 @@ class MockModelVersion:
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing Phase 4."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = {}
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
self.model_versions = {}
|
||||
self.annotations = {} # Shared reference for filtering
|
||||
self.training_links = {} # Shared reference for filtering
|
||||
|
||||
def get_documents_for_training(
|
||||
def get_for_training(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status="labeled",
|
||||
has_annotations=True,
|
||||
min_annotation_count=None,
|
||||
@@ -173,17 +177,28 @@ class MockAdminDB:
|
||||
total = len(filtered)
|
||||
return filtered[offset:offset+limit], total
|
||||
|
||||
def get_annotations_for_document(self, document_id):
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = {}
|
||||
|
||||
def get_for_document(self, document_id, page_number=None):
|
||||
"""Get annotations for document."""
|
||||
return self.annotations.get(str(document_id), [])
|
||||
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
def get_training_tasks_by_token(
|
||||
class MockTrainingTaskRepository:
|
||||
"""Mock TrainingTaskRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.training_tasks = {}
|
||||
self.training_links = {}
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
admin_token,
|
||||
admin_token=None,
|
||||
status=None,
|
||||
limit=20,
|
||||
offset=0,
|
||||
@@ -196,11 +211,22 @@ class MockAdminDB:
|
||||
total = len(tasks)
|
||||
return tasks[offset:offset+limit], total
|
||||
|
||||
def get_training_task(self, task_id):
|
||||
def get(self, task_id):
|
||||
"""Get training task by ID."""
|
||||
return self.training_tasks.get(str(task_id))
|
||||
|
||||
def get_model_versions(self, status=None, limit=20, offset=0):
|
||||
def get_document_training_tasks(self, document_id):
|
||||
"""Get training tasks that used this document."""
|
||||
return self.training_links.get(str(document_id), [])
|
||||
|
||||
|
||||
class MockModelVersionRepository:
|
||||
"""Mock ModelVersionRepository for testing Phase 4."""
|
||||
|
||||
def __init__(self):
|
||||
self.model_versions = {}
|
||||
|
||||
def get_paginated(self, status=None, limit=20, offset=0):
|
||||
"""Get model versions with optional filtering."""
|
||||
models = list(self.model_versions.values())
|
||||
if status:
|
||||
@@ -214,8 +240,11 @@ def app():
|
||||
"""Create test FastAPI app."""
|
||||
app = FastAPI()
|
||||
|
||||
# Create mock DB
|
||||
mock_db = MockAdminDB()
|
||||
# Create mock repositories
|
||||
mock_document_repo = MockDocumentRepository()
|
||||
mock_annotation_repo = MockAnnotationRepository()
|
||||
mock_training_task_repo = MockTrainingTaskRepository()
|
||||
mock_model_version_repo = MockModelVersionRepository()
|
||||
|
||||
# Add test documents
|
||||
doc1 = MockAdminDocument(
|
||||
@@ -231,22 +260,25 @@ def app():
|
||||
status="labeled",
|
||||
)
|
||||
|
||||
mock_db.documents[str(doc1.document_id)] = doc1
|
||||
mock_db.documents[str(doc2.document_id)] = doc2
|
||||
mock_db.documents[str(doc3.document_id)] = doc3
|
||||
mock_document_repo.documents[str(doc1.document_id)] = doc1
|
||||
mock_document_repo.documents[str(doc2.document_id)] = doc2
|
||||
mock_document_repo.documents[str(doc3.document_id)] = doc3
|
||||
|
||||
# Add annotations
|
||||
mock_db.annotations[str(doc1.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc1.document_id)] = [
|
||||
MockAnnotation(document_id=doc1.document_id, source="manual"),
|
||||
MockAnnotation(document_id=doc1.document_id, source="auto"),
|
||||
]
|
||||
mock_db.annotations[str(doc2.document_id)] = [
|
||||
mock_annotation_repo.annotations[str(doc2.document_id)] = [
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
MockAnnotation(document_id=doc2.document_id, source="auto"),
|
||||
]
|
||||
# doc3 has no annotations
|
||||
|
||||
# Share annotation data with document repo for filtering
|
||||
mock_document_repo.annotations = mock_annotation_repo.annotations
|
||||
|
||||
# Add training tasks
|
||||
task1 = MockTrainingTask(
|
||||
name="Training Run 2024-01",
|
||||
@@ -265,15 +297,18 @@ def app():
|
||||
metrics_recall=0.92,
|
||||
)
|
||||
|
||||
mock_db.training_tasks[str(task1.task_id)] = task1
|
||||
mock_db.training_tasks[str(task2.task_id)] = task2
|
||||
mock_training_task_repo.training_tasks[str(task1.task_id)] = task1
|
||||
mock_training_task_repo.training_tasks[str(task2.task_id)] = task2
|
||||
|
||||
# Add training links (doc1 used in task1)
|
||||
link1 = MockTrainingDocumentLink(
|
||||
task_id=task1.task_id,
|
||||
document_id=doc1.document_id,
|
||||
)
|
||||
mock_db.training_links[str(doc1.document_id)] = [link1]
|
||||
mock_training_task_repo.training_links[str(doc1.document_id)] = [link1]
|
||||
|
||||
# Share training links with document repo for filtering
|
||||
mock_document_repo.training_links = mock_training_task_repo.training_links
|
||||
|
||||
# Add model versions
|
||||
model1 = MockModelVersion(
|
||||
@@ -296,12 +331,15 @@ def app():
|
||||
metrics_recall=0.92,
|
||||
document_count=600,
|
||||
)
|
||||
mock_db.model_versions[str(model1.version_id)] = model1
|
||||
mock_db.model_versions[str(model2.version_id)] = model2
|
||||
mock_model_version_repo.model_versions[str(model1.version_id)] = model1
|
||||
mock_model_version_repo.model_versions[str(model2.version_id)] = model2
|
||||
|
||||
# Override dependencies
|
||||
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
||||
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
||||
app.dependency_overrides[get_document_repository] = lambda: mock_document_repo
|
||||
app.dependency_overrides[get_annotation_repository] = lambda: mock_annotation_repo
|
||||
app.dependency_overrides[get_training_task_repository] = lambda: mock_training_task_repo
|
||||
app.dependency_overrides[get_model_version_repository] = lambda: mock_model_version_repo
|
||||
|
||||
# Include router
|
||||
router = create_training_router()
|
||||
|
||||
Reference in New Issue
Block a user