WIP
This commit is contained in:
247
tests/web/test_admin_training.py
Normal file
247
tests/web/test_admin_training.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Tests for Admin Training Routes and Scheduler.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from src.data.admin_models import TrainingTask, TrainingLog
|
||||
from src.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from src.web.core.scheduler import (
|
||||
TrainingScheduler,
|
||||
get_training_scheduler,
|
||||
start_scheduler,
|
||||
stop_scheduler,
|
||||
)
|
||||
from src.web.schemas.admin import (
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
TrainingType,
|
||||
)
|
||||
|
||||
|
||||
# Test UUIDs
|
||||
TEST_TASK_UUID = "770e8400-e29b-41d4-a716-446655440002"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestTrainingRouterCreation:
|
||||
"""Tests for training router creation."""
|
||||
|
||||
def test_creates_router_with_endpoints(self):
|
||||
"""Test router is created with expected endpoints."""
|
||||
router = create_training_router()
|
||||
|
||||
# Get route paths (include prefix)
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
# Paths include the /admin/training prefix
|
||||
assert any("/tasks" in p for p in paths)
|
||||
assert any("{task_id}" in p for p in paths)
|
||||
assert any("cancel" in p for p in paths)
|
||||
assert any("logs" in p for p in paths)
|
||||
assert any("export" in p for p in paths)
|
||||
|
||||
|
||||
class TestTrainingConfigSchema:
|
||||
"""Tests for TrainingConfig schema."""
|
||||
|
||||
def test_default_config(self):
|
||||
"""Test default training configuration."""
|
||||
config = TrainingConfig()
|
||||
|
||||
assert config.model_name == "yolo11n.pt"
|
||||
assert config.epochs == 100
|
||||
assert config.batch_size == 16
|
||||
assert config.image_size == 640
|
||||
assert config.learning_rate == 0.01
|
||||
assert config.device == "0"
|
||||
|
||||
def test_custom_config(self):
|
||||
"""Test custom training configuration."""
|
||||
config = TrainingConfig(
|
||||
model_name="yolo11s.pt",
|
||||
epochs=50,
|
||||
batch_size=8,
|
||||
image_size=416,
|
||||
learning_rate=0.001,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert config.model_name == "yolo11s.pt"
|
||||
assert config.epochs == 50
|
||||
assert config.batch_size == 8
|
||||
|
||||
def test_config_validation(self):
|
||||
"""Test config validation constraints."""
|
||||
# Epochs must be 1-1000
|
||||
config = TrainingConfig(epochs=1)
|
||||
assert config.epochs == 1
|
||||
|
||||
config = TrainingConfig(epochs=1000)
|
||||
assert config.epochs == 1000
|
||||
|
||||
|
||||
class TestTrainingTaskCreateSchema:
|
||||
"""Tests for TrainingTaskCreate schema."""
|
||||
|
||||
def test_minimal_task(self):
|
||||
"""Test minimal task creation."""
|
||||
task = TrainingTaskCreate(name="Test Training")
|
||||
|
||||
assert task.name == "Test Training"
|
||||
assert task.task_type == TrainingType.TRAIN
|
||||
assert task.description is None
|
||||
assert task.scheduled_at is None
|
||||
|
||||
def test_scheduled_task(self):
|
||||
"""Test scheduled task creation."""
|
||||
scheduled_time = datetime.utcnow() + timedelta(hours=1)
|
||||
task = TrainingTaskCreate(
|
||||
name="Scheduled Training",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
assert task.scheduled_at == scheduled_time
|
||||
|
||||
def test_recurring_task(self):
|
||||
"""Test recurring task with cron expression."""
|
||||
task = TrainingTaskCreate(
|
||||
name="Recurring Training",
|
||||
cron_expression="0 0 * * 0", # Every Sunday at midnight
|
||||
)
|
||||
|
||||
assert task.cron_expression == "0 0 * * 0"
|
||||
|
||||
|
||||
class TestTrainingTaskModel:
|
||||
"""Tests for TrainingTask model."""
|
||||
|
||||
def test_task_creation(self):
|
||||
"""Test training task model creation."""
|
||||
task = TrainingTask(
|
||||
admin_token=TEST_TOKEN,
|
||||
name="Test Task",
|
||||
task_type="train",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
assert task.name == "Test Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
|
||||
def test_task_with_config(self):
|
||||
"""Test task with configuration."""
|
||||
config = {
|
||||
"model_name": "yolo11n.pt",
|
||||
"epochs": 100,
|
||||
}
|
||||
task = TrainingTask(
|
||||
admin_token=TEST_TOKEN,
|
||||
name="Configured Task",
|
||||
task_type="train",
|
||||
config=config,
|
||||
)
|
||||
|
||||
assert task.config == config
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
|
||||
class TestTrainingLogModel:
|
||||
"""Tests for TrainingLog model."""
|
||||
|
||||
def test_log_creation(self):
|
||||
"""Test training log creation."""
|
||||
log = TrainingLog(
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
level="INFO",
|
||||
message="Training started",
|
||||
)
|
||||
|
||||
assert str(log.task_id) == TEST_TASK_UUID
|
||||
assert log.level == "INFO"
|
||||
assert log.message == "Training started"
|
||||
|
||||
def test_log_with_details(self):
|
||||
"""Test log with additional details."""
|
||||
details = {
|
||||
"epoch": 10,
|
||||
"loss": 0.5,
|
||||
"mAP": 0.85,
|
||||
}
|
||||
log = TrainingLog(
|
||||
task_id=UUID(TEST_TASK_UUID),
|
||||
level="INFO",
|
||||
message="Epoch completed",
|
||||
details=details,
|
||||
)
|
||||
|
||||
assert log.details == details
|
||||
assert log.details["epoch"] == 10
|
||||
|
||||
|
||||
class TestTrainingScheduler:
|
||||
"""Tests for TrainingScheduler."""
|
||||
|
||||
@pytest.fixture
|
||||
def scheduler(self):
|
||||
"""Create a scheduler for testing."""
|
||||
return TrainingScheduler(check_interval_seconds=1)
|
||||
|
||||
def test_scheduler_creation(self, scheduler):
|
||||
"""Test scheduler creation."""
|
||||
assert scheduler._check_interval == 1
|
||||
assert scheduler._running is False
|
||||
assert scheduler._thread is None
|
||||
|
||||
def test_scheduler_start_stop(self, scheduler):
|
||||
"""Test scheduler start and stop."""
|
||||
with patch.object(scheduler, "_check_pending_tasks"):
|
||||
scheduler.start()
|
||||
assert scheduler._running is True
|
||||
assert scheduler._thread is not None
|
||||
|
||||
scheduler.stop()
|
||||
assert scheduler._running is False
|
||||
|
||||
def test_scheduler_singleton(self):
|
||||
"""Test get_training_scheduler returns singleton."""
|
||||
# Reset any existing scheduler
|
||||
stop_scheduler()
|
||||
|
||||
s1 = get_training_scheduler()
|
||||
s2 = get_training_scheduler()
|
||||
|
||||
assert s1 is s2
|
||||
|
||||
# Cleanup
|
||||
stop_scheduler()
|
||||
|
||||
|
||||
class TestTrainingStatusEnum:
|
||||
"""Tests for TrainingStatus enum."""
|
||||
|
||||
def test_all_statuses(self):
|
||||
"""Test all training statuses are defined."""
|
||||
statuses = [s.value for s in TrainingStatus]
|
||||
|
||||
assert "pending" in statuses
|
||||
assert "scheduled" in statuses
|
||||
assert "running" in statuses
|
||||
assert "completed" in statuses
|
||||
assert "failed" in statuses
|
||||
assert "cancelled" in statuses
|
||||
|
||||
|
||||
class TestTrainingTypeEnum:
|
||||
"""Tests for TrainingType enum."""
|
||||
|
||||
def test_all_types(self):
|
||||
"""Test all training types are defined."""
|
||||
types = [t.value for t in TrainingType]
|
||||
|
||||
assert "train" in types
|
||||
assert "finetune" in types
|
||||
Reference in New Issue
Block a user