673 lines
22 KiB
Python
673 lines
22 KiB
Python
"""Tests for unified task management interface.
|
|
|
|
TDD: These tests are written first (RED phase).
|
|
"""
|
|
|
|
from abc import ABC
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
class TestTaskStatus:
|
|
"""Tests for TaskStatus dataclass."""
|
|
|
|
def test_task_status_basic_fields(self) -> None:
|
|
"""TaskStatus has all required fields."""
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
status = TaskStatus(
|
|
name="test_runner",
|
|
is_running=True,
|
|
pending_count=5,
|
|
processing_count=2,
|
|
)
|
|
assert status.name == "test_runner"
|
|
assert status.is_running is True
|
|
assert status.pending_count == 5
|
|
assert status.processing_count == 2
|
|
|
|
def test_task_status_with_error(self) -> None:
|
|
"""TaskStatus can include optional error message."""
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
status = TaskStatus(
|
|
name="failed_runner",
|
|
is_running=False,
|
|
pending_count=0,
|
|
processing_count=0,
|
|
error="Connection failed",
|
|
)
|
|
assert status.error == "Connection failed"
|
|
|
|
def test_task_status_default_error_is_none(self) -> None:
|
|
"""TaskStatus error defaults to None."""
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
status = TaskStatus(
|
|
name="test",
|
|
is_running=True,
|
|
pending_count=0,
|
|
processing_count=0,
|
|
)
|
|
assert status.error is None
|
|
|
|
def test_task_status_is_frozen(self) -> None:
|
|
"""TaskStatus is immutable (frozen dataclass)."""
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
status = TaskStatus(
|
|
name="test",
|
|
is_running=True,
|
|
pending_count=0,
|
|
processing_count=0,
|
|
)
|
|
with pytest.raises(AttributeError):
|
|
status.name = "changed" # type: ignore[misc]
|
|
|
|
|
|
class TestTaskRunnerInterface:
|
|
"""Tests for TaskRunner abstract base class."""
|
|
|
|
def test_cannot_instantiate_directly(self) -> None:
|
|
"""TaskRunner is abstract and cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
with pytest.raises(TypeError):
|
|
TaskRunner() # type: ignore[abstract]
|
|
|
|
def test_is_abstract_base_class(self) -> None:
|
|
"""TaskRunner inherits from ABC."""
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
assert issubclass(TaskRunner, ABC)
|
|
|
|
def test_subclass_missing_name_cannot_instantiate(self) -> None:
|
|
"""Subclass without name property cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
|
|
|
class MissingName(TaskRunner):
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("", False, 0, 0)
|
|
|
|
with pytest.raises(TypeError):
|
|
MissingName() # type: ignore[abstract]
|
|
|
|
def test_subclass_missing_start_cannot_instantiate(self) -> None:
|
|
"""Subclass without start method cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
|
|
|
class MissingStart(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("", False, 0, 0)
|
|
|
|
with pytest.raises(TypeError):
|
|
MissingStart() # type: ignore[abstract]
|
|
|
|
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
|
|
"""Subclass without stop method cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
|
|
|
class MissingStop(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("", False, 0, 0)
|
|
|
|
with pytest.raises(TypeError):
|
|
MissingStop() # type: ignore[abstract]
|
|
|
|
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
|
|
"""Subclass without is_running property cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
|
|
|
class MissingIsRunning(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("", False, 0, 0)
|
|
|
|
with pytest.raises(TypeError):
|
|
MissingIsRunning() # type: ignore[abstract]
|
|
|
|
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
|
|
"""Subclass without get_status method cannot be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
class MissingGetStatus(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "test"
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
with pytest.raises(TypeError):
|
|
MissingGetStatus() # type: ignore[abstract]
|
|
|
|
def test_complete_subclass_can_instantiate(self) -> None:
|
|
"""Complete subclass implementing all methods can be instantiated."""
|
|
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
|
|
|
class CompleteRunner(TaskRunner):
|
|
def __init__(self) -> None:
|
|
self._running = False
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "complete_runner"
|
|
|
|
def start(self) -> None:
|
|
self._running = True
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
self._running = False
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._running
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(
|
|
name=self.name,
|
|
is_running=self._running,
|
|
pending_count=0,
|
|
processing_count=0,
|
|
)
|
|
|
|
runner = CompleteRunner()
|
|
assert runner.name == "complete_runner"
|
|
assert runner.is_running is False
|
|
|
|
runner.start()
|
|
assert runner.is_running is True
|
|
|
|
status = runner.get_status()
|
|
assert status.name == "complete_runner"
|
|
assert status.is_running is True
|
|
|
|
runner.stop()
|
|
assert runner.is_running is False
|
|
|
|
|
|
class TestTaskManager:
|
|
"""Tests for TaskManager facade."""
|
|
|
|
def test_register_runner(self) -> None:
|
|
"""Can register a task runner."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
class MockRunner(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "mock"
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("mock", False, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
runner = MockRunner()
|
|
manager.register(runner)
|
|
|
|
assert manager.get_runner("mock") is runner
|
|
|
|
def test_get_runner_returns_none_for_unknown(self) -> None:
|
|
"""get_runner returns None for unknown runner name."""
|
|
from inference.web.core.task_interface import TaskManager
|
|
|
|
manager = TaskManager()
|
|
assert manager.get_runner("unknown") is None
|
|
|
|
def test_start_all_runners(self) -> None:
|
|
"""start_all starts all registered runners."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
class MockRunner(TaskRunner):
|
|
def __init__(self, runner_name: str) -> None:
|
|
self._name = runner_name
|
|
self._running = False
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def start(self) -> None:
|
|
self._running = True
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
self._running = False
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._running
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(self._name, self._running, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
runner1 = MockRunner("runner1")
|
|
runner2 = MockRunner("runner2")
|
|
manager.register(runner1)
|
|
manager.register(runner2)
|
|
|
|
assert runner1.is_running is False
|
|
assert runner2.is_running is False
|
|
|
|
manager.start_all()
|
|
|
|
assert runner1.is_running is True
|
|
assert runner2.is_running is True
|
|
|
|
def test_stop_all_runners(self) -> None:
|
|
"""stop_all stops all registered runners."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
class MockRunner(TaskRunner):
|
|
def __init__(self, runner_name: str) -> None:
|
|
self._name = runner_name
|
|
self._running = True
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def start(self) -> None:
|
|
self._running = True
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
self._running = False
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return self._running
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(self._name, self._running, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
runner1 = MockRunner("runner1")
|
|
runner2 = MockRunner("runner2")
|
|
manager.register(runner1)
|
|
manager.register(runner2)
|
|
|
|
assert runner1.is_running is True
|
|
assert runner2.is_running is True
|
|
|
|
manager.stop_all()
|
|
|
|
assert runner1.is_running is False
|
|
assert runner2.is_running is False
|
|
|
|
def test_get_all_status(self) -> None:
|
|
"""get_all_status returns status of all runners."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
class MockRunner(TaskRunner):
|
|
def __init__(self, runner_name: str, pending: int) -> None:
|
|
self._name = runner_name
|
|
self._pending = pending
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return True
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(self._name, True, self._pending, 0)
|
|
|
|
manager = TaskManager()
|
|
manager.register(MockRunner("runner1", 5))
|
|
manager.register(MockRunner("runner2", 10))
|
|
|
|
all_status = manager.get_all_status()
|
|
|
|
assert len(all_status) == 2
|
|
assert all_status["runner1"].pending_count == 5
|
|
assert all_status["runner2"].pending_count == 10
|
|
|
|
def test_get_all_status_empty_when_no_runners(self) -> None:
|
|
"""get_all_status returns empty dict when no runners registered."""
|
|
from inference.web.core.task_interface import TaskManager
|
|
|
|
manager = TaskManager()
|
|
assert manager.get_all_status() == {}
|
|
|
|
def test_runner_names_property(self) -> None:
|
|
"""runner_names returns list of all registered runner names."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
class MockRunner(TaskRunner):
|
|
def __init__(self, runner_name: str) -> None:
|
|
self._name = runner_name
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(self._name, False, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
manager.register(MockRunner("alpha"))
|
|
manager.register(MockRunner("beta"))
|
|
|
|
names = manager.runner_names
|
|
assert set(names) == {"alpha", "beta"}
|
|
|
|
def test_stop_all_with_timeout_distribution(self) -> None:
|
|
"""stop_all distributes timeout across runners."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
received_timeouts: list[float | None] = []
|
|
|
|
class MockRunner(TaskRunner):
|
|
def __init__(self, runner_name: str) -> None:
|
|
self._name = runner_name
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def start(self) -> None:
|
|
pass
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
received_timeouts.append(timeout)
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus(self._name, False, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
manager.register(MockRunner("r1"))
|
|
manager.register(MockRunner("r2"))
|
|
|
|
manager.stop_all(timeout=20.0)
|
|
|
|
# Timeout should be distributed (20 / 2 = 10 each)
|
|
assert len(received_timeouts) == 2
|
|
assert all(t == 10.0 for t in received_timeouts)
|
|
|
|
def test_start_all_skips_runners_requiring_arguments(self) -> None:
|
|
"""start_all skips runners that require arguments."""
|
|
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
|
|
|
no_args_started = []
|
|
with_args_started = []
|
|
|
|
class NoArgsRunner(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "no_args"
|
|
|
|
def start(self) -> None:
|
|
no_args_started.append(True)
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("no_args", False, 0, 0)
|
|
|
|
class RequiresArgsRunner(TaskRunner):
|
|
@property
|
|
def name(self) -> str:
|
|
return "requires_args"
|
|
|
|
def start(self, handler: object) -> None: # type: ignore[override]
|
|
# This runner requires an argument
|
|
with_args_started.append(True)
|
|
|
|
def stop(self, timeout: float | None = None) -> None:
|
|
pass
|
|
|
|
@property
|
|
def is_running(self) -> bool:
|
|
return False
|
|
|
|
def get_status(self) -> TaskStatus:
|
|
return TaskStatus("requires_args", False, 0, 0)
|
|
|
|
manager = TaskManager()
|
|
manager.register(NoArgsRunner())
|
|
manager.register(RequiresArgsRunner())
|
|
|
|
# start_all should start no_args runner but skip requires_args
|
|
manager.start_all()
|
|
|
|
assert len(no_args_started) == 1
|
|
assert len(with_args_started) == 0 # Skipped due to TypeError
|
|
|
|
def test_stop_all_with_no_runners(self) -> None:
|
|
"""stop_all does nothing when no runners registered."""
|
|
from inference.web.core.task_interface import TaskManager
|
|
|
|
manager = TaskManager()
|
|
# Should not raise any exception
|
|
manager.stop_all()
|
|
# Just verify it returns without error
|
|
assert manager.runner_names == []
|
|
|
|
|
|
class TestTrainingSchedulerInterface:
|
|
"""Tests for TrainingScheduler implementing TaskRunner."""
|
|
|
|
def test_training_scheduler_is_task_runner(self) -> None:
|
|
"""TrainingScheduler inherits from TaskRunner."""
|
|
from inference.web.core.scheduler import TrainingScheduler
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
scheduler = TrainingScheduler()
|
|
assert isinstance(scheduler, TaskRunner)
|
|
|
|
def test_training_scheduler_name(self) -> None:
|
|
"""TrainingScheduler has correct name."""
|
|
from inference.web.core.scheduler import TrainingScheduler
|
|
|
|
scheduler = TrainingScheduler()
|
|
assert scheduler.name == "training_scheduler"
|
|
|
|
def test_training_scheduler_get_status(self) -> None:
|
|
"""TrainingScheduler provides status via get_status."""
|
|
from inference.web.core.scheduler import TrainingScheduler
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
scheduler = TrainingScheduler()
|
|
# Mock the training tasks repository
|
|
mock_tasks = MagicMock()
|
|
mock_tasks.get_pending.return_value = [MagicMock(), MagicMock()]
|
|
scheduler._training_tasks = mock_tasks
|
|
|
|
status = scheduler.get_status()
|
|
|
|
assert isinstance(status, TaskStatus)
|
|
assert status.name == "training_scheduler"
|
|
assert status.is_running is False
|
|
assert status.pending_count == 2
|
|
|
|
|
|
class TestAutoLabelSchedulerInterface:
|
|
"""Tests for AutoLabelScheduler implementing TaskRunner."""
|
|
|
|
def test_autolabel_scheduler_is_task_runner(self) -> None:
|
|
"""AutoLabelScheduler inherits from TaskRunner."""
|
|
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
|
scheduler = AutoLabelScheduler()
|
|
assert isinstance(scheduler, TaskRunner)
|
|
|
|
def test_autolabel_scheduler_name(self) -> None:
|
|
"""AutoLabelScheduler has correct name."""
|
|
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
|
|
|
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
|
scheduler = AutoLabelScheduler()
|
|
assert scheduler.name == "autolabel_scheduler"
|
|
|
|
def test_autolabel_scheduler_get_status(self) -> None:
|
|
"""AutoLabelScheduler provides status via get_status."""
|
|
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
|
with patch(
|
|
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
|
|
) as mock_get:
|
|
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
|
|
|
|
scheduler = AutoLabelScheduler()
|
|
status = scheduler.get_status()
|
|
|
|
assert isinstance(status, TaskStatus)
|
|
assert status.name == "autolabel_scheduler"
|
|
assert status.is_running is False
|
|
assert status.pending_count == 3
|
|
|
|
|
|
class TestAsyncTaskQueueInterface:
|
|
"""Tests for AsyncTaskQueue implementing TaskRunner."""
|
|
|
|
def test_async_queue_is_task_runner(self) -> None:
|
|
"""AsyncTaskQueue inherits from TaskRunner."""
|
|
from inference.web.workers.async_queue import AsyncTaskQueue
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
queue = AsyncTaskQueue()
|
|
assert isinstance(queue, TaskRunner)
|
|
|
|
def test_async_queue_name(self) -> None:
|
|
"""AsyncTaskQueue has correct name."""
|
|
from inference.web.workers.async_queue import AsyncTaskQueue
|
|
|
|
queue = AsyncTaskQueue()
|
|
assert queue.name == "async_task_queue"
|
|
|
|
def test_async_queue_get_status(self) -> None:
|
|
"""AsyncTaskQueue provides status via get_status."""
|
|
from inference.web.workers.async_queue import AsyncTaskQueue
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
queue = AsyncTaskQueue()
|
|
status = queue.get_status()
|
|
|
|
assert isinstance(status, TaskStatus)
|
|
assert status.name == "async_task_queue"
|
|
assert status.is_running is False
|
|
assert status.pending_count == 0
|
|
assert status.processing_count == 0
|
|
|
|
|
|
class TestBatchTaskQueueInterface:
|
|
"""Tests for BatchTaskQueue implementing TaskRunner."""
|
|
|
|
def test_batch_queue_is_task_runner(self) -> None:
|
|
"""BatchTaskQueue inherits from TaskRunner."""
|
|
from inference.web.workers.batch_queue import BatchTaskQueue
|
|
from inference.web.core.task_interface import TaskRunner
|
|
|
|
queue = BatchTaskQueue()
|
|
assert isinstance(queue, TaskRunner)
|
|
|
|
def test_batch_queue_name(self) -> None:
|
|
"""BatchTaskQueue has correct name."""
|
|
from inference.web.workers.batch_queue import BatchTaskQueue
|
|
|
|
queue = BatchTaskQueue()
|
|
assert queue.name == "batch_task_queue"
|
|
|
|
def test_batch_queue_get_status(self) -> None:
|
|
"""BatchTaskQueue provides status via get_status."""
|
|
from inference.web.workers.batch_queue import BatchTaskQueue
|
|
from inference.web.core.task_interface import TaskStatus
|
|
|
|
queue = BatchTaskQueue()
|
|
status = queue.get_status()
|
|
|
|
assert isinstance(status, TaskStatus)
|
|
assert status.name == "batch_task_queue"
|
|
assert status.is_running is False
|
|
assert status.pending_count == 0
|