""" Tests for TrainingTaskRepository 100% coverage tests for training task management. """ import pytest from datetime import datetime, timezone, timedelta from unittest.mock import MagicMock, patch from uuid import uuid4, UUID from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink from inference.data.repositories.training_task_repository import TrainingTaskRepository class TestTrainingTaskRepository: """Tests for TrainingTaskRepository.""" @pytest.fixture def sample_task(self) -> TrainingTask: """Create a sample training task for testing.""" return TrainingTask( task_id=uuid4(), admin_token="admin-token", name="Test Training Task", task_type="train", description="A test training task", status="pending", config={"epochs": 100, "batch_size": 16}, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_log(self) -> TrainingLog: """Create a sample training log for testing.""" return TrainingLog( log_id=uuid4(), task_id=uuid4(), level="INFO", message="Training started", details={"epoch": 1}, created_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_link(self) -> TrainingDocumentLink: """Create a sample training document link for testing.""" return TrainingDocumentLink( link_id=uuid4(), task_id=uuid4(), document_id=uuid4(), annotation_snapshot={"annotations": []}, created_at=datetime.now(timezone.utc), ) @pytest.fixture def repo(self) -> TrainingTaskRepository: """Create a TrainingTaskRepository instance.""" return TrainingTaskRepository() # ========================================================================= # create() tests # ========================================================================= def test_create_returns_task_id(self, repo): """Test create returns task ID.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( admin_token="admin-token", name="Test Task", ) assert result is not None mock_session.add.assert_called_once() mock_session.flush.assert_called_once() def test_create_with_all_params(self, repo): """Test create with all parameters.""" scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1) with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( admin_token="admin-token", name="Test Task", task_type="finetune", description="Full test", config={"epochs": 50}, scheduled_at=scheduled_time, cron_expression="0 0 * * *", is_recurring=True, dataset_id=str(uuid4()), ) assert result is not None added_task = mock_session.add.call_args[0][0] assert added_task.task_type == "finetune" assert added_task.description == "Full test" assert added_task.is_recurring is True assert added_task.status == "scheduled" # because scheduled_at is set def test_create_pending_status_when_not_scheduled(self, repo): """Test create sets pending status when no scheduled_at.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( admin_token="admin-token", name="Test Task", ) added_task = mock_session.add.call_args[0][0] assert added_task.status == "pending" def test_create_scheduled_status_when_scheduled(self, repo): """Test create sets scheduled status when scheduled_at is provided.""" scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1) with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create( admin_token="admin-token", name="Test Task", scheduled_at=scheduled_time, ) added_task = mock_session.add.call_args[0][0] assert added_task.status == "scheduled" # ========================================================================= # get() tests # ========================================================================= def test_get_returns_task(self, repo, sample_task): """Test get returns task when exists.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(sample_task.task_id)) assert result is not None assert result.name == "Test Training Task" mock_session.expunge.assert_called_once() def test_get_returns_none_when_not_found(self, repo): """Test get returns None when task not found.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(uuid4())) assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # get_by_token() tests # ========================================================================= def test_get_by_token_returns_task(self, repo, sample_task): """Test get_by_token returns task (delegates to get).""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_by_token(str(sample_task.task_id), "admin-token") assert result is not None def test_get_by_token_without_token_param(self, repo, sample_task): """Test get_by_token works without token parameter.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_by_token(str(sample_task.task_id)) assert result is not None # ========================================================================= # get_paginated() tests # ========================================================================= def test_get_paginated_returns_tasks_and_total(self, repo, sample_task): """Test get_paginated returns list of tasks and total count.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_task] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) tasks, total = repo.get_paginated() assert len(tasks) == 1 assert total == 1 def test_get_paginated_with_status_filter(self, repo, sample_task): """Test get_paginated filters by status.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_task] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) tasks, total = repo.get_paginated(status="pending") assert len(tasks) == 1 def test_get_paginated_with_pagination(self, repo, sample_task): """Test get_paginated with limit and offset.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 50 mock_session.exec.return_value.all.return_value = [sample_task] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) tasks, total = repo.get_paginated(limit=10, offset=20) assert total == 50 def test_get_paginated_empty_results(self, repo): """Test get_paginated with no results.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 0 mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) tasks, total = repo.get_paginated() assert tasks == [] assert total == 0 # ========================================================================= # get_pending() tests # ========================================================================= def test_get_pending_returns_pending_tasks(self, repo, sample_task): """Test get_pending returns pending and scheduled tasks.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_task] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_pending() assert len(result) == 1 def test_get_pending_returns_empty_list(self, repo): """Test get_pending returns empty list when no pending tasks.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_pending() assert result == [] # ========================================================================= # update_status() tests # ========================================================================= def test_update_status_updates_task(self, repo, sample_task): """Test update_status updates task status.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_task.task_id), "running") assert sample_task.status == "running" def test_update_status_sets_started_at_for_running(self, repo, sample_task): """Test update_status sets started_at when status is running.""" sample_task.started_at = None with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_task.task_id), "running") assert sample_task.started_at is not None def test_update_status_sets_completed_at_for_completed(self, repo, sample_task): """Test update_status sets completed_at when status is completed.""" sample_task.completed_at = None with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_task.task_id), "completed") assert sample_task.completed_at is not None def test_update_status_sets_completed_at_for_failed(self, repo, sample_task): """Test update_status sets completed_at when status is failed.""" sample_task.completed_at = None with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_task.task_id), "failed", error_message="Error occurred") assert sample_task.completed_at is not None assert sample_task.error_message == "Error occurred" def test_update_status_with_result_metrics(self, repo, sample_task): """Test update_status with result metrics.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_task.task_id), "completed", result_metrics={"mAP": 0.95}, ) assert sample_task.result_metrics == {"mAP": 0.95} def test_update_status_with_model_path(self, repo, sample_task): """Test update_status with model path.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_task.task_id), "completed", model_path="/path/to/model.pt", ) assert sample_task.model_path == "/path/to/model.pt" def test_update_status_not_found(self, repo): """Test update_status does nothing when task not found.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(uuid4()), "running") mock_session.add.assert_not_called() # ========================================================================= # cancel() tests # ========================================================================= def test_cancel_returns_true_for_pending(self, repo, sample_task): """Test cancel returns True for pending task.""" sample_task.status = "pending" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.cancel(str(sample_task.task_id)) assert result is True assert sample_task.status == "cancelled" def test_cancel_returns_true_for_scheduled(self, repo, sample_task): """Test cancel returns True for scheduled task.""" sample_task.status = "scheduled" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.cancel(str(sample_task.task_id)) assert result is True assert sample_task.status == "cancelled" def test_cancel_returns_false_for_running(self, repo, sample_task): """Test cancel returns False for running task.""" sample_task.status = "running" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_task mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.cancel(str(sample_task.task_id)) assert result is False def test_cancel_returns_false_when_not_found(self, repo): """Test cancel returns False when task not found.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.cancel(str(uuid4())) assert result is False # ========================================================================= # add_log() tests # ========================================================================= def test_add_log_creates_log_entry(self, repo): """Test add_log creates a log entry.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.add_log( task_id=str(uuid4()), level="INFO", message="Training started", ) mock_session.add.assert_called_once() added_log = mock_session.add.call_args[0][0] assert added_log.level == "INFO" assert added_log.message == "Training started" def test_add_log_with_details(self, repo): """Test add_log with details.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.add_log( task_id=str(uuid4()), level="DEBUG", message="Epoch complete", details={"epoch": 5, "loss": 0.05}, ) added_log = mock_session.add.call_args[0][0] assert added_log.details == {"epoch": 5, "loss": 0.05} # ========================================================================= # get_logs() tests # ========================================================================= def test_get_logs_returns_list(self, repo, sample_log): """Test get_logs returns list of logs.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_log] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_logs(str(sample_log.task_id)) assert len(result) == 1 assert result[0].level == "INFO" def test_get_logs_with_pagination(self, repo, sample_log): """Test get_logs with limit and offset.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_log] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_logs(str(sample_log.task_id), limit=50, offset=10) assert len(result) == 1 def test_get_logs_returns_empty_list(self, repo): """Test get_logs returns empty list when no logs.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_logs(str(uuid4())) assert result == [] # ========================================================================= # create_document_link() tests # ========================================================================= def test_create_document_link_returns_link(self, repo): """Test create_document_link returns created link.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) task_id = uuid4() document_id = uuid4() result = repo.create_document_link( task_id=task_id, document_id=document_id, ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_document_link_with_snapshot(self, repo): """Test create_document_link with annotation snapshot.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) snapshot = {"annotations": [{"class_name": "invoice_number"}]} repo.create_document_link( task_id=uuid4(), document_id=uuid4(), annotation_snapshot=snapshot, ) added_link = mock_session.add.call_args[0][0] assert added_link.annotation_snapshot == snapshot # ========================================================================= # get_document_links() tests # ========================================================================= def test_get_document_links_returns_list(self, repo, sample_link): """Test get_document_links returns list of links.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_link] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_links(sample_link.task_id) assert len(result) == 1 def test_get_document_links_returns_empty_list(self, repo): """Test get_document_links returns empty list when no links.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_links(uuid4()) assert result == [] # ========================================================================= # get_document_training_tasks() tests # ========================================================================= def test_get_document_training_tasks_returns_list(self, repo, sample_link): """Test get_document_training_tasks returns list of links.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_link] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_training_tasks(sample_link.document_id) assert len(result) == 1 def test_get_document_training_tasks_returns_empty_list(self, repo): """Test get_document_training_tasks returns empty list when no links.""" with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_document_training_tasks(uuid4()) assert result == []