Files
invoice-master-poc-v2/tests/data/repositories/test_dataset_repository.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

598 lines
26 KiB
Python

"""
Tests for DatasetRepository
100% coverage tests for dataset management.
"""
import pytest
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from backend.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepository:
"""Tests for DatasetRepository."""
@pytest.fixture
def sample_dataset(self) -> TrainingDataset:
"""Create a sample dataset for testing."""
return TrainingDataset(
dataset_id=uuid4(),
name="Test Dataset",
description="A test dataset",
status="ready",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
total_documents=100,
total_images=100,
total_annotations=500,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_dataset_document(self) -> DatasetDocument:
"""Create a sample dataset document for testing."""
return DatasetDocument(
id=uuid4(),
dataset_id=uuid4(),
document_id=uuid4(),
split="train",
page_count=2,
annotation_count=10,
created_at=datetime.now(timezone.utc),
)
@pytest.fixture
def sample_training_task(self) -> TrainingTask:
"""Create a sample training task for testing."""
return TrainingTask(
task_id=uuid4(),
admin_token="admin-token",
name="Test Task",
status="running",
dataset_id=uuid4(),
)
@pytest.fixture
def repo(self) -> DatasetRepository:
"""Create a DatasetRepository instance."""
return DatasetRepository()
# =========================================================================
# create() tests
# =========================================================================
def test_create_returns_dataset(self, repo):
"""Test create returns created dataset."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(name="Test Dataset")
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.create(
name="Full Dataset",
description="A complete dataset",
train_ratio=0.7,
val_ratio=0.15,
seed=123,
)
added_dataset = mock_session.add.call_args[0][0]
assert added_dataset.name == "Full Dataset"
assert added_dataset.description == "A complete dataset"
assert added_dataset.train_ratio == 0.7
assert added_dataset.val_ratio == 0.15
assert added_dataset.seed == 123
def test_create_default_values(self, repo):
"""Test create uses default values."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.create(name="Minimal Dataset")
added_dataset = mock_session.add.call_args[0][0]
assert added_dataset.train_ratio == 0.8
assert added_dataset.val_ratio == 0.1
assert added_dataset.seed == 42
# =========================================================================
# get() tests
# =========================================================================
def test_get_returns_dataset(self, repo, sample_dataset):
"""Test get returns dataset when exists."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(sample_dataset.dataset_id))
assert result is not None
assert result.name == "Test Dataset"
mock_session.expunge.assert_called_once()
def test_get_with_uuid(self, repo, sample_dataset):
"""Test get works with UUID object."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(sample_dataset.dataset_id)
assert result is not None
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when dataset not found."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get(str(uuid4()))
assert result is None
mock_session.expunge.assert_not_called()
# =========================================================================
# get_paginated() tests
# =========================================================================
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
"""Test get_paginated returns list of datasets and total count."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated()
assert len(datasets) == 1
assert total == 1
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
"""Test get_paginated filters by status."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated(status="ready")
assert len(datasets) == 1
def test_get_paginated_with_pagination(self, repo, sample_dataset):
"""Test get_paginated with limit and offset."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_dataset]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated(limit=10, offset=20)
assert total == 50
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
datasets, total = repo.get_paginated()
assert datasets == []
assert total == 0
# =========================================================================
# get_active_training_tasks() tests
# =========================================================================
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
"""Test get_active_training_tasks returns dict of active tasks."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_training_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active_training_tasks([str(sample_training_task.dataset_id)])
assert str(sample_training_task.dataset_id) in result
def test_get_active_training_tasks_empty_input(self, repo):
"""Test get_active_training_tasks with empty input."""
result = repo.get_active_training_tasks([])
assert result == {}
def test_get_active_training_tasks_invalid_uuid(self, repo):
"""Test get_active_training_tasks filters invalid UUIDs."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_active_training_tasks(["invalid-uuid", str(uuid4())])
# Should still query with valid UUID
assert result == {}
def test_get_active_training_tasks_all_invalid_uuids(self, repo):
"""Test get_active_training_tasks with all invalid UUIDs."""
result = repo.get_active_training_tasks(["invalid-uuid-1", "invalid-uuid-2"])
assert result == {}
# =========================================================================
# update_status() tests
# =========================================================================
def test_update_status_updates_dataset(self, repo, sample_dataset):
"""Test update_status updates dataset status."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(sample_dataset.dataset_id), "training")
assert sample_dataset.status == "training"
mock_session.commit.assert_called_once()
def test_update_status_with_error_message(self, repo, sample_dataset):
"""Test update_status with error message."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"failed",
error_message="Training failed",
)
assert sample_dataset.error_message == "Training failed"
def test_update_status_with_totals(self, repo, sample_dataset):
"""Test update_status with total counts."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"ready",
total_documents=200,
total_images=200,
total_annotations=1000,
)
assert sample_dataset.total_documents == 200
assert sample_dataset.total_images == 200
assert sample_dataset.total_annotations == 1000
def test_update_status_with_dataset_path(self, repo, sample_dataset):
"""Test update_status with dataset path."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(
str(sample_dataset.dataset_id),
"ready",
dataset_path="/path/to/dataset",
)
assert sample_dataset.dataset_path == "/path/to/dataset"
def test_update_status_with_uuid(self, repo, sample_dataset):
"""Test update_status works with UUID object."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(sample_dataset.dataset_id, "ready")
assert sample_dataset.status == "ready"
def test_update_status_not_found(self, repo):
"""Test update_status does nothing when dataset not found."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_status(str(uuid4()), "ready")
mock_session.add.assert_not_called()
# =========================================================================
# update_training_status() tests
# =========================================================================
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
"""Test update_training_status updates training status."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(str(sample_dataset.dataset_id), "running")
assert sample_dataset.training_status == "running"
mock_session.commit.assert_called_once()
def test_update_training_status_with_task_id(self, repo, sample_dataset):
"""Test update_training_status with active task ID."""
task_id = uuid4()
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
"running",
active_training_task_id=str(task_id),
)
assert sample_dataset.active_training_task_id == task_id
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
"""Test update_training_status updates main status when completed."""
sample_dataset.status = "ready"
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
"completed",
update_main_status=True,
)
assert sample_dataset.training_status == "completed"
assert sample_dataset.status == "trained"
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
"""Test update_training_status clears task ID when None."""
sample_dataset.active_training_task_id = uuid4()
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(
str(sample_dataset.dataset_id),
None,
active_training_task_id=None,
)
assert sample_dataset.active_training_task_id is None
def test_update_training_status_not_found(self, repo):
"""Test update_training_status does nothing when dataset not found."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.update_training_status(str(uuid4()), "running")
mock_session.add.assert_not_called()
# =========================================================================
# add_documents() tests
# =========================================================================
def test_add_documents_creates_links(self, repo):
"""Test add_documents creates dataset document links."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": str(uuid4()),
"split": "train",
"page_count": 2,
"annotation_count": 10,
},
{
"document_id": str(uuid4()),
"split": "val",
"page_count": 1,
"annotation_count": 5,
},
]
repo.add_documents(str(uuid4()), documents)
assert mock_session.add.call_count == 2
mock_session.commit.assert_called_once()
def test_add_documents_default_counts(self, repo):
"""Test add_documents uses default counts."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": str(uuid4()),
"split": "train",
},
]
repo.add_documents(str(uuid4()), documents)
added_doc = mock_session.add.call_args[0][0]
assert added_doc.page_count == 0
assert added_doc.annotation_count == 0
def test_add_documents_with_uuid(self, repo):
"""Test add_documents works with UUID object."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
documents = [
{
"document_id": uuid4(),
"split": "train",
},
]
repo.add_documents(uuid4(), documents)
mock_session.add.assert_called_once()
def test_add_documents_empty_list(self, repo):
"""Test add_documents with empty list."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
repo.add_documents(str(uuid4()), [])
mock_session.add.assert_not_called()
mock_session.commit.assert_called_once()
# =========================================================================
# get_documents() tests
# =========================================================================
def test_get_documents_returns_list(self, repo, sample_dataset_document):
"""Test get_documents returns list of dataset documents."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(str(sample_dataset_document.dataset_id))
assert len(result) == 1
assert result[0].split == "train"
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
"""Test get_documents works with UUID object."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(sample_dataset_document.dataset_id)
assert len(result) == 1
def test_get_documents_returns_empty_list(self, repo):
"""Test get_documents returns empty list when no documents."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.get_documents(str(uuid4()))
assert result == []
# =========================================================================
# delete() tests
# =========================================================================
def test_delete_returns_true(self, repo, sample_dataset):
"""Test delete returns True when dataset exists."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
mock_session.delete.assert_called_once()
mock_session.commit.assert_called_once()
def test_delete_with_uuid(self, repo, sample_dataset):
"""Test delete works with UUID object."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(sample_dataset.dataset_id)
assert result is True
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when dataset not found."""
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
result = repo.delete(str(uuid4()))
assert result is False
mock_session.delete.assert_not_called()