""" Training Task Repository Integration Tests Tests TrainingTaskRepository with real database operations. """ from datetime import datetime, timezone, timedelta from uuid import uuid4 import pytest from inference.data.repositories.training_task_repository import TrainingTaskRepository class TestTrainingTaskCreate: """Tests for training task creation.""" def test_create_training_task(self, patched_session, admin_token): """Test creating a training task.""" repo = TrainingTaskRepository() task_id = repo.create( admin_token=admin_token.token, name="Test Training Task", task_type="train", description="Integration test training task", config={"epochs": 100, "batch_size": 16}, ) assert task_id is not None task = repo.get(task_id) assert task is not None assert task.name == "Test Training Task" assert task.task_type == "train" assert task.status == "pending" assert task.config["epochs"] == 100 def test_create_scheduled_task(self, patched_session, admin_token): """Test creating a scheduled training task.""" repo = TrainingTaskRepository() scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1) task_id = repo.create( admin_token=admin_token.token, name="Scheduled Task", scheduled_at=scheduled_time, ) task = repo.get(task_id) assert task is not None assert task.status == "scheduled" assert task.scheduled_at is not None def test_create_recurring_task(self, patched_session, admin_token): """Test creating a recurring training task.""" repo = TrainingTaskRepository() task_id = repo.create( admin_token=admin_token.token, name="Recurring Task", cron_expression="0 2 * * *", is_recurring=True, ) task = repo.get(task_id) assert task is not None assert task.is_recurring is True assert task.cron_expression == "0 2 * * *" def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset): """Test creating task linked to a dataset.""" repo = TrainingTaskRepository() task_id = repo.create( admin_token=admin_token.token, name="Dataset Training Task", dataset_id=str(sample_dataset.dataset_id), ) task = repo.get(task_id) assert task is not None assert task.dataset_id == sample_dataset.dataset_id class TestTrainingTaskRead: """Tests for training task retrieval.""" def test_get_task_by_id(self, patched_session, sample_training_task): """Test getting task by ID.""" repo = TrainingTaskRepository() task = repo.get(str(sample_training_task.task_id)) assert task is not None assert task.task_id == sample_training_task.task_id def test_get_nonexistent_task(self, patched_session): """Test getting task that doesn't exist.""" repo = TrainingTaskRepository() task = repo.get(str(uuid4())) assert task is None def test_get_paginated_tasks(self, patched_session, admin_token): """Test paginated task listing.""" repo = TrainingTaskRepository() # Create multiple tasks for i in range(5): repo.create(admin_token=admin_token.token, name=f"Task {i}") tasks, total = repo.get_paginated(limit=2, offset=0) assert total == 5 assert len(tasks) == 2 def test_get_paginated_with_status_filter(self, patched_session, admin_token): """Test filtering tasks by status.""" repo = TrainingTaskRepository() # Create tasks with different statuses task_id = repo.create(admin_token=admin_token.token, name="Running Task") repo.update_status(task_id, "running") repo.create(admin_token=admin_token.token, name="Pending Task") tasks, total = repo.get_paginated(status="running") assert total == 1 assert tasks[0].status == "running" def test_get_pending_tasks(self, patched_session, admin_token): """Test getting pending tasks ready to run.""" repo = TrainingTaskRepository() # Create pending task repo.create(admin_token=admin_token.token, name="Ready Task") # Create scheduled task in the past (should be included) past_time = datetime.now(timezone.utc) - timedelta(hours=1) repo.create( admin_token=admin_token.token, name="Past Scheduled Task", scheduled_at=past_time, ) # Create scheduled task in the future (should not be included) future_time = datetime.now(timezone.utc) + timedelta(hours=1) repo.create( admin_token=admin_token.token, name="Future Scheduled Task", scheduled_at=future_time, ) pending = repo.get_pending() # Should include pending and past scheduled, not future scheduled assert len(pending) >= 2 names = [t.name for t in pending] assert "Ready Task" in names assert "Past Scheduled Task" in names def test_get_running_task(self, patched_session, admin_token): """Test getting currently running task.""" repo = TrainingTaskRepository() task_id = repo.create(admin_token=admin_token.token, name="Running Task") repo.update_status(task_id, "running") running = repo.get_running() assert running is not None assert running.status == "running" def test_get_running_task_none(self, patched_session, admin_token): """Test getting running task when none is running.""" repo = TrainingTaskRepository() repo.create(admin_token=admin_token.token, name="Pending Task") running = repo.get_running() assert running is None class TestTrainingTaskUpdate: """Tests for training task updates.""" def test_update_status_to_running(self, patched_session, sample_training_task): """Test updating task status to running.""" repo = TrainingTaskRepository() repo.update_status(str(sample_training_task.task_id), "running") task = repo.get(str(sample_training_task.task_id)) assert task is not None assert task.status == "running" assert task.started_at is not None def test_update_status_to_completed(self, patched_session, sample_training_task): """Test updating task status to completed.""" repo = TrainingTaskRepository() metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85} repo.update_status( str(sample_training_task.task_id), "completed", result_metrics=metrics, model_path="/models/trained_model.pt", ) task = repo.get(str(sample_training_task.task_id)) assert task is not None assert task.status == "completed" assert task.completed_at is not None assert task.result_metrics["mAP"] == 0.92 assert task.model_path == "/models/trained_model.pt" def test_update_status_to_failed(self, patched_session, sample_training_task): """Test updating task status to failed with error message.""" repo = TrainingTaskRepository() repo.update_status( str(sample_training_task.task_id), "failed", error_message="CUDA out of memory", ) task = repo.get(str(sample_training_task.task_id)) assert task is not None assert task.status == "failed" assert task.completed_at is not None assert "CUDA out of memory" in task.error_message def test_cancel_pending_task(self, patched_session, sample_training_task): """Test cancelling a pending task.""" repo = TrainingTaskRepository() result = repo.cancel(str(sample_training_task.task_id)) assert result is True task = repo.get(str(sample_training_task.task_id)) assert task is not None assert task.status == "cancelled" def test_cannot_cancel_running_task(self, patched_session, sample_training_task): """Test that running task cannot be cancelled.""" repo = TrainingTaskRepository() repo.update_status(str(sample_training_task.task_id), "running") result = repo.cancel(str(sample_training_task.task_id)) assert result is False task = repo.get(str(sample_training_task.task_id)) assert task.status == "running" class TestTrainingLogs: """Tests for training log management.""" def test_add_log_entry(self, patched_session, sample_training_task): """Test adding a training log entry.""" repo = TrainingTaskRepository() repo.add_log( str(sample_training_task.task_id), level="INFO", message="Starting training...", details={"epoch": 1, "batch": 0}, ) logs = repo.get_logs(str(sample_training_task.task_id)) assert len(logs) == 1 assert logs[0].level == "INFO" assert logs[0].message == "Starting training..." def test_add_multiple_log_entries(self, patched_session, sample_training_task): """Test adding multiple log entries.""" repo = TrainingTaskRepository() for i in range(5): repo.add_log( str(sample_training_task.task_id), level="INFO", message=f"Epoch {i} completed", details={"epoch": i, "loss": 0.5 - i * 0.1}, ) logs = repo.get_logs(str(sample_training_task.task_id)) assert len(logs) == 5 def test_get_logs_pagination(self, patched_session, sample_training_task): """Test paginated log retrieval.""" repo = TrainingTaskRepository() for i in range(10): repo.add_log( str(sample_training_task.task_id), level="INFO", message=f"Log entry {i}", ) logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0) assert len(logs) == 5 logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5) assert len(logs_page2) == 5 class TestDocumentLinks: """Tests for training document link management.""" def test_create_document_link(self, patched_session, sample_training_task, sample_document): """Test creating a document link.""" repo = TrainingTaskRepository() link = repo.create_document_link( task_id=sample_training_task.task_id, document_id=sample_document.document_id, annotation_snapshot={"count": 5, "verified": 3}, ) assert link is not None assert link.task_id == sample_training_task.task_id assert link.document_id == sample_document.document_id assert link.annotation_snapshot["count"] == 5 def test_get_document_links(self, patched_session, sample_training_task, multiple_documents): """Test getting all document links for a task.""" repo = TrainingTaskRepository() for doc in multiple_documents[:3]: repo.create_document_link( task_id=sample_training_task.task_id, document_id=doc.document_id, ) links = repo.get_document_links(sample_training_task.task_id) assert len(links) == 3 def test_get_document_training_tasks(self, patched_session, admin_token, sample_document): """Test getting training tasks that used a document.""" repo = TrainingTaskRepository() # Create multiple tasks using the same document task1_id = repo.create(admin_token=admin_token.token, name="Task 1") task2_id = repo.create(admin_token=admin_token.token, name="Task 2") repo.create_document_link( task_id=repo.get(task1_id).task_id, document_id=sample_document.document_id, ) repo.create_document_link( task_id=repo.get(task2_id).task_id, document_id=sample_document.document_id, ) links = repo.get_document_training_tasks(sample_document.document_id) assert len(links) == 2