Files
invoice-master-poc-v2/tests/web/test_admin_training.py
2026-01-27 23:58:17 +01:00

248 lines
6.9 KiB
Python

"""
Tests for Admin Training Routes and Scheduler.
"""
import pytest
from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingTask, TrainingLog
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
from inference.web.core.scheduler import (
TrainingScheduler,
get_training_scheduler,
start_scheduler,
stop_scheduler,
)
from inference.web.schemas.admin import (
TrainingConfig,
TrainingStatus,
TrainingTaskCreate,
TrainingType,
)
# Test UUIDs
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
TEST_TOKEN = "test-admin-token-12345"
class TestTrainingRouterCreation:
"""Tests for training router creation."""
def test_creates_router_with_endpoints(self):
"""Test router is created with expected endpoints."""
router = create_training_router()
# Get route paths (include prefix)
paths = [route.path for route in router.routes]
# Paths include the /admin/training prefix
assert any("/tasks" in p for p in paths)
assert any("{task_id}" in p for p in paths)
assert any("cancel" in p for p in paths)
assert any("logs" in p for p in paths)
assert any("export" in p for p in paths)
class TestTrainingConfigSchema:
"""Tests for TrainingConfig schema."""
def test_default_config(self):
"""Test default training configuration."""
config = TrainingConfig()
assert config.model_name == "yolo11n.pt"
assert config.epochs == 100
assert config.batch_size == 16
assert config.image_size == 640
assert config.learning_rate == 0.01
assert config.device == "0"
def test_custom_config(self):
"""Test custom training configuration."""
config = TrainingConfig(
model_name="yolo11s.pt",
epochs=50,
batch_size=8,
image_size=416,
learning_rate=0.001,
device="cpu",
)
assert config.model_name == "yolo11s.pt"
assert config.epochs == 50
assert config.batch_size == 8
def test_config_validation(self):
"""Test config validation constraints."""
# Epochs must be 1-1000
config = TrainingConfig(epochs=1)
assert config.epochs == 1
config = TrainingConfig(epochs=1000)
assert config.epochs == 1000
class TestTrainingTaskCreateSchema:
"""Tests for TrainingTaskCreate schema."""
def test_minimal_task(self):
"""Test minimal task creation."""
task = TrainingTaskCreate(name="Test Training")
assert task.name == "Test Training"
assert task.task_type == TrainingType.TRAIN
assert task.description is None
assert task.scheduled_at is None
def test_scheduled_task(self):
"""Test scheduled task creation."""
scheduled_time = datetime.utcnow() + timedelta(hours=1)
task = TrainingTaskCreate(
name="Scheduled Training",
scheduled_at=scheduled_time,
)
assert task.scheduled_at == scheduled_time
def test_recurring_task(self):
"""Test recurring task with cron expression."""
task = TrainingTaskCreate(
name="Recurring Training",
cron_expression="0 0 * * 0", # Every Sunday at midnight
)
assert task.cron_expression == "0 0 * * 0"
class TestTrainingTaskModel:
"""Tests for TrainingTask model."""
def test_task_creation(self):
"""Test training task model creation."""
task = TrainingTask(
admin_token=TEST_TOKEN,
name="Test Task",
task_type="train",
status="pending",
)
assert task.name == "Test Task"
assert task.task_type == "train"
assert task.status == "pending"
def test_task_with_config(self):
"""Test task with configuration."""
config = {
"model_name": "yolo11n.pt",
"epochs": 100,
}
task = TrainingTask(
admin_token=TEST_TOKEN,
name="Configured Task",
task_type="train",
config=config,
)
assert task.config == config
assert task.config["epochs"] == 100
class TestTrainingLogModel:
"""Tests for TrainingLog model."""
def test_log_creation(self):
"""Test training log creation."""
log = TrainingLog(
task_id=UUID(TEST_TASK_UUID),
level="INFO",
message="Training started",
)
assert str(log.task_id) == TEST_TASK_UUID
assert log.level == "INFO"
assert log.message == "Training started"
def test_log_with_details(self):
"""Test log with additional details."""
details = {
"epoch": 10,
"loss": 0.5,
"mAP": 0.85,
}
log = TrainingLog(
task_id=UUID(TEST_TASK_UUID),
level="INFO",
message="Epoch completed",
details=details,
)
assert log.details == details
assert log.details["epoch"] == 10
class TestTrainingScheduler:
"""Tests for TrainingScheduler."""
@pytest.fixture
def scheduler(self):
"""Create a scheduler for testing."""
return TrainingScheduler(check_interval_seconds=1)
def test_scheduler_creation(self, scheduler):
"""Test scheduler creation."""
assert scheduler._check_interval == 1
assert scheduler._running is False
assert scheduler._thread is None
def test_scheduler_start_stop(self, scheduler):
"""Test scheduler start and stop."""
with patch.object(scheduler, "_check_pending_tasks"):
scheduler.start()
assert scheduler._running is True
assert scheduler._thread is not None
scheduler.stop()
assert scheduler._running is False
def test_scheduler_singleton(self):
"""Test get_training_scheduler returns singleton."""
# Reset any existing scheduler
stop_scheduler()
s1 = get_training_scheduler()
s2 = get_training_scheduler()
assert s1 is s2
# Cleanup
stop_scheduler()
class TestTrainingStatusEnum:
"""Tests for TrainingStatus enum."""
def test_all_statuses(self):
"""Test all training statuses are defined."""
statuses = [s.value for s in TrainingStatus]
assert "pending" in statuses
assert "scheduled" in statuses
assert "running" in statuses
assert "completed" in statuses
assert "failed" in statuses
assert "cancelled" in statuses
class TestTrainingTypeEnum:
"""Tests for TrainingType enum."""
def test_all_types(self):
"""Test all training types are defined."""
types = [t.value for t in TrainingType]
assert "train" in types
assert "finetune" in types