370 lines
13 KiB
Python
370 lines
13 KiB
Python
"""
|
|
Tests for dataset training status feature.
|
|
|
|
Tests cover:
|
|
1. Database model fields (training_status, active_training_task_id)
|
|
2. DatasetRepository update_training_status method
|
|
3. API response includes training status fields
|
|
4. Scheduler updates dataset status during training lifecycle
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock, patch
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
|
|
# =============================================================================
|
|
# Test Database Model
|
|
# =============================================================================
|
|
|
|
|
|
class TestTrainingDatasetModel:
|
|
"""Tests for TrainingDataset model fields."""
|
|
|
|
def test_training_dataset_has_training_status_field(self):
|
|
"""TrainingDataset model should have training_status field."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(
|
|
name="test-dataset",
|
|
training_status="running",
|
|
)
|
|
assert dataset.training_status == "running"
|
|
|
|
def test_training_dataset_has_active_training_task_id_field(self):
|
|
"""TrainingDataset model should have active_training_task_id field."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
task_id = uuid4()
|
|
dataset = TrainingDataset(
|
|
name="test-dataset",
|
|
active_training_task_id=task_id,
|
|
)
|
|
assert dataset.active_training_task_id == task_id
|
|
|
|
def test_training_dataset_defaults(self):
|
|
"""TrainingDataset should have correct defaults for new fields."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(name="test-dataset")
|
|
assert dataset.training_status is None
|
|
assert dataset.active_training_task_id is None
|
|
|
|
|
|
# =============================================================================
|
|
# Test DatasetRepository Methods
|
|
# =============================================================================
|
|
|
|
|
|
class TestDatasetRepositoryTrainingStatus:
|
|
"""Tests for DatasetRepository.update_training_status method."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create mock database session."""
|
|
session = MagicMock()
|
|
return session
|
|
|
|
def test_update_training_status_sets_status(self, mock_session):
|
|
"""update_training_status should set training_status."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset_id = uuid4()
|
|
dataset = TrainingDataset(
|
|
dataset_id=dataset_id,
|
|
name="test-dataset",
|
|
status="ready",
|
|
)
|
|
mock_session.get.return_value = dataset
|
|
|
|
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
from inference.data.repositories import DatasetRepository
|
|
|
|
repo = DatasetRepository()
|
|
repo.update_training_status(
|
|
dataset_id=str(dataset_id),
|
|
training_status="running",
|
|
)
|
|
|
|
assert dataset.training_status == "running"
|
|
mock_session.add.assert_called_once_with(dataset)
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_update_training_status_sets_task_id(self, mock_session):
|
|
"""update_training_status should set active_training_task_id."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset_id = uuid4()
|
|
task_id = uuid4()
|
|
dataset = TrainingDataset(
|
|
dataset_id=dataset_id,
|
|
name="test-dataset",
|
|
status="ready",
|
|
)
|
|
mock_session.get.return_value = dataset
|
|
|
|
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
from inference.data.repositories import DatasetRepository
|
|
|
|
repo = DatasetRepository()
|
|
repo.update_training_status(
|
|
dataset_id=str(dataset_id),
|
|
training_status="running",
|
|
active_training_task_id=str(task_id),
|
|
)
|
|
|
|
assert dataset.active_training_task_id == task_id
|
|
|
|
def test_update_training_status_updates_main_status_on_complete(
|
|
self, mock_session
|
|
):
|
|
"""update_training_status should update main status to 'trained' when completed."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset_id = uuid4()
|
|
dataset = TrainingDataset(
|
|
dataset_id=dataset_id,
|
|
name="test-dataset",
|
|
status="ready",
|
|
)
|
|
mock_session.get.return_value = dataset
|
|
|
|
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
from inference.data.repositories import DatasetRepository
|
|
|
|
repo = DatasetRepository()
|
|
repo.update_training_status(
|
|
dataset_id=str(dataset_id),
|
|
training_status="completed",
|
|
update_main_status=True,
|
|
)
|
|
|
|
assert dataset.status == "trained"
|
|
assert dataset.training_status == "completed"
|
|
|
|
def test_update_training_status_clears_task_id_on_complete(
|
|
self, mock_session
|
|
):
|
|
"""update_training_status should clear task_id when training completes."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset_id = uuid4()
|
|
task_id = uuid4()
|
|
dataset = TrainingDataset(
|
|
dataset_id=dataset_id,
|
|
name="test-dataset",
|
|
status="ready",
|
|
training_status="running",
|
|
active_training_task_id=task_id,
|
|
)
|
|
mock_session.get.return_value = dataset
|
|
|
|
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
from inference.data.repositories import DatasetRepository
|
|
|
|
repo = DatasetRepository()
|
|
repo.update_training_status(
|
|
dataset_id=str(dataset_id),
|
|
training_status="completed",
|
|
active_training_task_id=None,
|
|
)
|
|
|
|
assert dataset.active_training_task_id is None
|
|
|
|
def test_update_training_status_handles_missing_dataset(self, mock_session):
|
|
"""update_training_status should handle missing dataset gracefully."""
|
|
mock_session.get.return_value = None
|
|
|
|
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
from inference.data.repositories import DatasetRepository
|
|
|
|
repo = DatasetRepository()
|
|
# Should not raise
|
|
repo.update_training_status(
|
|
dataset_id=str(uuid4()),
|
|
training_status="running",
|
|
)
|
|
|
|
mock_session.add.assert_not_called()
|
|
mock_session.commit.assert_not_called()
|
|
|
|
|
|
# =============================================================================
|
|
# Test API Response
|
|
# =============================================================================
|
|
|
|
|
|
class TestDatasetDetailResponseTrainingStatus:
|
|
"""Tests for DatasetDetailResponse including training status fields."""
|
|
|
|
def test_dataset_detail_response_includes_training_status(self):
|
|
"""DatasetDetailResponse schema should include training_status field."""
|
|
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
|
|
|
response = DatasetDetailResponse(
|
|
dataset_id=str(uuid4()),
|
|
name="test-dataset",
|
|
description=None,
|
|
status="ready",
|
|
training_status="running",
|
|
active_training_task_id=str(uuid4()),
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
total_documents=10,
|
|
total_images=15,
|
|
total_annotations=100,
|
|
dataset_path="/path/to/dataset",
|
|
error_message=None,
|
|
documents=[],
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow(),
|
|
)
|
|
|
|
assert response.training_status == "running"
|
|
assert response.active_training_task_id is not None
|
|
|
|
def test_dataset_detail_response_allows_null_training_status(self):
|
|
"""DatasetDetailResponse should allow null training_status."""
|
|
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
|
|
|
response = DatasetDetailResponse(
|
|
dataset_id=str(uuid4()),
|
|
name="test-dataset",
|
|
description=None,
|
|
status="ready",
|
|
training_status=None,
|
|
active_training_task_id=None,
|
|
train_ratio=0.8,
|
|
val_ratio=0.1,
|
|
seed=42,
|
|
total_documents=10,
|
|
total_images=15,
|
|
total_annotations=100,
|
|
dataset_path=None,
|
|
error_message=None,
|
|
documents=[],
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow(),
|
|
)
|
|
|
|
assert response.training_status is None
|
|
assert response.active_training_task_id is None
|
|
|
|
|
|
# =============================================================================
|
|
# Test Scheduler Training Status Updates
|
|
# =============================================================================
|
|
|
|
|
|
class TestSchedulerDatasetStatusUpdates:
|
|
"""Tests for scheduler updating dataset status during training."""
|
|
|
|
@pytest.fixture
|
|
def mock_datasets_repo(self):
|
|
"""Create mock DatasetRepository."""
|
|
mock = MagicMock()
|
|
mock.get.return_value = MagicMock(
|
|
dataset_id=uuid4(),
|
|
name="test-dataset",
|
|
dataset_path="/path/to/dataset",
|
|
total_images=100,
|
|
)
|
|
return mock
|
|
|
|
@pytest.fixture
|
|
def mock_training_tasks_repo(self):
|
|
"""Create mock TrainingTaskRepository."""
|
|
mock = MagicMock()
|
|
return mock
|
|
|
|
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
|
|
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
|
from inference.web.core.scheduler import TrainingScheduler
|
|
|
|
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
|
|
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
|
|
|
scheduler = TrainingScheduler()
|
|
scheduler._datasets = mock_datasets_repo
|
|
scheduler._training_tasks = mock_training_tasks_repo
|
|
|
|
task_id = str(uuid4())
|
|
dataset_id = str(uuid4())
|
|
|
|
# Execute task (will fail but we check the status update call)
|
|
try:
|
|
scheduler._execute_task(
|
|
task_id=task_id,
|
|
config={"model_name": "yolo11n.pt"},
|
|
dataset_id=dataset_id,
|
|
)
|
|
except Exception:
|
|
pass # Expected to fail in test environment
|
|
|
|
# Check that training status was updated to running
|
|
mock_datasets_repo.update_training_status.assert_called()
|
|
first_call = mock_datasets_repo.update_training_status.call_args_list[0]
|
|
assert first_call.kwargs["training_status"] == "running"
|
|
assert first_call.kwargs["active_training_task_id"] == task_id
|
|
|
|
|
|
# =============================================================================
|
|
# Test Dataset Status Values
|
|
# =============================================================================
|
|
|
|
|
|
class TestDatasetStatusValues:
|
|
"""Tests for valid dataset status values."""
|
|
|
|
def test_dataset_status_building(self):
|
|
"""Dataset can have status 'building'."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(name="test", status="building")
|
|
assert dataset.status == "building"
|
|
|
|
def test_dataset_status_ready(self):
|
|
"""Dataset can have status 'ready'."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(name="test", status="ready")
|
|
assert dataset.status == "ready"
|
|
|
|
def test_dataset_status_trained(self):
|
|
"""Dataset can have status 'trained'."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(name="test", status="trained")
|
|
assert dataset.status == "trained"
|
|
|
|
def test_dataset_status_failed(self):
|
|
"""Dataset can have status 'failed'."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
dataset = TrainingDataset(name="test", status="failed")
|
|
assert dataset.status == "failed"
|
|
|
|
def test_training_status_values(self):
|
|
"""Training status can have various values."""
|
|
from inference.data.admin_models import TrainingDataset
|
|
|
|
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
|
|
for status in valid_statuses:
|
|
dataset = TrainingDataset(name="test", training_status=status)
|
|
assert dataset.training_status == status
|