WIP
This commit is contained in:
597
tests/data/repositories/test_dataset_repository.py
Normal file
597
tests/data/repositories/test_dataset_repository.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
Tests for DatasetRepository
|
||||
|
||||
100% coverage tests for dataset management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepository:
|
||||
"""Tests for DatasetRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(self) -> TrainingDataset:
|
||||
"""Create a sample dataset for testing."""
|
||||
return TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="Test Dataset",
|
||||
description="A test dataset",
|
||||
status="ready",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=100,
|
||||
total_images=100,
|
||||
total_annotations=500,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset_document(self) -> DatasetDocument:
|
||||
"""Create a sample dataset document for testing."""
|
||||
return DatasetDocument(
|
||||
id=uuid4(),
|
||||
dataset_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
split="train",
|
||||
page_count=2,
|
||||
annotation_count=10,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_task(self) -> TrainingTask:
|
||||
"""Create a sample training task for testing."""
|
||||
return TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
status="running",
|
||||
dataset_id=uuid4(),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> DatasetRepository:
|
||||
"""Create a DatasetRepository instance."""
|
||||
return DatasetRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_dataset(self, repo):
|
||||
"""Test create returns created dataset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(name="Test Dataset")
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
name="Full Dataset",
|
||||
description="A complete dataset",
|
||||
train_ratio=0.7,
|
||||
val_ratio=0.15,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
added_dataset = mock_session.add.call_args[0][0]
|
||||
assert added_dataset.name == "Full Dataset"
|
||||
assert added_dataset.description == "A complete dataset"
|
||||
assert added_dataset.train_ratio == 0.7
|
||||
assert added_dataset.val_ratio == 0.15
|
||||
assert added_dataset.seed == 123
|
||||
|
||||
def test_create_default_values(self, repo):
|
||||
"""Test create uses default values."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(name="Minimal Dataset")
|
||||
|
||||
added_dataset = mock_session.add.call_args[0][0]
|
||||
assert added_dataset.train_ratio == 0.8
|
||||
assert added_dataset.val_ratio == 0.1
|
||||
assert added_dataset.seed == 42
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_dataset(self, repo, sample_dataset):
|
||||
"""Test get returns dataset when exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Test Dataset"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_dataset):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(sample_dataset.dataset_id)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
|
||||
"""Test get_paginated returns list of datasets and total count."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated()
|
||||
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert len(datasets) == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_dataset):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated(limit=10, offset=20)
|
||||
|
||||
assert total == 50
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated()
|
||||
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
# =========================================================================
|
||||
# get_active_training_tasks() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
|
||||
"""Test get_active_training_tasks returns dict of active tasks."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_training_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_training_task.dataset_id)])
|
||||
|
||||
assert str(sample_training_task.dataset_id) in result
|
||||
|
||||
def test_get_active_training_tasks_empty_input(self, repo):
|
||||
"""Test get_active_training_tasks with empty input."""
|
||||
result = repo.get_active_training_tasks([])
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_get_active_training_tasks_invalid_uuid(self, repo):
|
||||
"""Test get_active_training_tasks filters invalid UUIDs."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active_training_tasks(["invalid-uuid", str(uuid4())])
|
||||
|
||||
# Should still query with valid UUID
|
||||
assert result == {}
|
||||
|
||||
def test_get_active_training_tasks_all_invalid_uuids(self, repo):
|
||||
"""Test get_active_training_tasks with all invalid UUIDs."""
|
||||
result = repo.get_active_training_tasks(["invalid-uuid-1", "invalid-uuid-2"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
# =========================================================================
|
||||
# update_status() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_status updates dataset status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_dataset.dataset_id), "training")
|
||||
|
||||
assert sample_dataset.status == "training"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_status_with_error_message(self, repo, sample_dataset):
|
||||
"""Test update_status with error message."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"failed",
|
||||
error_message="Training failed",
|
||||
)
|
||||
|
||||
assert sample_dataset.error_message == "Training failed"
|
||||
|
||||
def test_update_status_with_totals(self, repo, sample_dataset):
|
||||
"""Test update_status with total counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"ready",
|
||||
total_documents=200,
|
||||
total_images=200,
|
||||
total_annotations=1000,
|
||||
)
|
||||
|
||||
assert sample_dataset.total_documents == 200
|
||||
assert sample_dataset.total_images == 200
|
||||
assert sample_dataset.total_annotations == 1000
|
||||
|
||||
def test_update_status_with_dataset_path(self, repo, sample_dataset):
|
||||
"""Test update_status with dataset path."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"ready",
|
||||
dataset_path="/path/to/dataset",
|
||||
)
|
||||
|
||||
assert sample_dataset.dataset_path == "/path/to/dataset"
|
||||
|
||||
def test_update_status_with_uuid(self, repo, sample_dataset):
|
||||
"""Test update_status works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(sample_dataset.dataset_id, "ready")
|
||||
|
||||
assert sample_dataset.status == "ready"
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(uuid4()), "ready")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# update_training_status() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates training status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(str(sample_dataset.dataset_id), "running")
|
||||
|
||||
assert sample_dataset.training_status == "running"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_training_status_with_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status with active task ID."""
|
||||
task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"running",
|
||||
active_training_task_id=str(task_id),
|
||||
)
|
||||
|
||||
assert sample_dataset.active_training_task_id == task_id
|
||||
|
||||
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates main status when completed."""
|
||||
sample_dataset.status = "ready"
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
assert sample_dataset.training_status == "completed"
|
||||
assert sample_dataset.status == "trained"
|
||||
|
||||
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status clears task ID when None."""
|
||||
sample_dataset.active_training_task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
None,
|
||||
active_training_task_id=None,
|
||||
)
|
||||
|
||||
assert sample_dataset.active_training_task_id is None
|
||||
|
||||
def test_update_training_status_not_found(self, repo):
|
||||
"""Test update_training_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(str(uuid4()), "running")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# add_documents() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_add_documents_creates_links(self, repo):
|
||||
"""Test add_documents creates dataset document links."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 10,
|
||||
},
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(uuid4()), documents)
|
||||
|
||||
assert mock_session.add.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_add_documents_default_counts(self, repo):
|
||||
"""Test add_documents uses default counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "train",
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(uuid4()), documents)
|
||||
|
||||
added_doc = mock_session.add.call_args[0][0]
|
||||
assert added_doc.page_count == 0
|
||||
assert added_doc.annotation_count == 0
|
||||
|
||||
def test_add_documents_with_uuid(self, repo):
|
||||
"""Test add_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": uuid4(),
|
||||
"split": "train",
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(uuid4(), documents)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_add_documents_empty_list(self, repo):
|
||||
"""Test add_documents with empty list."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.add_documents(str(uuid4()), [])
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# get_documents() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_documents_returns_list(self, repo, sample_dataset_document):
|
||||
"""Test get_documents returns list of dataset documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(str(sample_dataset_document.dataset_id))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].split == "train"
|
||||
|
||||
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
|
||||
"""Test get_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(sample_dataset_document.dataset_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_documents_returns_empty_list(self, repo):
|
||||
"""Test get_documents returns empty list when no documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(str(uuid4()))
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# delete() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_dataset):
|
||||
"""Test delete returns True when dataset exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_with_uuid(self, repo, sample_dataset):
|
||||
"""Test delete works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(sample_dataset.dataset_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
Reference in New Issue
Block a user