WIP
This commit is contained in:
386
tests/data/repositories/test_batch_upload_repository.py
Normal file
386
tests/data/repositories/test_batch_upload_repository.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Tests for BatchUploadRepository
|
||||
|
||||
100% coverage tests for batch upload management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadRepository:
|
||||
"""Tests for BatchUploadRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch(self) -> BatchUpload:
|
||||
"""Create a sample batch upload for testing."""
|
||||
return BatchUpload(
|
||||
batch_id=uuid4(),
|
||||
admin_token="admin-token",
|
||||
filename="invoices.zip",
|
||||
file_size=1024000,
|
||||
upload_source="ui",
|
||||
status="pending",
|
||||
total_files=10,
|
||||
processed_files=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(self) -> BatchUploadFile:
|
||||
"""Create a sample batch upload file for testing."""
|
||||
return BatchUploadFile(
|
||||
file_id=uuid4(),
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
status="pending",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> BatchUploadRepository:
|
||||
"""Create a BatchUploadRepository instance."""
|
||||
return BatchUploadRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_batch(self, repo):
|
||||
"""Test create returns created batch upload."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_with_upload_source(self, repo):
|
||||
"""Test create with custom upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
added_batch = mock_session.add.call_args[0][0]
|
||||
assert added_batch.upload_source == "api"
|
||||
|
||||
def test_create_default_upload_source(self, repo):
|
||||
"""Test create uses default upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
added_batch = mock_session.add.call_args[0][0]
|
||||
assert added_batch.upload_source == "ui"
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_batch(self, repo, sample_batch):
|
||||
"""Test get returns batch when exists."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(sample_batch.batch_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.filename == "invoices.zip"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(uuid4())
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# update() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_updates_batch(self, repo, sample_batch):
|
||||
"""Test update updates batch fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
status="processing",
|
||||
processed_files=5,
|
||||
)
|
||||
|
||||
assert sample_batch.status == "processing"
|
||||
assert sample_batch.processed_files == 5
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_ignores_unknown_fields(self, repo, sample_batch):
|
||||
"""Test update ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
unknown_field="should_be_ignored",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_not_found(self, repo):
|
||||
"""Test update does nothing when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(uuid4(), status="processing")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
def test_update_multiple_fields(self, repo, sample_batch):
|
||||
"""Test update can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
status="completed",
|
||||
processed_files=10,
|
||||
total_files=10,
|
||||
)
|
||||
|
||||
assert sample_batch.status == "completed"
|
||||
assert sample_batch.processed_files == 10
|
||||
assert sample_batch.total_files == 10
|
||||
|
||||
# =========================================================================
|
||||
# create_file() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_file_returns_file(self, repo):
|
||||
"""Test create_file returns created file record."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_file(
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_file_with_kwargs(self, repo):
|
||||
"""Test create_file with additional kwargs."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_file(
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
status="processing",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
added_file = mock_session.add.call_args[0][0]
|
||||
assert added_file.filename == "invoice_001.pdf"
|
||||
|
||||
# =========================================================================
|
||||
# update_file() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_file_updates_file(self, repo, sample_file):
|
||||
"""Test update_file updates file fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
assert sample_file.status == "completed"
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
|
||||
"""Test update_file ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
unknown_field="should_be_ignored",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_file_not_found(self, repo):
|
||||
"""Test update_file does nothing when file not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(uuid4(), status="completed")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
def test_update_file_multiple_fields(self, repo, sample_file):
|
||||
"""Test update_file can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
assert sample_file.status == "failed"
|
||||
|
||||
# =========================================================================
|
||||
# get_files() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_files_returns_list(self, repo, sample_file):
|
||||
"""Test get_files returns list of files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_file]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_files(sample_file.batch_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].filename == "invoice_001.pdf"
|
||||
|
||||
def test_get_files_returns_empty_list(self, repo):
|
||||
"""Test get_files returns empty list when no files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_files(uuid4())
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
|
||||
"""Test get_paginated returns list of batches and total count."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated()
|
||||
|
||||
assert len(batches) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_batch):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 100
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated(limit=25, offset=50)
|
||||
|
||||
assert total == 100
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated()
|
||||
|
||||
assert batches == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_paginated_with_admin_token(self, repo, sample_batch):
|
||||
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated(admin_token="admin-token")
|
||||
|
||||
assert len(batches) == 1
|
||||
Reference in New Issue
Block a user