Files
invoice-master-poc-v2/tests/web/test_dataset_training_status.py
Yaojia Wang a564ac9d70 WIP
2026-02-01 18:51:54 +01:00

370 lines
13 KiB
Python

"""
Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. DatasetRepository update_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# =============================================================================
# Test Database Model
# =============================================================================
class TestTrainingDatasetModel:
"""Tests for TrainingDataset model fields."""
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
training_status="running",
)
assert dataset.training_status == "running"
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
name="test-dataset",
active_training_task_id=task_id,
)
assert dataset.active_training_task_id == task_id
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
assert dataset.active_training_task_id is None
# =============================================================================
# Test DatasetRepository Methods
# =============================================================================
class TestDatasetRepositoryTrainingStatus:
"""Tests for DatasetRepository.update_training_status method."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = MagicMock()
return session
def test_update_training_status_sets_status(self, mock_session):
"""update_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
assert dataset.training_status == "running"
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_training_status_sets_task_id(self, mock_session):
"""update_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
)
assert dataset.active_training_task_id == task_id
def test_update_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
)
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
training_status="running",
active_training_task_id=task_id,
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
)
assert dataset.active_training_task_id is None
def test_update_training_status_handles_missing_dataset(self, mock_session):
"""update_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
repo = DatasetRepository()
# Should not raise
repo.update_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
# =============================================================================
# Test API Response
# =============================================================================
class TestDatasetDetailResponseTrainingStatus:
"""Tests for DatasetDetailResponse including training status fields."""
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status="running",
active_training_task_id=str(uuid4()),
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path="/path/to/dataset",
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status == "running"
assert response.active_training_task_id is not None
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path=None,
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status is None
assert response.active_training_task_id is None
# =============================================================================
# Test Scheduler Training Status Updates
# =============================================================================
class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_datasets_repo(self):
"""Create mock DatasetRepository."""
mock = MagicMock()
mock.get.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
return mock
@pytest.fixture
def mock_training_tasks_repo(self):
"""Create mock TrainingTaskRepository."""
mock = MagicMock()
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._datasets = mock_datasets_repo
scheduler._training_tasks = mock_training_tasks_repo
task_id = str(uuid4())
dataset_id = str(uuid4())
# Execute task (will fail but we check the status update call)
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
dataset_id=dataset_id,
)
except Exception:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_datasets_repo.update_training_status.assert_called()
first_call = mock_datasets_repo.update_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id
# =============================================================================
# Test Dataset Status Values
# =============================================================================
class TestDatasetStatusValues:
"""Tests for valid dataset status values."""
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:
dataset = TrainingDataset(name="test", training_status=status)
assert dataset.training_status == status