This commit is contained in:
Yaojia Wang
2026-02-01 00:08:40 +01:00
parent 33ada0350d
commit a516de4320
90 changed files with 11642 additions and 398 deletions

View File

@@ -0,0 +1,363 @@
"""
Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
# =============================================================================
# Test Database Model
# =============================================================================
class TestTrainingDatasetModel:
"""Tests for TrainingDataset model fields."""
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
training_status="running",
)
assert dataset.training_status == "running"
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
name="test-dataset",
active_training_task_id=task_id,
)
assert dataset.active_training_task_id == task_id
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
assert dataset.active_training_task_id is None
# =============================================================================
# Test AdminDB Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
@pytest.fixture
def mock_session(self):
"""Create mock database session."""
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
assert dataset.training_status == "running"
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
)
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
)
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
dataset = TrainingDataset(
dataset_id=dataset_id,
name="test-dataset",
status="ready",
training_status="running",
active_training_task_id=task_id,
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
db.update_dataset_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
)
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
db = AdminDB()
# Should not raise
db.update_dataset_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
mock_session.add.assert_not_called()
mock_session.commit.assert_not_called()
# =============================================================================
# Test API Response
# =============================================================================
class TestDatasetDetailResponseTrainingStatus:
"""Tests for DatasetDetailResponse including training status fields."""
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status="running",
active_training_task_id=str(uuid4()),
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path="/path/to/dataset",
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status == "running"
assert response.active_training_task_id is not None
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
name="test-dataset",
description=None,
status="ready",
training_status=None,
active_training_task_id=None,
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=10,
total_images=15,
total_annotations=100,
dataset_path=None,
error_message=None,
documents=[],
created_at=datetime.utcnow(),
updated_at=datetime.utcnow(),
)
assert response.training_status is None
assert response.active_training_task_id is None
# =============================================================================
# Test Scheduler Training Status Updates
# =============================================================================
class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
task_id = str(uuid4())
dataset_id = str(uuid4())
# Execute task (will fail but we check the status update call)
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
dataset_id=dataset_id,
)
except Exception:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id
# =============================================================================
# Test Dataset Status Values
# =============================================================================
class TestDatasetStatusValues:
"""Tests for valid dataset status values."""
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:
dataset = TrainingDataset(name="test", training_status=status)
assert dataset.training_status == status