This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -3,7 +3,7 @@ Tests for dataset training status feature.
Tests cover:
1. Database model fields (training_status, active_training_task_id)
2. AdminDB update_dataset_training_status method
2. DatasetRepository update_training_status method
3. API response includes training status fields
4. Scheduler updates dataset status during training lifecycle
"""
@@ -56,12 +56,12 @@ class TestTrainingDatasetModel:
# =============================================================================
# Test AdminDB Methods
# Test DatasetRepository Methods
# =============================================================================
class TestAdminDBDatasetTrainingStatus:
"""Tests for AdminDB.update_dataset_training_status method."""
class TestDatasetRepositoryTrainingStatus:
"""Tests for DatasetRepository.update_training_status method."""
@pytest.fixture
def mock_session(self):
@@ -69,8 +69,8 @@ class TestAdminDBDatasetTrainingStatus:
session = MagicMock()
return session
def test_update_dataset_training_status_sets_status(self, mock_session):
"""update_dataset_training_status should set training_status."""
def test_update_training_status_sets_status(self, mock_session):
"""update_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -81,13 +81,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
)
@@ -96,8 +96,8 @@ class TestAdminDBDatasetTrainingStatus:
mock_session.add.assert_called_once_with(dataset)
mock_session.commit.assert_called_once()
def test_update_dataset_training_status_sets_task_id(self, mock_session):
"""update_dataset_training_status should set active_training_task_id."""
def test_update_training_status_sets_task_id(self, mock_session):
"""update_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -109,13 +109,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="running",
active_training_task_id=str(task_id),
@@ -123,10 +123,10 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.active_training_task_id == task_id
def test_update_dataset_training_status_updates_main_status_on_complete(
def test_update_training_status_updates_main_status_on_complete(
self, mock_session
):
"""update_dataset_training_status should update main status to 'trained' when completed."""
"""update_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -137,13 +137,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
update_main_status=True,
@@ -152,10 +152,10 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.status == "trained"
assert dataset.training_status == "completed"
def test_update_dataset_training_status_clears_task_id_on_complete(
def test_update_training_status_clears_task_id_on_complete(
self, mock_session
):
"""update_dataset_training_status should clear task_id when training completes."""
"""update_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
dataset_id = uuid4()
@@ -169,13 +169,13 @@ class TestAdminDBDatasetTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
db.update_dataset_training_status(
repo = DatasetRepository()
repo.update_training_status(
dataset_id=str(dataset_id),
training_status="completed",
active_training_task_id=None,
@@ -183,18 +183,18 @@ class TestAdminDBDatasetTrainingStatus:
assert dataset.active_training_task_id is None
def test_update_dataset_training_status_handles_missing_dataset(self, mock_session):
"""update_dataset_training_status should handle missing dataset gracefully."""
def test_update_training_status_handles_missing_dataset(self, mock_session):
"""update_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.admin_db.get_session_context") as mock_ctx:
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.admin_db import AdminDB
from inference.data.repositories import DatasetRepository
db = AdminDB()
repo = DatasetRepository()
# Should not raise
db.update_dataset_training_status(
repo.update_training_status(
dataset_id=str(uuid4()),
training_status="running",
)
@@ -275,19 +275,24 @@ class TestSchedulerDatasetStatusUpdates:
"""Tests for scheduler updating dataset status during training."""
@pytest.fixture
def mock_db(self):
"""Create mock AdminDB."""
def mock_datasets_repo(self):
"""Create mock DatasetRepository."""
mock = MagicMock()
mock.get_dataset.return_value = MagicMock(
mock.get.return_value = MagicMock(
dataset_id=uuid4(),
name="test-dataset",
dataset_path="/path/to/dataset",
total_images=100,
)
mock.get_pending_training_tasks.return_value = []
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_db):
@pytest.fixture
def mock_training_tasks_repo(self):
"""Create mock TrainingTaskRepository."""
mock = MagicMock()
return mock
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
@@ -295,7 +300,8 @@ class TestSchedulerDatasetStatusUpdates:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
scheduler = TrainingScheduler()
scheduler._db = mock_db
scheduler._datasets = mock_datasets_repo
scheduler._training_tasks = mock_training_tasks_repo
task_id = str(uuid4())
dataset_id = str(uuid4())
@@ -311,8 +317,8 @@ class TestSchedulerDatasetStatusUpdates:
pass # Expected to fail in test environment
# Check that training status was updated to running
mock_db.update_dataset_training_status.assert_called()
first_call = mock_db.update_dataset_training_status.call_args_list[0]
mock_datasets_repo.update_training_status.assert_called()
first_call = mock_datasets_repo.update_training_status.call_args_list[0]
assert first_call.kwargs["training_status"] == "running"
assert first_call.kwargs["active_training_task_id"] == task_id