This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -0,0 +1 @@
"""Tests for web core components."""

View File

@@ -0,0 +1,672 @@
"""Tests for unified task management interface.
TDD: These tests are written first (RED phase).
"""
from abc import ABC
from unittest.mock import MagicMock, patch
import pytest
class TestTaskStatus:
"""Tests for TaskStatus dataclass."""
def test_task_status_basic_fields(self) -> None:
"""TaskStatus has all required fields."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test_runner",
is_running=True,
pending_count=5,
processing_count=2,
)
assert status.name == "test_runner"
assert status.is_running is True
assert status.pending_count == 5
assert status.processing_count == 2
def test_task_status_with_error(self) -> None:
"""TaskStatus can include optional error message."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="failed_runner",
is_running=False,
pending_count=0,
processing_count=0,
error="Connection failed",
)
assert status.error == "Connection failed"
def test_task_status_default_error_is_none(self) -> None:
"""TaskStatus error defaults to None."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
is_running=True,
pending_count=0,
processing_count=0,
)
assert status.error is None
def test_task_status_is_frozen(self) -> None:
"""TaskStatus is immutable (frozen dataclass)."""
from inference.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
is_running=True,
pending_count=0,
processing_count=0,
)
with pytest.raises(AttributeError):
status.name = "changed" # type: ignore[misc]
class TestTaskRunnerInterface:
"""Tests for TaskRunner abstract base class."""
def test_cannot_instantiate_directly(self) -> None:
"""TaskRunner is abstract and cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
with pytest.raises(TypeError):
TaskRunner() # type: ignore[abstract]
def test_is_abstract_base_class(self) -> None:
"""TaskRunner inherits from ABC."""
from inference.web.core.task_interface import TaskRunner
assert issubclass(TaskRunner, ABC)
def test_subclass_missing_name_cannot_instantiate(self) -> None:
"""Subclass without name property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingName(TaskRunner):
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingName() # type: ignore[abstract]
def test_subclass_missing_start_cannot_instantiate(self) -> None:
"""Subclass without start method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingStart(TaskRunner):
@property
def name(self) -> str:
return "test"
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingStart() # type: ignore[abstract]
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
"""Subclass without stop method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingStop(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingStop() # type: ignore[abstract]
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
"""Subclass without is_running property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class MissingIsRunning(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
def get_status(self) -> TaskStatus:
return TaskStatus("", False, 0, 0)
with pytest.raises(TypeError):
MissingIsRunning() # type: ignore[abstract]
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
"""Subclass without get_status method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
class MissingGetStatus(TaskRunner):
@property
def name(self) -> str:
return "test"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
with pytest.raises(TypeError):
MissingGetStatus() # type: ignore[abstract]
def test_complete_subclass_can_instantiate(self) -> None:
"""Complete subclass implementing all methods can be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
class CompleteRunner(TaskRunner):
def __init__(self) -> None:
self._running = False
@property
def name(self) -> str:
return "complete_runner"
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(
name=self.name,
is_running=self._running,
pending_count=0,
processing_count=0,
)
runner = CompleteRunner()
assert runner.name == "complete_runner"
assert runner.is_running is False
runner.start()
assert runner.is_running is True
status = runner.get_status()
assert status.name == "complete_runner"
assert status.is_running is True
runner.stop()
assert runner.is_running is False
class TestTaskManager:
"""Tests for TaskManager facade."""
def test_register_runner(self) -> None:
"""Can register a task runner."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
@property
def name(self) -> str:
return "mock"
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("mock", False, 0, 0)
manager = TaskManager()
runner = MockRunner()
manager.register(runner)
assert manager.get_runner("mock") is runner
def test_get_runner_returns_none_for_unknown(self) -> None:
"""get_runner returns None for unknown runner name."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_runner("unknown") is None
def test_start_all_runners(self) -> None:
"""start_all starts all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
self._running = False
@property
def name(self) -> str:
return self._name
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, self._running, 0, 0)
manager = TaskManager()
runner1 = MockRunner("runner1")
runner2 = MockRunner("runner2")
manager.register(runner1)
manager.register(runner2)
assert runner1.is_running is False
assert runner2.is_running is False
manager.start_all()
assert runner1.is_running is True
assert runner2.is_running is True
def test_stop_all_runners(self) -> None:
"""stop_all stops all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
self._running = True
@property
def name(self) -> str:
return self._name
def start(self) -> None:
self._running = True
def stop(self, timeout: float | None = None) -> None:
self._running = False
@property
def is_running(self) -> bool:
return self._running
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, self._running, 0, 0)
manager = TaskManager()
runner1 = MockRunner("runner1")
runner2 = MockRunner("runner2")
manager.register(runner1)
manager.register(runner2)
assert runner1.is_running is True
assert runner2.is_running is True
manager.stop_all()
assert runner1.is_running is False
assert runner2.is_running is False
def test_get_all_status(self) -> None:
"""get_all_status returns status of all runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str, pending: int) -> None:
self._name = runner_name
self._pending = pending
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return True
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, True, self._pending, 0)
manager = TaskManager()
manager.register(MockRunner("runner1", 5))
manager.register(MockRunner("runner2", 10))
all_status = manager.get_all_status()
assert len(all_status) == 2
assert all_status["runner1"].pending_count == 5
assert all_status["runner2"].pending_count == 10
def test_get_all_status_empty_when_no_runners(self) -> None:
"""get_all_status returns empty dict when no runners registered."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_all_status() == {}
def test_runner_names_property(self) -> None:
"""runner_names returns list of all registered runner names."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, False, 0, 0)
manager = TaskManager()
manager.register(MockRunner("alpha"))
manager.register(MockRunner("beta"))
names = manager.runner_names
assert set(names) == {"alpha", "beta"}
def test_stop_all_with_timeout_distribution(self) -> None:
"""stop_all distributes timeout across runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
received_timeouts: list[float | None] = []
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
self._name = runner_name
@property
def name(self) -> str:
return self._name
def start(self) -> None:
pass
def stop(self, timeout: float | None = None) -> None:
received_timeouts.append(timeout)
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus(self._name, False, 0, 0)
manager = TaskManager()
manager.register(MockRunner("r1"))
manager.register(MockRunner("r2"))
manager.stop_all(timeout=20.0)
# Timeout should be distributed (20 / 2 = 10 each)
assert len(received_timeouts) == 2
assert all(t == 10.0 for t in received_timeouts)
def test_start_all_skips_runners_requiring_arguments(self) -> None:
"""start_all skips runners that require arguments."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
no_args_started = []
with_args_started = []
class NoArgsRunner(TaskRunner):
@property
def name(self) -> str:
return "no_args"
def start(self) -> None:
no_args_started.append(True)
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("no_args", False, 0, 0)
class RequiresArgsRunner(TaskRunner):
@property
def name(self) -> str:
return "requires_args"
def start(self, handler: object) -> None: # type: ignore[override]
# This runner requires an argument
with_args_started.append(True)
def stop(self, timeout: float | None = None) -> None:
pass
@property
def is_running(self) -> bool:
return False
def get_status(self) -> TaskStatus:
return TaskStatus("requires_args", False, 0, 0)
manager = TaskManager()
manager.register(NoArgsRunner())
manager.register(RequiresArgsRunner())
# start_all should start no_args runner but skip requires_args
manager.start_all()
assert len(no_args_started) == 1
assert len(with_args_started) == 0 # Skipped due to TypeError
def test_stop_all_with_no_runners(self) -> None:
"""stop_all does nothing when no runners registered."""
from inference.web.core.task_interface import TaskManager
manager = TaskManager()
# Should not raise any exception
manager.stop_all()
# Just verify it returns without error
assert manager.runner_names == []
class TestTrainingSchedulerInterface:
"""Tests for TrainingScheduler implementing TaskRunner."""
def test_training_scheduler_is_task_runner(self) -> None:
"""TrainingScheduler inherits from TaskRunner."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskRunner
scheduler = TrainingScheduler()
assert isinstance(scheduler, TaskRunner)
def test_training_scheduler_name(self) -> None:
"""TrainingScheduler has correct name."""
from inference.web.core.scheduler import TrainingScheduler
scheduler = TrainingScheduler()
assert scheduler.name == "training_scheduler"
def test_training_scheduler_get_status(self) -> None:
"""TrainingScheduler provides status via get_status."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskStatus
scheduler = TrainingScheduler()
# Mock the training tasks repository
mock_tasks = MagicMock()
mock_tasks.get_pending.return_value = [MagicMock(), MagicMock()]
scheduler._training_tasks = mock_tasks
status = scheduler.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "training_scheduler"
assert status.is_running is False
assert status.pending_count == 2
class TestAutoLabelSchedulerInterface:
"""Tests for AutoLabelScheduler implementing TaskRunner."""
def test_autolabel_scheduler_is_task_runner(self) -> None:
"""AutoLabelScheduler inherits from TaskRunner."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskRunner
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert isinstance(scheduler, TaskRunner)
def test_autolabel_scheduler_name(self) -> None:
"""AutoLabelScheduler has correct name."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert scheduler.name == "autolabel_scheduler"
def test_autolabel_scheduler_get_status(self) -> None:
"""AutoLabelScheduler provides status via get_status."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskStatus
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
with patch(
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
) as mock_get:
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
scheduler = AutoLabelScheduler()
status = scheduler.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "autolabel_scheduler"
assert status.is_running is False
assert status.pending_count == 3
class TestAsyncTaskQueueInterface:
"""Tests for AsyncTaskQueue implementing TaskRunner."""
def test_async_queue_is_task_runner(self) -> None:
"""AsyncTaskQueue inherits from TaskRunner."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskRunner
queue = AsyncTaskQueue()
assert isinstance(queue, TaskRunner)
def test_async_queue_name(self) -> None:
"""AsyncTaskQueue has correct name."""
from inference.web.workers.async_queue import AsyncTaskQueue
queue = AsyncTaskQueue()
assert queue.name == "async_task_queue"
def test_async_queue_get_status(self) -> None:
"""AsyncTaskQueue provides status via get_status."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskStatus
queue = AsyncTaskQueue()
status = queue.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "async_task_queue"
assert status.is_running is False
assert status.pending_count == 0
assert status.processing_count == 0
class TestBatchTaskQueueInterface:
"""Tests for BatchTaskQueue implementing TaskRunner."""
def test_batch_queue_is_task_runner(self) -> None:
"""BatchTaskQueue inherits from TaskRunner."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskRunner
queue = BatchTaskQueue()
assert isinstance(queue, TaskRunner)
def test_batch_queue_name(self) -> None:
"""BatchTaskQueue has correct name."""
from inference.web.workers.batch_queue import BatchTaskQueue
queue = BatchTaskQueue()
assert queue.name == "batch_task_queue"
def test_batch_queue_get_status(self) -> None:
"""BatchTaskQueue provides status via get_status."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskStatus
queue = BatchTaskQueue()
status = queue.get_status()
assert isinstance(status, TaskStatus)
assert status.name == "batch_task_queue"
assert status.is_running is False
assert status.pending_count == 0