""" Tests for DatasetRepository 100% coverage tests for dataset management. """ import pytest from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4, UUID from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask from backend.data.repositories.dataset_repository import DatasetRepository class TestDatasetRepository: """Tests for DatasetRepository.""" @pytest.fixture def sample_dataset(self) -> TrainingDataset: """Create a sample dataset for testing.""" return TrainingDataset( dataset_id=uuid4(), name="Test Dataset", description="A test dataset", status="ready", train_ratio=0.8, val_ratio=0.1, seed=42, total_documents=100, total_images=100, total_annotations=500, created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_dataset_document(self) -> DatasetDocument: """Create a sample dataset document for testing.""" return DatasetDocument( id=uuid4(), dataset_id=uuid4(), document_id=uuid4(), split="train", page_count=2, annotation_count=10, created_at=datetime.now(timezone.utc), ) @pytest.fixture def sample_training_task(self) -> TrainingTask: """Create a sample training task for testing.""" return TrainingTask( task_id=uuid4(), admin_token="admin-token", name="Test Task", status="running", dataset_id=uuid4(), ) @pytest.fixture def repo(self) -> DatasetRepository: """Create a DatasetRepository instance.""" return DatasetRepository() # ========================================================================= # create() tests # ========================================================================= def test_create_returns_dataset(self, repo): """Test create returns created dataset.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create(name="Test Dataset") mock_session.add.assert_called_once() mock_session.commit.assert_called_once() def test_create_with_all_params(self, repo): """Test create with all parameters.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.create( name="Full Dataset", description="A complete dataset", train_ratio=0.7, val_ratio=0.15, seed=123, ) added_dataset = mock_session.add.call_args[0][0] assert added_dataset.name == "Full Dataset" assert added_dataset.description == "A complete dataset" assert added_dataset.train_ratio == 0.7 assert added_dataset.val_ratio == 0.15 assert added_dataset.seed == 123 def test_create_default_values(self, repo): """Test create uses default values.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.create(name="Minimal Dataset") added_dataset = mock_session.add.call_args[0][0] assert added_dataset.train_ratio == 0.8 assert added_dataset.val_ratio == 0.1 assert added_dataset.seed == 42 # ========================================================================= # get() tests # ========================================================================= def test_get_returns_dataset(self, repo, sample_dataset): """Test get returns dataset when exists.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(sample_dataset.dataset_id)) assert result is not None assert result.name == "Test Dataset" mock_session.expunge.assert_called_once() def test_get_with_uuid(self, repo, sample_dataset): """Test get works with UUID object.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(sample_dataset.dataset_id) assert result is not None def test_get_returns_none_when_not_found(self, repo): """Test get returns None when dataset not found.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get(str(uuid4())) assert result is None mock_session.expunge.assert_not_called() # ========================================================================= # get_paginated() tests # ========================================================================= def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset): """Test get_paginated returns list of datasets and total count.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_dataset] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) datasets, total = repo.get_paginated() assert len(datasets) == 1 assert total == 1 def test_get_paginated_with_status_filter(self, repo, sample_dataset): """Test get_paginated filters by status.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 1 mock_session.exec.return_value.all.return_value = [sample_dataset] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) datasets, total = repo.get_paginated(status="ready") assert len(datasets) == 1 def test_get_paginated_with_pagination(self, repo, sample_dataset): """Test get_paginated with limit and offset.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 50 mock_session.exec.return_value.all.return_value = [sample_dataset] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) datasets, total = repo.get_paginated(limit=10, offset=20) assert total == 50 def test_get_paginated_empty_results(self, repo): """Test get_paginated with no results.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.one.return_value = 0 mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) datasets, total = repo.get_paginated() assert datasets == [] assert total == 0 # ========================================================================= # get_active_training_tasks() tests # ========================================================================= def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task): """Test get_active_training_tasks returns dict of active tasks.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_training_task] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_active_training_tasks([str(sample_training_task.dataset_id)]) assert str(sample_training_task.dataset_id) in result def test_get_active_training_tasks_empty_input(self, repo): """Test get_active_training_tasks with empty input.""" result = repo.get_active_training_tasks([]) assert result == {} def test_get_active_training_tasks_invalid_uuid(self, repo): """Test get_active_training_tasks filters invalid UUIDs.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_active_training_tasks(["invalid-uuid", str(uuid4())]) # Should still query with valid UUID assert result == {} def test_get_active_training_tasks_all_invalid_uuids(self, repo): """Test get_active_training_tasks with all invalid UUIDs.""" result = repo.get_active_training_tasks(["invalid-uuid-1", "invalid-uuid-2"]) assert result == {} # ========================================================================= # update_status() tests # ========================================================================= def test_update_status_updates_dataset(self, repo, sample_dataset): """Test update_status updates dataset status.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(sample_dataset.dataset_id), "training") assert sample_dataset.status == "training" mock_session.commit.assert_called_once() def test_update_status_with_error_message(self, repo, sample_dataset): """Test update_status with error message.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_dataset.dataset_id), "failed", error_message="Training failed", ) assert sample_dataset.error_message == "Training failed" def test_update_status_with_totals(self, repo, sample_dataset): """Test update_status with total counts.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_dataset.dataset_id), "ready", total_documents=200, total_images=200, total_annotations=1000, ) assert sample_dataset.total_documents == 200 assert sample_dataset.total_images == 200 assert sample_dataset.total_annotations == 1000 def test_update_status_with_dataset_path(self, repo, sample_dataset): """Test update_status with dataset path.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status( str(sample_dataset.dataset_id), "ready", dataset_path="/path/to/dataset", ) assert sample_dataset.dataset_path == "/path/to/dataset" def test_update_status_with_uuid(self, repo, sample_dataset): """Test update_status works with UUID object.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(sample_dataset.dataset_id, "ready") assert sample_dataset.status == "ready" def test_update_status_not_found(self, repo): """Test update_status does nothing when dataset not found.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_status(str(uuid4()), "ready") mock_session.add.assert_not_called() # ========================================================================= # update_training_status() tests # ========================================================================= def test_update_training_status_updates_dataset(self, repo, sample_dataset): """Test update_training_status updates training status.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_training_status(str(sample_dataset.dataset_id), "running") assert sample_dataset.training_status == "running" mock_session.commit.assert_called_once() def test_update_training_status_with_task_id(self, repo, sample_dataset): """Test update_training_status with active task ID.""" task_id = uuid4() with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_training_status( str(sample_dataset.dataset_id), "running", active_training_task_id=str(task_id), ) assert sample_dataset.active_training_task_id == task_id def test_update_training_status_updates_main_status(self, repo, sample_dataset): """Test update_training_status updates main status when completed.""" sample_dataset.status = "ready" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_training_status( str(sample_dataset.dataset_id), "completed", update_main_status=True, ) assert sample_dataset.training_status == "completed" assert sample_dataset.status == "trained" def test_update_training_status_clears_task_id(self, repo, sample_dataset): """Test update_training_status clears task ID when None.""" sample_dataset.active_training_task_id = uuid4() with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_training_status( str(sample_dataset.dataset_id), None, active_training_task_id=None, ) assert sample_dataset.active_training_task_id is None def test_update_training_status_not_found(self, repo): """Test update_training_status does nothing when dataset not found.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.update_training_status(str(uuid4()), "running") mock_session.add.assert_not_called() # ========================================================================= # add_documents() tests # ========================================================================= def test_add_documents_creates_links(self, repo): """Test add_documents creates dataset document links.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) documents = [ { "document_id": str(uuid4()), "split": "train", "page_count": 2, "annotation_count": 10, }, { "document_id": str(uuid4()), "split": "val", "page_count": 1, "annotation_count": 5, }, ] repo.add_documents(str(uuid4()), documents) assert mock_session.add.call_count == 2 mock_session.commit.assert_called_once() def test_add_documents_default_counts(self, repo): """Test add_documents uses default counts.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) documents = [ { "document_id": str(uuid4()), "split": "train", }, ] repo.add_documents(str(uuid4()), documents) added_doc = mock_session.add.call_args[0][0] assert added_doc.page_count == 0 assert added_doc.annotation_count == 0 def test_add_documents_with_uuid(self, repo): """Test add_documents works with UUID object.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) documents = [ { "document_id": uuid4(), "split": "train", }, ] repo.add_documents(uuid4(), documents) mock_session.add.assert_called_once() def test_add_documents_empty_list(self, repo): """Test add_documents with empty list.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) repo.add_documents(str(uuid4()), []) mock_session.add.assert_not_called() mock_session.commit.assert_called_once() # ========================================================================= # get_documents() tests # ========================================================================= def test_get_documents_returns_list(self, repo, sample_dataset_document): """Test get_documents returns list of dataset documents.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_dataset_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_documents(str(sample_dataset_document.dataset_id)) assert len(result) == 1 assert result[0].split == "train" def test_get_documents_with_uuid(self, repo, sample_dataset_document): """Test get_documents works with UUID object.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [sample_dataset_document] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_documents(sample_dataset_document.dataset_id) assert len(result) == 1 def test_get_documents_returns_empty_list(self, repo): """Test get_documents returns empty list when no documents.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.exec.return_value.all.return_value = [] mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.get_documents(str(uuid4())) assert result == [] # ========================================================================= # delete() tests # ========================================================================= def test_delete_returns_true(self, repo, sample_dataset): """Test delete returns True when dataset exists.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(sample_dataset.dataset_id)) assert result is True mock_session.delete.assert_called_once() mock_session.commit.assert_called_once() def test_delete_with_uuid(self, repo, sample_dataset): """Test delete works with UUID object.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = sample_dataset mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(sample_dataset.dataset_id) assert result is True def test_delete_returns_false_when_not_found(self, repo): """Test delete returns False when dataset not found.""" with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx: mock_session = MagicMock() mock_session.get.return_value = None mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) result = repo.delete(str(uuid4())) assert result is False mock_session.delete.assert_not_called()