""" Tests for Admin Training Routes and Scheduler. """ import pytest from datetime import datetime, timedelta from unittest.mock import MagicMock, patch from uuid import UUID from inference.data.admin_models import TrainingTask, TrainingLog from inference.web.api.v1.admin.training import _validate_uuid, create_training_router from inference.web.core.scheduler import ( TrainingScheduler, get_training_scheduler, start_scheduler, stop_scheduler, ) from inference.web.schemas.admin import ( TrainingConfig, TrainingStatus, TrainingTaskCreate, TrainingType, ) # Test UUIDs TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002" TEST_TOKEN = "test-admin-token-12345" class TestTrainingRouterCreation: """Tests for training router creation.""" def test_creates_router_with_endpoints(self): """Test router is created with expected endpoints.""" router = create_training_router() # Get route paths (include prefix) paths = [route.path for route in router.routes] # Paths include the /admin/training prefix assert any("/tasks" in p for p in paths) assert any("{task_id}" in p for p in paths) assert any("cancel" in p for p in paths) assert any("logs" in p for p in paths) assert any("export" in p for p in paths) class TestTrainingConfigSchema: """Tests for TrainingConfig schema.""" def test_default_config(self): """Test default training configuration.""" config = TrainingConfig() assert config.model_name == "yolo11n.pt" assert config.epochs == 100 assert config.batch_size == 16 assert config.image_size == 640 assert config.learning_rate == 0.01 assert config.device == "0" def test_custom_config(self): """Test custom training configuration.""" config = TrainingConfig( model_name="yolo11s.pt", epochs=50, batch_size=8, image_size=416, learning_rate=0.001, device="cpu", ) assert config.model_name == "yolo11s.pt" assert config.epochs == 50 assert config.batch_size == 8 def test_config_validation(self): """Test config validation constraints.""" # Epochs must be 1-1000 config = TrainingConfig(epochs=1) assert config.epochs == 1 config = TrainingConfig(epochs=1000) assert config.epochs == 1000 class TestTrainingTaskCreateSchema: """Tests for TrainingTaskCreate schema.""" def test_minimal_task(self): """Test minimal task creation.""" task = TrainingTaskCreate(name="Test Training") assert task.name == "Test Training" assert task.task_type == TrainingType.TRAIN assert task.description is None assert task.scheduled_at is None def test_scheduled_task(self): """Test scheduled task creation.""" scheduled_time = datetime.utcnow() + timedelta(hours=1) task = TrainingTaskCreate( name="Scheduled Training", scheduled_at=scheduled_time, ) assert task.scheduled_at == scheduled_time def test_recurring_task(self): """Test recurring task with cron expression.""" task = TrainingTaskCreate( name="Recurring Training", cron_expression="0 0 * * 0", # Every Sunday at midnight ) assert task.cron_expression == "0 0 * * 0" class TestTrainingTaskModel: """Tests for TrainingTask model.""" def test_task_creation(self): """Test training task model creation.""" task = TrainingTask( admin_token=TEST_TOKEN, name="Test Task", task_type="train", status="pending", ) assert task.name == "Test Task" assert task.task_type == "train" assert task.status == "pending" def test_task_with_config(self): """Test task with configuration.""" config = { "model_name": "yolo11n.pt", "epochs": 100, } task = TrainingTask( admin_token=TEST_TOKEN, name="Configured Task", task_type="train", config=config, ) assert task.config == config assert task.config["epochs"] == 100 class TestTrainingLogModel: """Tests for TrainingLog model.""" def test_log_creation(self): """Test training log creation.""" log = TrainingLog( task_id=UUID(TEST_TASK_UUID), level="INFO", message="Training started", ) assert str(log.task_id) == TEST_TASK_UUID assert log.level == "INFO" assert log.message == "Training started" def test_log_with_details(self): """Test log with additional details.""" details = { "epoch": 10, "loss": 0.5, "mAP": 0.85, } log = TrainingLog( task_id=UUID(TEST_TASK_UUID), level="INFO", message="Epoch completed", details=details, ) assert log.details == details assert log.details["epoch"] == 10 class TestTrainingScheduler: """Tests for TrainingScheduler.""" @pytest.fixture def scheduler(self): """Create a scheduler for testing.""" return TrainingScheduler(check_interval_seconds=1) def test_scheduler_creation(self, scheduler): """Test scheduler creation.""" assert scheduler._check_interval == 1 assert scheduler._running is False assert scheduler._thread is None def test_scheduler_start_stop(self, scheduler): """Test scheduler start and stop.""" with patch.object(scheduler, "_check_pending_tasks"): scheduler.start() assert scheduler._running is True assert scheduler._thread is not None scheduler.stop() assert scheduler._running is False def test_scheduler_singleton(self): """Test get_training_scheduler returns singleton.""" # Reset any existing scheduler stop_scheduler() s1 = get_training_scheduler() s2 = get_training_scheduler() assert s1 is s2 # Cleanup stop_scheduler() class TestTrainingStatusEnum: """Tests for TrainingStatus enum.""" def test_all_statuses(self): """Test all training statuses are defined.""" statuses = [s.value for s in TrainingStatus] assert "pending" in statuses assert "scheduled" in statuses assert "running" in statuses assert "completed" in statuses assert "failed" in statuses assert "cancelled" in statuses class TestTrainingTypeEnum: """Tests for TrainingType enum.""" def test_all_types(self): """Test all training types are defined.""" types = [t.value for t in TrainingType] assert "train" in types assert "finetune" in types