Files
invoice-master-poc-v2/tests/integration/repositories/test_dataset_repo_integration.py
2026-02-01 22:40:41 +01:00

322 lines
10 KiB
Python

"""
Dataset Repository Integration Tests
Tests DatasetRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepositoryCreate:
"""Tests for dataset creation."""
def test_create_dataset(self, patched_session):
"""Test creating a training dataset."""
repo = DatasetRepository()
dataset = repo.create(
name="Test Dataset",
description="Dataset for integration testing",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
assert dataset is not None
assert dataset.name == "Test Dataset"
assert dataset.description == "Dataset for integration testing"
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
assert dataset.status == "building"
def test_create_dataset_with_defaults(self, patched_session):
"""Test creating dataset with default values."""
repo = DatasetRepository()
dataset = repo.create(name="Minimal Dataset")
assert dataset is not None
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
class TestDatasetRepositoryRead:
"""Tests for dataset retrieval."""
def test_get_dataset_by_id(self, patched_session, sample_dataset):
"""Test getting dataset by ID."""
repo = DatasetRepository()
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_id == sample_dataset.dataset_id
assert dataset.name == sample_dataset.name
def test_get_nonexistent_dataset(self, patched_session):
"""Test getting dataset that doesn't exist."""
repo = DatasetRepository()
dataset = repo.get(str(uuid4()))
assert dataset is None
def test_get_paginated_datasets(self, patched_session):
"""Test paginated dataset listing."""
repo = DatasetRepository()
# Create multiple datasets
for i in range(5):
repo.create(name=f"Dataset {i}")
datasets, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(datasets) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering datasets by status."""
repo = DatasetRepository()
# Create datasets with different statuses
d1 = repo.create(name="Building Dataset")
repo.update_status(str(d1.dataset_id), "ready")
d2 = repo.create(name="Another Building Dataset")
# stays as "building"
datasets, total = repo.get_paginated(status="ready")
assert total == 1
assert datasets[0].status == "ready"
class TestDatasetRepositoryUpdate:
"""Tests for dataset updates."""
def test_update_status(self, patched_session, sample_dataset):
"""Test updating dataset status."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
total_documents=100,
total_images=150,
total_annotations=500,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "ready"
assert dataset.total_documents == 100
assert dataset.total_images == 150
assert dataset.total_annotations == 500
def test_update_status_with_error(self, patched_session, sample_dataset):
"""Test updating dataset status with error message."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="failed",
error_message="Failed to build dataset: insufficient documents",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "failed"
assert "insufficient documents" in dataset.error_message
def test_update_status_with_path(self, patched_session, sample_dataset):
"""Test updating dataset path."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
dataset_path="/datasets/test_dataset_2024",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_path == "/datasets/test_dataset_2024"
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
"""Test updating dataset training status."""
repo = DatasetRepository()
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="running",
active_training_task_id=str(sample_training_task.task_id),
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "running"
assert dataset.active_training_task_id == sample_training_task.task_id
def test_update_training_status_completed(self, patched_session, sample_dataset):
"""Test updating training status to completed updates main status."""
repo = DatasetRepository()
# First set to ready
repo.update_status(str(sample_dataset.dataset_id), status="ready")
# Then complete training
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="completed",
update_main_status=True,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "completed"
assert dataset.status == "trained"
class TestDatasetDocuments:
"""Tests for dataset document management."""
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
"""Test adding documents to a dataset."""
repo = DatasetRepository()
documents_data = [
{
"document_id": str(multiple_documents[0].document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
},
{
"document_id": str(multiple_documents[1].document_id),
"split": "train",
"page_count": 2,
"annotation_count": 8,
},
{
"document_id": str(multiple_documents[2].document_id),
"split": "val",
"page_count": 1,
"annotation_count": 3,
},
]
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
# Verify documents were added
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 3
train_docs = [d for d in docs if d.split == "train"]
val_docs = [d for d in docs if d.split == "val"]
assert len(train_docs) == 2
assert len(val_docs) == 1
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
"""Test getting documents from a dataset."""
repo = DatasetRepository()
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 1
assert docs[0].document_id == sample_document.document_id
assert docs[0].split == "train"
assert docs[0].page_count == 1
assert docs[0].annotation_count == 5
class TestDatasetRepositoryDelete:
"""Tests for dataset deletion."""
def test_delete_dataset(self, patched_session, sample_dataset):
"""Test deleting a dataset."""
repo = DatasetRepository()
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is None
def test_delete_nonexistent_dataset(self, patched_session):
"""Test deleting dataset that doesn't exist."""
repo = DatasetRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
"""Test deleting dataset also removes document links."""
repo = DatasetRepository()
# Add document to dataset
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
# Delete dataset
repo.delete(str(sample_dataset.dataset_id))
# Document links should be gone
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 0
class TestActiveTrainingTasks:
"""Tests for active training task queries."""
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
"""Test getting active training tasks for datasets."""
repo = DatasetRepository()
# Update task to running
from inference.data.repositories.training_task_repository import TrainingTaskRepository
task_repo = TrainingTaskRepository()
task_repo.update_status(str(sample_training_task.task_id), "running")
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
assert str(sample_dataset.dataset_id) in result
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
"""Test getting active training tasks returns empty when no tasks exist."""
repo = DatasetRepository()
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
# No training task exists for this dataset, so result should be empty
assert str(sample_dataset.dataset_id) not in result
assert result == {}