Add more tests
This commit is contained in:
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Dataset Repository Integration Tests
|
||||
|
||||
Tests DatasetRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepositoryCreate:
|
||||
"""Tests for dataset creation."""
|
||||
|
||||
def test_create_dataset(self, patched_session):
|
||||
"""Test creating a training dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.name == "Test Dataset"
|
||||
assert dataset.description == "Dataset for integration testing"
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_create_dataset_with_defaults(self, patched_session):
|
||||
"""Test creating dataset with default values."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(name="Minimal Dataset")
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
|
||||
|
||||
class TestDatasetRepositoryRead:
|
||||
"""Tests for dataset retrieval."""
|
||||
|
||||
def test_get_dataset_by_id(self, patched_session, sample_dataset):
|
||||
"""Test getting dataset by ID."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_id == sample_dataset.dataset_id
|
||||
assert dataset.name == sample_dataset.name
|
||||
|
||||
def test_get_nonexistent_dataset(self, patched_session):
|
||||
"""Test getting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(uuid4()))
|
||||
assert dataset is None
|
||||
|
||||
def test_get_paginated_datasets(self, patched_session):
|
||||
"""Test paginated dataset listing."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create multiple datasets
|
||||
for i in range(5):
|
||||
repo.create(name=f"Dataset {i}")
|
||||
|
||||
datasets, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(datasets) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering datasets by status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create datasets with different statuses
|
||||
d1 = repo.create(name="Building Dataset")
|
||||
repo.update_status(str(d1.dataset_id), "ready")
|
||||
|
||||
d2 = repo.create(name="Another Building Dataset")
|
||||
# stays as "building"
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert total == 1
|
||||
assert datasets[0].status == "ready"
|
||||
|
||||
|
||||
class TestDatasetRepositoryUpdate:
|
||||
"""Tests for dataset updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
total_documents=100,
|
||||
total_images=150,
|
||||
total_annotations=500,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "ready"
|
||||
assert dataset.total_documents == 100
|
||||
assert dataset.total_images == 150
|
||||
assert dataset.total_annotations == 500
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status with error message."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="failed",
|
||||
error_message="Failed to build dataset: insufficient documents",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "failed"
|
||||
assert "insufficient documents" in dataset.error_message
|
||||
|
||||
def test_update_status_with_path(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset path."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
dataset_path="/datasets/test_dataset_2024",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_path == "/datasets/test_dataset_2024"
|
||||
|
||||
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test updating dataset training status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(sample_training_task.task_id),
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "running"
|
||||
assert dataset.active_training_task_id == sample_training_task.task_id
|
||||
|
||||
def test_update_training_status_completed(self, patched_session, sample_dataset):
|
||||
"""Test updating training status to completed updates main status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# First set to ready
|
||||
repo.update_status(str(sample_dataset.dataset_id), status="ready")
|
||||
|
||||
# Then complete training
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "completed"
|
||||
assert dataset.status == "trained"
|
||||
|
||||
|
||||
class TestDatasetDocuments:
|
||||
"""Tests for dataset document management."""
|
||||
|
||||
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
|
||||
"""Test adding documents to a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
documents_data = [
|
||||
{
|
||||
"document_id": str(multiple_documents[0].document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[1].document_id),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 8,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[2].document_id),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 3,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
|
||||
|
||||
# Verify documents were added
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 3
|
||||
|
||||
train_docs = [d for d in docs if d.split == "train"]
|
||||
val_docs = [d for d in docs if d.split == "val"]
|
||||
|
||||
assert len(train_docs) == 2
|
||||
assert len(val_docs) == 1
|
||||
|
||||
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test getting documents from a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].document_id == sample_document.document_id
|
||||
assert docs[0].split == "train"
|
||||
assert docs[0].page_count == 1
|
||||
assert docs[0].annotation_count == 5
|
||||
|
||||
|
||||
class TestDatasetRepositoryDelete:
|
||||
"""Tests for dataset deletion."""
|
||||
|
||||
def test_delete_dataset(self, patched_session, sample_dataset):
|
||||
"""Test deleting a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
assert result is True
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is None
|
||||
|
||||
def test_delete_nonexistent_dataset(self, patched_session):
|
||||
"""Test deleting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test deleting dataset also removes document links."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Add document to dataset
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Delete dataset
|
||||
repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
# Document links should be gone
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
class TestActiveTrainingTasks:
|
||||
"""Tests for active training task queries."""
|
||||
|
||||
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test getting active training tasks for datasets."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Update task to running
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
task_repo = TrainingTaskRepository()
|
||||
task_repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
assert str(sample_dataset.dataset_id) in result
|
||||
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
|
||||
|
||||
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
|
||||
"""Test getting active training tasks returns empty when no tasks exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
# No training task exists for this dataset, so result should be empty
|
||||
assert str(sample_dataset.dataset_id) not in result
|
||||
assert result == {}
|
||||
Reference in New Issue
Block a user