Add more tests
This commit is contained in:
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Training Task Repository Integration Tests
|
||||
|
||||
Tests TrainingTaskRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskCreate:
|
||||
"""Tests for training task creation."""
|
||||
|
||||
def test_create_training_task(self, patched_session, admin_token):
|
||||
"""Test creating a training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
task_type="train",
|
||||
description="Integration test training task",
|
||||
config={"epochs": 100, "batch_size": 16},
|
||||
)
|
||||
|
||||
assert task_id is not None
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.name == "Test Training Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
def test_create_scheduled_task(self, patched_session, admin_token):
|
||||
"""Test creating a scheduled training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Scheduled Task",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.status == "scheduled"
|
||||
assert task.scheduled_at is not None
|
||||
|
||||
def test_create_recurring_task(self, patched_session, admin_token):
|
||||
"""Test creating a recurring training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Recurring Task",
|
||||
cron_expression="0 2 * * *",
|
||||
is_recurring=True,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.is_recurring is True
|
||||
assert task.cron_expression == "0 2 * * *"
|
||||
|
||||
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test creating task linked to a dataset."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Dataset Training Task",
|
||||
dataset_id=str(sample_dataset.dataset_id),
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.dataset_id == sample_dataset.dataset_id
|
||||
|
||||
|
||||
class TestTrainingTaskRead:
|
||||
"""Tests for training task retrieval."""
|
||||
|
||||
def test_get_task_by_id(self, patched_session, sample_training_task):
|
||||
"""Test getting task by ID."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
|
||||
assert task is not None
|
||||
assert task.task_id == sample_training_task.task_id
|
||||
|
||||
def test_get_nonexistent_task(self, patched_session):
|
||||
"""Test getting task that doesn't exist."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(uuid4()))
|
||||
assert task is None
|
||||
|
||||
def test_get_paginated_tasks(self, patched_session, admin_token):
|
||||
"""Test paginated task listing."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks
|
||||
for i in range(5):
|
||||
repo.create(admin_token=admin_token.token, name=f"Task {i}")
|
||||
|
||||
tasks, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(tasks) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
|
||||
"""Test filtering tasks by status."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create tasks with different statuses
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
tasks, total = repo.get_paginated(status="running")
|
||||
|
||||
assert total == 1
|
||||
assert tasks[0].status == "running"
|
||||
|
||||
def test_get_pending_tasks(self, patched_session, admin_token):
|
||||
"""Test getting pending tasks ready to run."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create pending task
|
||||
repo.create(admin_token=admin_token.token, name="Ready Task")
|
||||
|
||||
# Create scheduled task in the past (should be included)
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Past Scheduled Task",
|
||||
scheduled_at=past_time,
|
||||
)
|
||||
|
||||
# Create scheduled task in the future (should not be included)
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Future Scheduled Task",
|
||||
scheduled_at=future_time,
|
||||
)
|
||||
|
||||
pending = repo.get_pending()
|
||||
|
||||
# Should include pending and past scheduled, not future scheduled
|
||||
assert len(pending) >= 2
|
||||
names = [t.name for t in pending]
|
||||
assert "Ready Task" in names
|
||||
assert "Past Scheduled Task" in names
|
||||
|
||||
def test_get_running_task(self, patched_session, admin_token):
|
||||
"""Test getting currently running task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
running = repo.get_running()
|
||||
|
||||
assert running is not None
|
||||
assert running.status == "running"
|
||||
|
||||
def test_get_running_task_none(self, patched_session, admin_token):
|
||||
"""Test getting running task when none is running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
running = repo.get_running()
|
||||
assert running is None
|
||||
|
||||
|
||||
class TestTrainingTaskUpdate:
|
||||
"""Tests for training task updates."""
|
||||
|
||||
def test_update_status_to_running(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "running"
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_update_status_to_completed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to completed."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"completed",
|
||||
result_metrics=metrics,
|
||||
model_path="/models/trained_model.pt",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.completed_at is not None
|
||||
assert task.result_metrics["mAP"] == 0.92
|
||||
assert task.model_path == "/models/trained_model.pt"
|
||||
|
||||
def test_update_status_to_failed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to failed with error message."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"failed",
|
||||
error_message="CUDA out of memory",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "failed"
|
||||
assert task.completed_at is not None
|
||||
assert "CUDA out of memory" in task.error_message
|
||||
|
||||
def test_cancel_pending_task(self, patched_session, sample_training_task):
|
||||
"""Test cancelling a pending task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "cancelled"
|
||||
|
||||
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
|
||||
"""Test that running task cannot be cancelled."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task.status == "running"
|
||||
|
||||
|
||||
class TestTrainingLogs:
|
||||
"""Tests for training log management."""
|
||||
|
||||
def test_add_log_entry(self, patched_session, sample_training_task):
|
||||
"""Test adding a training log entry."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message="Starting training...",
|
||||
details={"epoch": 1, "batch": 0},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 1
|
||||
assert logs[0].level == "INFO"
|
||||
assert logs[0].message == "Starting training..."
|
||||
|
||||
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
|
||||
"""Test adding multiple log entries."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Epoch {i} completed",
|
||||
details={"epoch": i, "loss": 0.5 - i * 0.1},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 5
|
||||
|
||||
def test_get_logs_pagination(self, patched_session, sample_training_task):
|
||||
"""Test paginated log retrieval."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(10):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Log entry {i}",
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
|
||||
assert len(logs) == 5
|
||||
|
||||
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
|
||||
assert len(logs_page2) == 5
|
||||
|
||||
|
||||
class TestDocumentLinks:
|
||||
"""Tests for training document link management."""
|
||||
|
||||
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
|
||||
"""Test creating a document link."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
link = repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=sample_document.document_id,
|
||||
annotation_snapshot={"count": 5, "verified": 3},
|
||||
)
|
||||
|
||||
assert link is not None
|
||||
assert link.task_id == sample_training_task.task_id
|
||||
assert link.document_id == sample_document.document_id
|
||||
assert link.annotation_snapshot["count"] == 5
|
||||
|
||||
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
|
||||
"""Test getting all document links for a task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for doc in multiple_documents[:3]:
|
||||
repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_links(sample_training_task.task_id)
|
||||
assert len(links) == 3
|
||||
|
||||
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
|
||||
"""Test getting training tasks that used a document."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks using the same document
|
||||
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
|
||||
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
|
||||
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task1_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task2_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_training_tasks(sample_document.document_id)
|
||||
assert len(links) == 2
|
||||
Reference in New Issue
Block a user