Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,364 @@
"""
Training Task Repository Integration Tests
Tests TrainingTaskRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:
"""Tests for training task creation."""
def test_create_training_task(self, patched_session, admin_token):
"""Test creating a training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Test Training Task",
task_type="train",
description="Integration test training task",
config={"epochs": 100, "batch_size": 16},
)
assert task_id is not None
task = repo.get(task_id)
assert task is not None
assert task.name == "Test Training Task"
assert task.task_type == "train"
assert task.status == "pending"
assert task.config["epochs"] == 100
def test_create_scheduled_task(self, patched_session, admin_token):
"""Test creating a scheduled training task."""
repo = TrainingTaskRepository()
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
task_id = repo.create(
admin_token=admin_token.token,
name="Scheduled Task",
scheduled_at=scheduled_time,
)
task = repo.get(task_id)
assert task is not None
assert task.status == "scheduled"
assert task.scheduled_at is not None
def test_create_recurring_task(self, patched_session, admin_token):
"""Test creating a recurring training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Recurring Task",
cron_expression="0 2 * * *",
is_recurring=True,
)
task = repo.get(task_id)
assert task is not None
assert task.is_recurring is True
assert task.cron_expression == "0 2 * * *"
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
"""Test creating task linked to a dataset."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Dataset Training Task",
dataset_id=str(sample_dataset.dataset_id),
)
task = repo.get(task_id)
assert task is not None
assert task.dataset_id == sample_dataset.dataset_id
class TestTrainingTaskRead:
"""Tests for training task retrieval."""
def test_get_task_by_id(self, patched_session, sample_training_task):
"""Test getting task by ID."""
repo = TrainingTaskRepository()
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.task_id == sample_training_task.task_id
def test_get_nonexistent_task(self, patched_session):
"""Test getting task that doesn't exist."""
repo = TrainingTaskRepository()
task = repo.get(str(uuid4()))
assert task is None
def test_get_paginated_tasks(self, patched_session, admin_token):
"""Test paginated task listing."""
repo = TrainingTaskRepository()
# Create multiple tasks
for i in range(5):
repo.create(admin_token=admin_token.token, name=f"Task {i}")
tasks, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(tasks) == 2
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
"""Test filtering tasks by status."""
repo = TrainingTaskRepository()
# Create tasks with different statuses
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
repo.create(admin_token=admin_token.token, name="Pending Task")
tasks, total = repo.get_paginated(status="running")
assert total == 1
assert tasks[0].status == "running"
def test_get_pending_tasks(self, patched_session, admin_token):
"""Test getting pending tasks ready to run."""
repo = TrainingTaskRepository()
# Create pending task
repo.create(admin_token=admin_token.token, name="Ready Task")
# Create scheduled task in the past (should be included)
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Past Scheduled Task",
scheduled_at=past_time,
)
# Create scheduled task in the future (should not be included)
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Future Scheduled Task",
scheduled_at=future_time,
)
pending = repo.get_pending()
# Should include pending and past scheduled, not future scheduled
assert len(pending) >= 2
names = [t.name for t in pending]
assert "Ready Task" in names
assert "Past Scheduled Task" in names
def test_get_running_task(self, patched_session, admin_token):
"""Test getting currently running task."""
repo = TrainingTaskRepository()
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
running = repo.get_running()
assert running is not None
assert running.status == "running"
def test_get_running_task_none(self, patched_session, admin_token):
"""Test getting running task when none is running."""
repo = TrainingTaskRepository()
repo.create(admin_token=admin_token.token, name="Pending Task")
running = repo.get_running()
assert running is None
class TestTrainingTaskUpdate:
"""Tests for training task updates."""
def test_update_status_to_running(self, patched_session, sample_training_task):
"""Test updating task status to running."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "running"
assert task.started_at is not None
def test_update_status_to_completed(self, patched_session, sample_training_task):
"""Test updating task status to completed."""
repo = TrainingTaskRepository()
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
repo.update_status(
str(sample_training_task.task_id),
"completed",
result_metrics=metrics,
model_path="/models/trained_model.pt",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "completed"
assert task.completed_at is not None
assert task.result_metrics["mAP"] == 0.92
assert task.model_path == "/models/trained_model.pt"
def test_update_status_to_failed(self, patched_session, sample_training_task):
"""Test updating task status to failed with error message."""
repo = TrainingTaskRepository()
repo.update_status(
str(sample_training_task.task_id),
"failed",
error_message="CUDA out of memory",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "failed"
assert task.completed_at is not None
assert "CUDA out of memory" in task.error_message
def test_cancel_pending_task(self, patched_session, sample_training_task):
"""Test cancelling a pending task."""
repo = TrainingTaskRepository()
result = repo.cancel(str(sample_training_task.task_id))
assert result is True
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "cancelled"
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
"""Test that running task cannot be cancelled."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
result = repo.cancel(str(sample_training_task.task_id))
assert result is False
task = repo.get(str(sample_training_task.task_id))
assert task.status == "running"
class TestTrainingLogs:
"""Tests for training log management."""
def test_add_log_entry(self, patched_session, sample_training_task):
"""Test adding a training log entry."""
repo = TrainingTaskRepository()
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message="Starting training...",
details={"epoch": 1, "batch": 0},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 1
assert logs[0].level == "INFO"
assert logs[0].message == "Starting training..."
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
"""Test adding multiple log entries."""
repo = TrainingTaskRepository()
for i in range(5):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Epoch {i} completed",
details={"epoch": i, "loss": 0.5 - i * 0.1},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 5
def test_get_logs_pagination(self, patched_session, sample_training_task):
"""Test paginated log retrieval."""
repo = TrainingTaskRepository()
for i in range(10):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Log entry {i}",
)
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
assert len(logs) == 5
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
assert len(logs_page2) == 5
class TestDocumentLinks:
"""Tests for training document link management."""
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
"""Test creating a document link."""
repo = TrainingTaskRepository()
link = repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=sample_document.document_id,
annotation_snapshot={"count": 5, "verified": 3},
)
assert link is not None
assert link.task_id == sample_training_task.task_id
assert link.document_id == sample_document.document_id
assert link.annotation_snapshot["count"] == 5
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
"""Test getting all document links for a task."""
repo = TrainingTaskRepository()
for doc in multiple_documents[:3]:
repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=doc.document_id,
)
links = repo.get_document_links(sample_training_task.task_id)
assert len(links) == 3
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
"""Test getting training tasks that used a document."""
repo = TrainingTaskRepository()
# Create multiple tasks using the same document
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
repo.create_document_link(
task_id=repo.get(task1_id).task_id,
document_id=sample_document.document_id,
)
repo.create_document_link(
task_id=repo.get(task2_id).task_id,
document_id=sample_document.document_id,
)
links = repo.get_document_training_tasks(sample_document.document_id)
assert len(links) == 2