Files
invoice-master-poc-v2/tests/integration/repositories/test_training_task_repo_integration.py
2026-02-01 22:40:41 +01:00

365 lines
12 KiB
Python

"""
Training Task Repository Integration Tests
Tests TrainingTaskRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:
"""Tests for training task creation."""
def test_create_training_task(self, patched_session, admin_token):
"""Test creating a training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Test Training Task",
task_type="train",
description="Integration test training task",
config={"epochs": 100, "batch_size": 16},
)
assert task_id is not None
task = repo.get(task_id)
assert task is not None
assert task.name == "Test Training Task"
assert task.task_type == "train"
assert task.status == "pending"
assert task.config["epochs"] == 100
def test_create_scheduled_task(self, patched_session, admin_token):
"""Test creating a scheduled training task."""
repo = TrainingTaskRepository()
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
task_id = repo.create(
admin_token=admin_token.token,
name="Scheduled Task",
scheduled_at=scheduled_time,
)
task = repo.get(task_id)
assert task is not None
assert task.status == "scheduled"
assert task.scheduled_at is not None
def test_create_recurring_task(self, patched_session, admin_token):
"""Test creating a recurring training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Recurring Task",
cron_expression="0 2 * * *",
is_recurring=True,
)
task = repo.get(task_id)
assert task is not None
assert task.is_recurring is True
assert task.cron_expression == "0 2 * * *"
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
"""Test creating task linked to a dataset."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Dataset Training Task",
dataset_id=str(sample_dataset.dataset_id),
)
task = repo.get(task_id)
assert task is not None
assert task.dataset_id == sample_dataset.dataset_id
class TestTrainingTaskRead:
"""Tests for training task retrieval."""
def test_get_task_by_id(self, patched_session, sample_training_task):
"""Test getting task by ID."""
repo = TrainingTaskRepository()
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.task_id == sample_training_task.task_id
def test_get_nonexistent_task(self, patched_session):
"""Test getting task that doesn't exist."""
repo = TrainingTaskRepository()
task = repo.get(str(uuid4()))
assert task is None
def test_get_paginated_tasks(self, patched_session, admin_token):
"""Test paginated task listing."""
repo = TrainingTaskRepository()
# Create multiple tasks
for i in range(5):
repo.create(admin_token=admin_token.token, name=f"Task {i}")
tasks, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(tasks) == 2
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
"""Test filtering tasks by status."""
repo = TrainingTaskRepository()
# Create tasks with different statuses
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
repo.create(admin_token=admin_token.token, name="Pending Task")
tasks, total = repo.get_paginated(status="running")
assert total == 1
assert tasks[0].status == "running"
def test_get_pending_tasks(self, patched_session, admin_token):
"""Test getting pending tasks ready to run."""
repo = TrainingTaskRepository()
# Create pending task
repo.create(admin_token=admin_token.token, name="Ready Task")
# Create scheduled task in the past (should be included)
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Past Scheduled Task",
scheduled_at=past_time,
)
# Create scheduled task in the future (should not be included)
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Future Scheduled Task",
scheduled_at=future_time,
)
pending = repo.get_pending()
# Should include pending and past scheduled, not future scheduled
assert len(pending) >= 2
names = [t.name for t in pending]
assert "Ready Task" in names
assert "Past Scheduled Task" in names
def test_get_running_task(self, patched_session, admin_token):
"""Test getting currently running task."""
repo = TrainingTaskRepository()
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
running = repo.get_running()
assert running is not None
assert running.status == "running"
def test_get_running_task_none(self, patched_session, admin_token):
"""Test getting running task when none is running."""
repo = TrainingTaskRepository()
repo.create(admin_token=admin_token.token, name="Pending Task")
running = repo.get_running()
assert running is None
class TestTrainingTaskUpdate:
"""Tests for training task updates."""
def test_update_status_to_running(self, patched_session, sample_training_task):
"""Test updating task status to running."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "running"
assert task.started_at is not None
def test_update_status_to_completed(self, patched_session, sample_training_task):
"""Test updating task status to completed."""
repo = TrainingTaskRepository()
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
repo.update_status(
str(sample_training_task.task_id),
"completed",
result_metrics=metrics,
model_path="/models/trained_model.pt",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "completed"
assert task.completed_at is not None
assert task.result_metrics["mAP"] == 0.92
assert task.model_path == "/models/trained_model.pt"
def test_update_status_to_failed(self, patched_session, sample_training_task):
"""Test updating task status to failed with error message."""
repo = TrainingTaskRepository()
repo.update_status(
str(sample_training_task.task_id),
"failed",
error_message="CUDA out of memory",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "failed"
assert task.completed_at is not None
assert "CUDA out of memory" in task.error_message
def test_cancel_pending_task(self, patched_session, sample_training_task):
"""Test cancelling a pending task."""
repo = TrainingTaskRepository()
result = repo.cancel(str(sample_training_task.task_id))
assert result is True
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "cancelled"
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
"""Test that running task cannot be cancelled."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
result = repo.cancel(str(sample_training_task.task_id))
assert result is False
task = repo.get(str(sample_training_task.task_id))
assert task.status == "running"
class TestTrainingLogs:
"""Tests for training log management."""
def test_add_log_entry(self, patched_session, sample_training_task):
"""Test adding a training log entry."""
repo = TrainingTaskRepository()
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message="Starting training...",
details={"epoch": 1, "batch": 0},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 1
assert logs[0].level == "INFO"
assert logs[0].message == "Starting training..."
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
"""Test adding multiple log entries."""
repo = TrainingTaskRepository()
for i in range(5):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Epoch {i} completed",
details={"epoch": i, "loss": 0.5 - i * 0.1},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 5
def test_get_logs_pagination(self, patched_session, sample_training_task):
"""Test paginated log retrieval."""
repo = TrainingTaskRepository()
for i in range(10):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Log entry {i}",
)
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
assert len(logs) == 5
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
assert len(logs_page2) == 5
class TestDocumentLinks:
"""Tests for training document link management."""
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
"""Test creating a document link."""
repo = TrainingTaskRepository()
link = repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=sample_document.document_id,
annotation_snapshot={"count": 5, "verified": 3},
)
assert link is not None
assert link.task_id == sample_training_task.task_id
assert link.document_id == sample_document.document_id
assert link.annotation_snapshot["count"] == 5
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
"""Test getting all document links for a task."""
repo = TrainingTaskRepository()
for doc in multiple_documents[:3]:
repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=doc.document_id,
)
links = repo.get_document_links(sample_training_task.task_id)
assert len(links) == 3
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
"""Test getting training tasks that used a document."""
repo = TrainingTaskRepository()
# Create multiple tasks using the same document
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
repo.create_document_link(
task_id=repo.get(task1_id).task_id,
document_id=sample_document.document_id,
)
repo.create_document_link(
task_id=repo.get(task2_id).task_id,
document_id=sample_document.document_id,
)
links = repo.get_document_training_tasks(sample_document.document_id)
assert len(links) == 2