WIP
This commit is contained in:
1
tests/data/repositories/__init__.py
Normal file
1
tests/data/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for repository pattern implementation."""
|
||||
711
tests/data/repositories/test_annotation_repository.py
Normal file
711
tests/data/repositories/test_annotation_repository.py
Normal file
@@ -0,0 +1,711 @@
|
||||
"""
|
||||
Tests for AnnotationRepository
|
||||
|
||||
100% coverage tests for annotation management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepository:
|
||||
"""Tests for AnnotationRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_annotation(self) -> AdminAnnotation:
|
||||
"""Create a sample annotation for testing."""
|
||||
return AdminAnnotation(
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
text_value="INV-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_history(self) -> AnnotationHistory:
|
||||
"""Create a sample annotation history for testing."""
|
||||
return AnnotationHistory(
|
||||
history_id=uuid4(),
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
action="override",
|
||||
previous_value={"class_name": "old_class"},
|
||||
new_value={"class_name": "new_class"},
|
||||
changed_by="admin-token",
|
||||
change_reason="Correction",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> AnnotationRepository:
|
||||
"""Create an AnnotationRepository instance."""
|
||||
return AnnotationRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_annotation_id(self, repo):
|
||||
"""Test create returns annotation ID."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_create_with_optional_params(self, repo):
|
||||
"""Test create with optional text_value and confidence."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=2,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=200,
|
||||
bbox_y=300,
|
||||
bbox_width=100,
|
||||
bbox_height=25,
|
||||
text_value="2024-01-15",
|
||||
confidence=0.88,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.text_value == "2024-01-15"
|
||||
assert added_annotation.confidence == 0.88
|
||||
assert added_annotation.source == "auto"
|
||||
|
||||
def test_create_default_source_is_manual(self, repo):
|
||||
"""Test create uses manual as default source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
document_id=str(uuid4()),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=100,
|
||||
bbox_y=200,
|
||||
bbox_width=150,
|
||||
bbox_height=30,
|
||||
)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.source == "manual"
|
||||
|
||||
# =========================================================================
|
||||
# create_batch() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_batch_returns_ids(self, repo):
|
||||
"""Test create_batch returns list of annotation IDs."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
},
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.6,
|
||||
"y_center": 0.4,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 200,
|
||||
"bbox_y": 300,
|
||||
"bbox_width": 100,
|
||||
"bbox_height": 25,
|
||||
},
|
||||
]
|
||||
|
||||
result = repo.create_batch(annotations)
|
||||
|
||||
assert len(result) == 2
|
||||
assert mock_session.add.call_count == 2
|
||||
assert mock_session.flush.call_count == 2
|
||||
|
||||
def test_create_batch_default_page_number(self, repo):
|
||||
"""Test create_batch uses page_number=1 by default."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
# no page_number
|
||||
},
|
||||
]
|
||||
|
||||
repo.create_batch(annotations)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.page_number == 1
|
||||
|
||||
def test_create_batch_with_all_optional_params(self, repo):
|
||||
"""Test create_batch with all optional parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotations = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"page_number": 3,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.3,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 100,
|
||||
"bbox_y": 200,
|
||||
"bbox_width": 150,
|
||||
"bbox_height": 30,
|
||||
"text_value": "INV-123",
|
||||
"confidence": 0.92,
|
||||
"source": "ocr",
|
||||
},
|
||||
]
|
||||
|
||||
repo.create_batch(annotations)
|
||||
|
||||
added_annotation = mock_session.add.call_args[0][0]
|
||||
assert added_annotation.page_number == 3
|
||||
assert added_annotation.text_value == "INV-123"
|
||||
assert added_annotation.confidence == 0.92
|
||||
assert added_annotation.source == "ocr"
|
||||
|
||||
def test_create_batch_empty_list(self, repo):
|
||||
"""Test create_batch with empty list returns empty."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_batch([])
|
||||
|
||||
assert result == []
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_annotation(self, repo, sample_annotation):
|
||||
"""Test get returns annotation when exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_annotation.annotation_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.class_name == "invoice_number"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_for_document() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
|
||||
"""Test get_for_document returns all annotations for document."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(sample_annotation.document_id))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].class_name == "invoice_number"
|
||||
|
||||
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
|
||||
"""Test get_for_document filters by page number."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(sample_annotation.document_id), page_number=1)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_for_document_returns_empty_list(self, repo):
|
||||
"""Test get_for_document returns empty list when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_for_document(str(uuid4()))
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# update() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_returns_true(self, repo, sample_annotation):
|
||||
"""Test update returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.text_value == "INV-002"
|
||||
|
||||
def test_update_returns_false_when_not_found(self, repo):
|
||||
"""Test update returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(str(uuid4()), text_value="INV-002")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_update_all_fields(self, repo, sample_annotation):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=150,
|
||||
bbox_y=250,
|
||||
bbox_width=175,
|
||||
bbox_height=35,
|
||||
text_value="NEW-VALUE",
|
||||
class_id=5,
|
||||
class_name="new_class",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.x_center == 0.6
|
||||
assert sample_annotation.y_center == 0.4
|
||||
assert sample_annotation.width == 0.25
|
||||
assert sample_annotation.height == 0.06
|
||||
assert sample_annotation.bbox_x == 150
|
||||
assert sample_annotation.bbox_y == 250
|
||||
assert sample_annotation.bbox_width == 175
|
||||
assert sample_annotation.bbox_height == 35
|
||||
assert sample_annotation.text_value == "NEW-VALUE"
|
||||
assert sample_annotation.class_id == 5
|
||||
assert sample_annotation.class_name == "new_class"
|
||||
|
||||
def test_update_partial_fields(self, repo, sample_annotation):
|
||||
"""Test update only updates provided fields."""
|
||||
original_x = sample_annotation.x_center
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="UPDATED",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
assert sample_annotation.text_value == "UPDATED"
|
||||
assert sample_annotation.x_center == original_x # unchanged
|
||||
|
||||
# =========================================================================
|
||||
# delete() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_annotation):
|
||||
"""Test delete returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# delete_for_document() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_for_document_returns_count(self, repo, sample_annotation):
|
||||
"""Test delete_for_document returns count of deleted annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(sample_annotation.document_id))
|
||||
|
||||
assert result == 1
|
||||
mock_session.delete.assert_called_once()
|
||||
|
||||
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
|
||||
"""Test delete_for_document filters by source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(sample_annotation.document_id), source="auto")
|
||||
|
||||
assert result == 1
|
||||
|
||||
def test_delete_for_document_returns_zero(self, repo):
|
||||
"""Test delete_for_document returns 0 when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete_for_document(str(uuid4()))
|
||||
|
||||
assert result == 0
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# verify() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
|
||||
"""Test verify marks annotation as verified."""
|
||||
sample_annotation.is_verified = False
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.verify(str(sample_annotation.annotation_id), "admin-token")
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.is_verified is True
|
||||
assert sample_annotation.verified_by == "admin-token"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_verify_returns_none_when_not_found(self, repo):
|
||||
"""Test verify returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.verify(str(uuid4()), "admin-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
# =========================================================================
|
||||
# override() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_override_updates_annotation(self, repo, sample_annotation):
|
||||
"""Test override updates annotation and creates history."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
change_reason="Correction",
|
||||
text_value="NEW-VALUE",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.text_value == "NEW-VALUE"
|
||||
assert sample_annotation.source == "manual"
|
||||
assert sample_annotation.override_source == "auto"
|
||||
assert mock_session.add.call_count >= 2 # annotation + history
|
||||
|
||||
def test_override_returns_none_when_not_found(self, repo):
|
||||
"""Test override returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(str(uuid4()), "admin-token", text_value="NEW")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_override_does_not_change_source_if_already_manual(self, repo, sample_annotation):
|
||||
"""Test override does not change override_source if already manual."""
|
||||
sample_annotation.source = "manual"
|
||||
sample_annotation.override_source = None
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
text_value="NEW-VALUE",
|
||||
)
|
||||
|
||||
assert sample_annotation.source == "manual"
|
||||
assert sample_annotation.override_source is None
|
||||
|
||||
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
|
||||
"""Test override ignores unknown attributes."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
"admin-token",
|
||||
unknown_field="should_be_ignored",
|
||||
text_value="VALID",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert sample_annotation.text_value == "VALID"
|
||||
assert not hasattr(sample_annotation, "unknown_field") or getattr(sample_annotation, "unknown_field", None) != "should_be_ignored"
|
||||
|
||||
# =========================================================================
|
||||
# create_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_history_returns_history(self, repo):
|
||||
"""Test create_history returns created history record."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
annotation_id = uuid4()
|
||||
document_id = uuid4()
|
||||
result = repo.create_history(
|
||||
annotation_id=annotation_id,
|
||||
document_id=document_id,
|
||||
action="create",
|
||||
previous_value=None,
|
||||
new_value={"class_name": "invoice_number"},
|
||||
changed_by="admin-token",
|
||||
change_reason="Initial creation",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_history_with_minimal_params(self, repo):
|
||||
"""Test create_history with minimal parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
action="delete",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
added_history = mock_session.add.call_args[0][0]
|
||||
assert added_history.action == "delete"
|
||||
assert added_history.previous_value is None
|
||||
assert added_history.new_value is None
|
||||
|
||||
# =========================================================================
|
||||
# get_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_history(sample_history.annotation_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].action == "override"
|
||||
|
||||
def test_get_history_returns_empty_list(self, repo):
|
||||
"""Test get_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_history(uuid4())
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# get_document_history() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_document_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_document_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_history(sample_history.document_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_document_history_returns_empty_list(self, repo):
|
||||
"""Test get_document_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_history(uuid4())
|
||||
|
||||
assert result == []
|
||||
142
tests/data/repositories/test_base_repository.py
Normal file
142
tests/data/repositories/test_base_repository.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Tests for BaseRepository
|
||||
|
||||
100% coverage tests for base repository utilities.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ConcreteRepository(BaseRepository[MagicMock]):
|
||||
"""Concrete implementation for testing abstract base class."""
|
||||
pass
|
||||
|
||||
|
||||
class TestBaseRepository:
|
||||
"""Tests for BaseRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> ConcreteRepository:
|
||||
"""Create a ConcreteRepository instance."""
|
||||
return ConcreteRepository()
|
||||
|
||||
# =========================================================================
|
||||
# _session() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_session_yields_session(self, repo):
|
||||
"""Test _session yields a database session."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with repo._session() as session:
|
||||
assert session is mock_session
|
||||
|
||||
# =========================================================================
|
||||
# _expunge() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_expunge_detaches_entity(self, repo):
|
||||
"""Test _expunge detaches entity from session."""
|
||||
mock_session = MagicMock()
|
||||
mock_entity = MagicMock()
|
||||
|
||||
result = repo._expunge(mock_session, mock_entity)
|
||||
|
||||
mock_session.expunge.assert_called_once_with(mock_entity)
|
||||
assert result is mock_entity
|
||||
|
||||
# =========================================================================
|
||||
# _expunge_all() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_expunge_all_detaches_all_entities(self, repo):
|
||||
"""Test _expunge_all detaches all entities from session."""
|
||||
mock_session = MagicMock()
|
||||
mock_entity1 = MagicMock()
|
||||
mock_entity2 = MagicMock()
|
||||
entities = [mock_entity1, mock_entity2]
|
||||
|
||||
result = repo._expunge_all(mock_session, entities)
|
||||
|
||||
assert mock_session.expunge.call_count == 2
|
||||
mock_session.expunge.assert_any_call(mock_entity1)
|
||||
mock_session.expunge.assert_any_call(mock_entity2)
|
||||
assert result is entities
|
||||
|
||||
def test_expunge_all_empty_list(self, repo):
|
||||
"""Test _expunge_all with empty list."""
|
||||
mock_session = MagicMock()
|
||||
entities = []
|
||||
|
||||
result = repo._expunge_all(mock_session, entities)
|
||||
|
||||
mock_session.expunge.assert_not_called()
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# _now() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_now_returns_utc_datetime(self, repo):
|
||||
"""Test _now returns timezone-aware UTC datetime."""
|
||||
result = repo._now()
|
||||
|
||||
assert result.tzinfo == timezone.utc
|
||||
assert isinstance(result, datetime)
|
||||
|
||||
def test_now_is_recent(self, repo):
|
||||
"""Test _now returns a recent datetime."""
|
||||
before = datetime.now(timezone.utc)
|
||||
result = repo._now()
|
||||
after = datetime.now(timezone.utc)
|
||||
|
||||
assert before <= result <= after
|
||||
|
||||
# =========================================================================
|
||||
# _validate_uuid() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_validate_uuid_with_valid_string(self, repo):
|
||||
"""Test _validate_uuid with valid UUID string."""
|
||||
valid_uuid_str = str(uuid4())
|
||||
|
||||
result = repo._validate_uuid(valid_uuid_str)
|
||||
|
||||
assert isinstance(result, UUID)
|
||||
assert str(result) == valid_uuid_str
|
||||
|
||||
def test_validate_uuid_with_invalid_string(self, repo):
|
||||
"""Test _validate_uuid raises ValueError for invalid UUID."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
repo._validate_uuid("not-a-valid-uuid")
|
||||
|
||||
assert "Invalid id" in str(exc_info.value)
|
||||
|
||||
def test_validate_uuid_with_custom_field_name(self, repo):
|
||||
"""Test _validate_uuid uses custom field name in error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
repo._validate_uuid("invalid", field_name="document_id")
|
||||
|
||||
assert "Invalid document_id" in str(exc_info.value)
|
||||
|
||||
def test_validate_uuid_with_none(self, repo):
|
||||
"""Test _validate_uuid raises ValueError for None."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
repo._validate_uuid(None)
|
||||
|
||||
assert "Invalid id" in str(exc_info.value)
|
||||
|
||||
def test_validate_uuid_with_empty_string(self, repo):
|
||||
"""Test _validate_uuid raises ValueError for empty string."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
repo._validate_uuid("")
|
||||
|
||||
assert "Invalid id" in str(exc_info.value)
|
||||
386
tests/data/repositories/test_batch_upload_repository.py
Normal file
386
tests/data/repositories/test_batch_upload_repository.py
Normal file
@@ -0,0 +1,386 @@
|
||||
"""
|
||||
Tests for BatchUploadRepository
|
||||
|
||||
100% coverage tests for batch upload management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadRepository:
|
||||
"""Tests for BatchUploadRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch(self) -> BatchUpload:
|
||||
"""Create a sample batch upload for testing."""
|
||||
return BatchUpload(
|
||||
batch_id=uuid4(),
|
||||
admin_token="admin-token",
|
||||
filename="invoices.zip",
|
||||
file_size=1024000,
|
||||
upload_source="ui",
|
||||
status="pending",
|
||||
total_files=10,
|
||||
processed_files=0,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file(self) -> BatchUploadFile:
|
||||
"""Create a sample batch upload file for testing."""
|
||||
return BatchUploadFile(
|
||||
file_id=uuid4(),
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
status="pending",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> BatchUploadRepository:
|
||||
"""Create a BatchUploadRepository instance."""
|
||||
return BatchUploadRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_batch(self, repo):
|
||||
"""Test create returns created batch upload."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_with_upload_source(self, repo):
|
||||
"""Test create with custom upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
added_batch = mock_session.add.call_args[0][0]
|
||||
assert added_batch.upload_source == "api"
|
||||
|
||||
def test_create_default_upload_source(self, repo):
|
||||
"""Test create uses default upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
filename="test.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
added_batch = mock_session.add.call_args[0][0]
|
||||
assert added_batch.upload_source == "ui"
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_batch(self, repo, sample_batch):
|
||||
"""Test get returns batch when exists."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(sample_batch.batch_id)
|
||||
|
||||
assert result is not None
|
||||
assert result.filename == "invoices.zip"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(uuid4())
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# update() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_updates_batch(self, repo, sample_batch):
|
||||
"""Test update updates batch fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
status="processing",
|
||||
processed_files=5,
|
||||
)
|
||||
|
||||
assert sample_batch.status == "processing"
|
||||
assert sample_batch.processed_files == 5
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_ignores_unknown_fields(self, repo, sample_batch):
|
||||
"""Test update ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
unknown_field="should_be_ignored",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_not_found(self, repo):
|
||||
"""Test update does nothing when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(uuid4(), status="processing")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
def test_update_multiple_fields(self, repo, sample_batch):
|
||||
"""Test update can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update(
|
||||
sample_batch.batch_id,
|
||||
status="completed",
|
||||
processed_files=10,
|
||||
total_files=10,
|
||||
)
|
||||
|
||||
assert sample_batch.status == "completed"
|
||||
assert sample_batch.processed_files == 10
|
||||
assert sample_batch.total_files == 10
|
||||
|
||||
# =========================================================================
|
||||
# create_file() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_file_returns_file(self, repo):
|
||||
"""Test create_file returns created file record."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_file(
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_file_with_kwargs(self, repo):
|
||||
"""Test create_file with additional kwargs."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create_file(
|
||||
batch_id=uuid4(),
|
||||
filename="invoice_001.pdf",
|
||||
status="processing",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
added_file = mock_session.add.call_args[0][0]
|
||||
assert added_file.filename == "invoice_001.pdf"
|
||||
|
||||
# =========================================================================
|
||||
# update_file() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_file_updates_file(self, repo, sample_file):
|
||||
"""Test update_file updates file fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
status="completed",
|
||||
)
|
||||
|
||||
assert sample_file.status == "completed"
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
|
||||
"""Test update_file ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
unknown_field="should_be_ignored",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_file_not_found(self, repo):
|
||||
"""Test update_file does nothing when file not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(uuid4(), status="completed")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
def test_update_file_multiple_fields(self, repo, sample_file):
|
||||
"""Test update_file can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file(
|
||||
sample_file.file_id,
|
||||
status="failed",
|
||||
)
|
||||
|
||||
assert sample_file.status == "failed"
|
||||
|
||||
# =========================================================================
|
||||
# get_files() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_files_returns_list(self, repo, sample_file):
|
||||
"""Test get_files returns list of files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_file]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_files(sample_file.batch_id)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].filename == "invoice_001.pdf"
|
||||
|
||||
def test_get_files_returns_empty_list(self, repo):
|
||||
"""Test get_files returns empty list when no files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_files(uuid4())
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
|
||||
"""Test get_paginated returns list of batches and total count."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated()
|
||||
|
||||
assert len(batches) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_batch):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 100
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated(limit=25, offset=50)
|
||||
|
||||
assert total == 100
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated()
|
||||
|
||||
assert batches == []
|
||||
assert total == 0
|
||||
|
||||
def test_get_paginated_with_admin_token(self, repo, sample_batch):
|
||||
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
batches, total = repo.get_paginated(admin_token="admin-token")
|
||||
|
||||
assert len(batches) == 1
|
||||
597
tests/data/repositories/test_dataset_repository.py
Normal file
597
tests/data/repositories/test_dataset_repository.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""
|
||||
Tests for DatasetRepository
|
||||
|
||||
100% coverage tests for dataset management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepository:
|
||||
"""Tests for DatasetRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset(self) -> TrainingDataset:
|
||||
"""Create a sample dataset for testing."""
|
||||
return TrainingDataset(
|
||||
dataset_id=uuid4(),
|
||||
name="Test Dataset",
|
||||
description="A test dataset",
|
||||
status="ready",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
total_documents=100,
|
||||
total_images=100,
|
||||
total_annotations=500,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_dataset_document(self) -> DatasetDocument:
|
||||
"""Create a sample dataset document for testing."""
|
||||
return DatasetDocument(
|
||||
id=uuid4(),
|
||||
dataset_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
split="train",
|
||||
page_count=2,
|
||||
annotation_count=10,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_training_task(self) -> TrainingTask:
|
||||
"""Create a sample training task for testing."""
|
||||
return TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
status="running",
|
||||
dataset_id=uuid4(),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> DatasetRepository:
|
||||
"""Create a DatasetRepository instance."""
|
||||
return DatasetRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_dataset(self, repo):
|
||||
"""Test create returns created dataset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(name="Test Dataset")
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
name="Full Dataset",
|
||||
description="A complete dataset",
|
||||
train_ratio=0.7,
|
||||
val_ratio=0.15,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
added_dataset = mock_session.add.call_args[0][0]
|
||||
assert added_dataset.name == "Full Dataset"
|
||||
assert added_dataset.description == "A complete dataset"
|
||||
assert added_dataset.train_ratio == 0.7
|
||||
assert added_dataset.val_ratio == 0.15
|
||||
assert added_dataset.seed == 123
|
||||
|
||||
def test_create_default_values(self, repo):
|
||||
"""Test create uses default values."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(name="Minimal Dataset")
|
||||
|
||||
added_dataset = mock_session.add.call_args[0][0]
|
||||
assert added_dataset.train_ratio == 0.8
|
||||
assert added_dataset.val_ratio == 0.1
|
||||
assert added_dataset.seed == 42
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_dataset(self, repo, sample_dataset):
|
||||
"""Test get returns dataset when exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Test Dataset"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_dataset):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(sample_dataset.dataset_id)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
|
||||
"""Test get_paginated returns list of datasets and total count."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated()
|
||||
|
||||
assert len(datasets) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert len(datasets) == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_dataset):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated(limit=10, offset=20)
|
||||
|
||||
assert total == 50
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
datasets, total = repo.get_paginated()
|
||||
|
||||
assert datasets == []
|
||||
assert total == 0
|
||||
|
||||
# =========================================================================
|
||||
# get_active_training_tasks() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
|
||||
"""Test get_active_training_tasks returns dict of active tasks."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_training_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_training_task.dataset_id)])
|
||||
|
||||
assert str(sample_training_task.dataset_id) in result
|
||||
|
||||
def test_get_active_training_tasks_empty_input(self, repo):
|
||||
"""Test get_active_training_tasks with empty input."""
|
||||
result = repo.get_active_training_tasks([])
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_get_active_training_tasks_invalid_uuid(self, repo):
|
||||
"""Test get_active_training_tasks filters invalid UUIDs."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active_training_tasks(["invalid-uuid", str(uuid4())])
|
||||
|
||||
# Should still query with valid UUID
|
||||
assert result == {}
|
||||
|
||||
def test_get_active_training_tasks_all_invalid_uuids(self, repo):
|
||||
"""Test get_active_training_tasks with all invalid UUIDs."""
|
||||
result = repo.get_active_training_tasks(["invalid-uuid-1", "invalid-uuid-2"])
|
||||
|
||||
assert result == {}
|
||||
|
||||
# =========================================================================
|
||||
# update_status() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_status updates dataset status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_dataset.dataset_id), "training")
|
||||
|
||||
assert sample_dataset.status == "training"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_status_with_error_message(self, repo, sample_dataset):
|
||||
"""Test update_status with error message."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"failed",
|
||||
error_message="Training failed",
|
||||
)
|
||||
|
||||
assert sample_dataset.error_message == "Training failed"
|
||||
|
||||
def test_update_status_with_totals(self, repo, sample_dataset):
|
||||
"""Test update_status with total counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"ready",
|
||||
total_documents=200,
|
||||
total_images=200,
|
||||
total_annotations=1000,
|
||||
)
|
||||
|
||||
assert sample_dataset.total_documents == 200
|
||||
assert sample_dataset.total_images == 200
|
||||
assert sample_dataset.total_annotations == 1000
|
||||
|
||||
def test_update_status_with_dataset_path(self, repo, sample_dataset):
|
||||
"""Test update_status with dataset path."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"ready",
|
||||
dataset_path="/path/to/dataset",
|
||||
)
|
||||
|
||||
assert sample_dataset.dataset_path == "/path/to/dataset"
|
||||
|
||||
def test_update_status_with_uuid(self, repo, sample_dataset):
|
||||
"""Test update_status works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(sample_dataset.dataset_id, "ready")
|
||||
|
||||
assert sample_dataset.status == "ready"
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(uuid4()), "ready")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# update_training_status() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates training status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(str(sample_dataset.dataset_id), "running")
|
||||
|
||||
assert sample_dataset.training_status == "running"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_training_status_with_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status with active task ID."""
|
||||
task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"running",
|
||||
active_training_task_id=str(task_id),
|
||||
)
|
||||
|
||||
assert sample_dataset.active_training_task_id == task_id
|
||||
|
||||
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates main status when completed."""
|
||||
sample_dataset.status = "ready"
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
"completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
assert sample_dataset.training_status == "completed"
|
||||
assert sample_dataset.status == "trained"
|
||||
|
||||
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status clears task ID when None."""
|
||||
sample_dataset.active_training_task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
None,
|
||||
active_training_task_id=None,
|
||||
)
|
||||
|
||||
assert sample_dataset.active_training_task_id is None
|
||||
|
||||
def test_update_training_status_not_found(self, repo):
|
||||
"""Test update_training_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_training_status(str(uuid4()), "running")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# add_documents() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_add_documents_creates_links(self, repo):
|
||||
"""Test add_documents creates dataset document links."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 10,
|
||||
},
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(uuid4()), documents)
|
||||
|
||||
assert mock_session.add.call_count == 2
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_add_documents_default_counts(self, repo):
|
||||
"""Test add_documents uses default counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": str(uuid4()),
|
||||
"split": "train",
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(uuid4()), documents)
|
||||
|
||||
added_doc = mock_session.add.call_args[0][0]
|
||||
assert added_doc.page_count == 0
|
||||
assert added_doc.annotation_count == 0
|
||||
|
||||
def test_add_documents_with_uuid(self, repo):
|
||||
"""Test add_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
documents = [
|
||||
{
|
||||
"document_id": uuid4(),
|
||||
"split": "train",
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(uuid4(), documents)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_add_documents_empty_list(self, repo):
|
||||
"""Test add_documents with empty list."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.add_documents(str(uuid4()), [])
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# =========================================================================
|
||||
# get_documents() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_documents_returns_list(self, repo, sample_dataset_document):
|
||||
"""Test get_documents returns list of dataset documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(str(sample_dataset_document.dataset_id))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].split == "train"
|
||||
|
||||
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
|
||||
"""Test get_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(sample_dataset_document.dataset_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_documents_returns_empty_list(self, repo):
|
||||
"""Test get_documents returns empty list when no documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_documents(str(uuid4()))
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# delete() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_dataset):
|
||||
"""Test delete returns True when dataset exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_with_uuid(self, repo, sample_dataset):
|
||||
"""Test delete works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(sample_dataset.dataset_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
748
tests/data/repositories/test_document_repository.py
Normal file
748
tests/data/repositories/test_document_repository.py
Normal file
@@ -0,0 +1,748 @@
|
||||
"""
|
||||
Tests for DocumentRepository
|
||||
|
||||
Comprehensive TDD tests for document management - targeting 100% coverage.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
class TestDocumentRepository:
|
||||
"""Tests for DocumentRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_document(self) -> AdminDocument:
|
||||
"""Create a sample document for testing."""
|
||||
return AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/test.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def labeled_document(self) -> AdminDocument:
|
||||
"""Create a labeled document for testing."""
|
||||
return AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="labeled.pdf",
|
||||
file_size=2048,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/labeled.pdf",
|
||||
page_count=2,
|
||||
status="labeled",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def locked_document(self) -> AdminDocument:
|
||||
"""Create a document with annotation lock."""
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="locked.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/locked.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
doc.annotation_lock_until = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
return doc
|
||||
|
||||
@pytest.fixture
|
||||
def expired_lock_document(self) -> AdminDocument:
|
||||
"""Create a document with expired annotation lock."""
|
||||
doc = AdminDocument(
|
||||
document_id=uuid4(),
|
||||
filename="expired_lock.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/expired_lock.pdf",
|
||||
page_count=1,
|
||||
status="pending",
|
||||
category="invoice",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
doc.annotation_lock_until = datetime.now(timezone.utc) - timedelta(minutes=5)
|
||||
return doc
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> DocumentRepository:
|
||||
"""Create a DocumentRepository instance."""
|
||||
return DocumentRepository()
|
||||
|
||||
# ==========================================================================
|
||||
# create() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_create_returns_document_id(self, repo):
|
||||
"""Test create returns document ID."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/test.pdf",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_create_with_all_parameters(self, repo):
|
||||
"""Test create with all optional parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
filename="test.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/tmp/test.pdf",
|
||||
page_count=5,
|
||||
upload_source="api",
|
||||
csv_field_values={"InvoiceNumber": "INV-001"},
|
||||
group_key="batch-001",
|
||||
category="receipt",
|
||||
admin_token="token-123",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
added_doc = mock_session.add.call_args[0][0]
|
||||
assert added_doc.page_count == 5
|
||||
assert added_doc.upload_source == "api"
|
||||
assert added_doc.csv_field_values == {"InvoiceNumber": "INV-001"}
|
||||
assert added_doc.group_key == "batch-001"
|
||||
assert added_doc.category == "receipt"
|
||||
|
||||
# ==========================================================================
|
||||
# get() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_returns_document(self, repo, sample_document):
|
||||
"""Test get returns document when exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_document.document_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.filename == "test.pdf"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
# ==========================================================================
|
||||
# get_by_token() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_by_token_delegates_to_get(self, repo, sample_document):
|
||||
"""Test get_by_token delegates to get method."""
|
||||
with patch.object(repo, "get", return_value=sample_document) as mock_get:
|
||||
result = repo.get_by_token(str(sample_document.document_id), "token-123")
|
||||
|
||||
assert result == sample_document
|
||||
mock_get.assert_called_once_with(str(sample_document.document_id))
|
||||
|
||||
# ==========================================================================
|
||||
# get_paginated() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_paginated_no_filters(self, repo, sample_document):
|
||||
"""Test get_paginated with no filters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated()
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(status="pending")
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with upload_source filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(upload_source="ui")
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with auto_label_status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(auto_label_status="completed")
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with batch_id filter."""
|
||||
batch_id = str(uuid4())
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(batch_id=batch_id)
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_category_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with category filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(category="invoice")
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=True."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(has_annotations=True)
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_paginated(has_annotations=False)
|
||||
|
||||
assert total == 1
|
||||
|
||||
# ==========================================================================
|
||||
# update_status() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_update_status(self, repo, sample_document):
|
||||
"""Test update_status updates document status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_document.document_id), "labeled")
|
||||
|
||||
assert sample_document.status == "labeled"
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_status_with_auto_label_status(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
"labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
assert sample_document.auto_label_status == "completed"
|
||||
|
||||
def test_update_status_with_auto_label_error(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_error."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
"failed",
|
||||
auto_label_error="OCR failed",
|
||||
)
|
||||
|
||||
assert sample_document.auto_label_error == "OCR failed"
|
||||
|
||||
def test_update_status_document_not_found(self, repo):
|
||||
"""Test update_status when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(uuid4()), "labeled")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# ==========================================================================
|
||||
# update_file_path() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_update_file_path(self, repo, sample_document):
|
||||
"""Test update_file_path updates document file path."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file_path(str(sample_document.document_id), "/new/path.pdf")
|
||||
|
||||
assert sample_document.file_path == "/new/path.pdf"
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
def test_update_file_path_document_not_found(self, repo):
|
||||
"""Test update_file_path when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_file_path(str(uuid4()), "/new/path.pdf")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# ==========================================================================
|
||||
# update_group_key() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_update_group_key_returns_true(self, repo, sample_document):
|
||||
"""Test update_group_key returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update_group_key(str(sample_document.document_id), "new-group")
|
||||
|
||||
assert result is True
|
||||
assert sample_document.group_key == "new-group"
|
||||
|
||||
def test_update_group_key_returns_false(self, repo):
|
||||
"""Test update_group_key returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update_group_key(str(uuid4()), "new-group")
|
||||
|
||||
assert result is False
|
||||
|
||||
# ==========================================================================
|
||||
# update_category() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_update_category(self, repo, sample_document):
|
||||
"""Test update_category updates document category."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update_category(str(sample_document.document_id), "receipt")
|
||||
|
||||
assert sample_document.category == "receipt"
|
||||
mock_session.add.assert_called()
|
||||
|
||||
def test_update_category_returns_none_when_not_found(self, repo):
|
||||
"""Test update_category returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update_category(str(uuid4()), "receipt")
|
||||
|
||||
assert result is None
|
||||
|
||||
# ==========================================================================
|
||||
# delete() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_delete_returns_true_when_exists(self, repo, sample_document):
|
||||
"""Test delete returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once_with(sample_document)
|
||||
|
||||
def test_delete_with_annotations(self, repo, sample_document):
|
||||
"""Test delete removes annotations before deleting document."""
|
||||
annotation = MagicMock()
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = [annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
|
||||
assert result is True
|
||||
assert mock_session.delete.call_count == 2
|
||||
|
||||
def test_delete_returns_false_when_not_exists(self, repo):
|
||||
"""Test delete returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
|
||||
# ==========================================================================
|
||||
# get_categories() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_categories(self, repo):
|
||||
"""Test get_categories returns unique categories."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_categories()
|
||||
|
||||
assert result == ["invoice", "receipt"]
|
||||
|
||||
# ==========================================================================
|
||||
# get_labeled_for_export() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_labeled_for_export(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export returns labeled documents."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_labeled_for_export()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].status == "labeled"
|
||||
|
||||
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export with admin_token filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_labeled_for_export(admin_token="token-123")
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
# ==========================================================================
|
||||
# count_by_status() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_count_by_status(self, repo):
|
||||
"""Test count_by_status returns status counts."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [
|
||||
("pending", 10),
|
||||
("labeled", 5),
|
||||
]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.count_by_status()
|
||||
|
||||
assert result == {"pending": 10, "labeled": 5}
|
||||
|
||||
# ==========================================================================
|
||||
# get_by_ids() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_by_ids(self, repo, sample_document):
|
||||
"""Test get_by_ids returns documents by IDs."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_by_ids([str(sample_document.document_id)])
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
# ==========================================================================
|
||||
# get_for_training() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_get_for_training_basic(self, repo, labeled_document):
|
||||
"""Test get_for_training with default parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_for_training()
|
||||
|
||||
assert total == 1
|
||||
assert len(results) == 1
|
||||
|
||||
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
|
||||
"""Test get_for_training with min_annotation_count."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_for_training(min_annotation_count=3)
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_for_training_exclude_used(self, repo, labeled_document):
|
||||
"""Test get_for_training with exclude_used_in_training."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_for_training(exclude_used_in_training=True)
|
||||
|
||||
assert total == 1
|
||||
|
||||
def test_get_for_training_no_annotations(self, repo, labeled_document):
|
||||
"""Test get_for_training with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
results, total = repo.get_for_training(has_annotations=False)
|
||||
|
||||
assert total == 1
|
||||
|
||||
# ==========================================================================
|
||||
# acquire_annotation_lock() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_acquire_annotation_lock_success(self, repo, sample_document):
|
||||
"""Test acquire_annotation_lock when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert result is not None
|
||||
assert sample_document.annotation_lock_until is not None
|
||||
|
||||
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
|
||||
"""Test acquire_annotation_lock fails when document is already locked."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.acquire_annotation_lock(str(locked_document.document_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_acquire_annotation_lock_document_not_found(self, repo):
|
||||
"""Test acquire_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.acquire_annotation_lock(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
# ==========================================================================
|
||||
# release_annotation_lock() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_release_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test release_annotation_lock releases the lock."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.release_annotation_lock(str(locked_document.document_id))
|
||||
|
||||
assert result is not None
|
||||
assert locked_document.annotation_lock_until is None
|
||||
|
||||
def test_release_annotation_lock_document_not_found(self, repo):
|
||||
"""Test release_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.release_annotation_lock(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
# ==========================================================================
|
||||
# extend_annotation_lock() tests
|
||||
# ==========================================================================
|
||||
|
||||
def test_extend_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test extend_annotation_lock extends the lock."""
|
||||
original_lock = locked_document.annotation_lock_until
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.extend_annotation_lock(str(locked_document.document_id))
|
||||
|
||||
assert result is not None
|
||||
assert locked_document.annotation_lock_until > original_lock
|
||||
|
||||
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
|
||||
"""Test extend_annotation_lock fails when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.extend_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
|
||||
"""Test extend_annotation_lock fails when lock is expired."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_lock_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.extend_annotation_lock(str(expired_lock_document.document_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_extend_annotation_lock_document_not_found(self, repo):
|
||||
"""Test extend_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.extend_annotation_lock(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
582
tests/data/repositories/test_model_version_repository.py
Normal file
582
tests/data/repositories/test_model_version_repository.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""
|
||||
Tests for ModelVersionRepository
|
||||
|
||||
100% coverage tests for model version management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionRepository:
|
||||
"""Tests for ModelVersionRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model(self) -> ModelVersion:
|
||||
"""Create a sample model version for testing."""
|
||||
return ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="v1.0.0",
|
||||
name="Test Model",
|
||||
description="A test model",
|
||||
model_path="/path/to/model.pt",
|
||||
status="ready",
|
||||
is_active=False,
|
||||
metrics_mAP=0.95,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=100,
|
||||
training_config={"epochs": 100},
|
||||
file_size=1024000,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def active_model(self) -> ModelVersion:
|
||||
"""Create an active model version for testing."""
|
||||
return ModelVersion(
|
||||
version_id=uuid4(),
|
||||
version="v1.0.0",
|
||||
name="Active Model",
|
||||
model_path="/path/to/active_model.pt",
|
||||
status="active",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> ModelVersionRepository:
|
||||
"""Create a ModelVersionRepository instance."""
|
||||
return ModelVersionRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_model(self, repo):
|
||||
"""Test create returns created model version."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
version="v1.0.0",
|
||||
name="Test Model",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
task_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
trained_at = datetime.now(timezone.utc)
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
version="v2.0.0",
|
||||
name="Full Model",
|
||||
model_path="/path/to/full_model.pt",
|
||||
description="A complete model",
|
||||
task_id=str(task_id),
|
||||
dataset_id=str(dataset_id),
|
||||
metrics_mAP=0.95,
|
||||
metrics_precision=0.92,
|
||||
metrics_recall=0.88,
|
||||
document_count=500,
|
||||
training_config={"epochs": 200},
|
||||
file_size=2048000,
|
||||
trained_at=trained_at,
|
||||
)
|
||||
|
||||
added_model = mock_session.add.call_args[0][0]
|
||||
assert added_model.version == "v2.0.0"
|
||||
assert added_model.description == "A complete model"
|
||||
assert added_model.task_id == task_id
|
||||
assert added_model.dataset_id == dataset_id
|
||||
assert added_model.metrics_mAP == 0.95
|
||||
|
||||
def test_create_with_uuid_objects(self, repo):
|
||||
"""Test create works with UUID objects."""
|
||||
task_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
version="v1.0.0",
|
||||
name="Test Model",
|
||||
model_path="/path/to/model.pt",
|
||||
task_id=task_id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
|
||||
added_model = mock_session.add.call_args[0][0]
|
||||
assert added_model.task_id == task_id
|
||||
assert added_model.dataset_id == dataset_id
|
||||
|
||||
def test_create_without_optional_ids(self, repo):
|
||||
"""Test create without task_id and dataset_id."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
version="v1.0.0",
|
||||
name="Test Model",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
|
||||
added_model = mock_session.add.call_args[0][0]
|
||||
assert added_model.task_id is None
|
||||
assert added_model.dataset_id is None
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_model(self, repo, sample_model):
|
||||
"""Test get returns model when exists."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_model.version_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Test Model"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_model):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(sample_model.version_id)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
|
||||
"""Test get_paginated returns list of models and total count."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
models, total = repo.get_paginated()
|
||||
|
||||
assert len(models) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_model):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
models, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert len(models) == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_model):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
models, total = repo.get_paginated(limit=10, offset=20)
|
||||
|
||||
assert total == 50
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
models, total = repo.get_paginated()
|
||||
|
||||
assert models == []
|
||||
assert total == 0
|
||||
|
||||
# =========================================================================
|
||||
# get_active() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_active_returns_active_model(self, repo, active_model):
|
||||
"""Test get_active returns the active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active()
|
||||
|
||||
assert result is not None
|
||||
assert result.is_active is True
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_active_returns_none(self, repo):
|
||||
"""Test get_active returns None when no active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_active()
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# activate() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_activate_activates_model(self, repo, sample_model, active_model):
|
||||
"""Test activate sets model as active and deactivates others."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [active_model]
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.activate(str(sample_model.version_id))
|
||||
|
||||
assert result is not None
|
||||
assert sample_model.is_active is True
|
||||
assert sample_model.status == "active"
|
||||
assert active_model.is_active is False
|
||||
assert active_model.status == "inactive"
|
||||
|
||||
def test_activate_with_uuid(self, repo, sample_model):
|
||||
"""Test activate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.activate(sample_model.version_id)
|
||||
|
||||
assert result is not None
|
||||
assert sample_model.is_active is True
|
||||
|
||||
def test_activate_returns_none_when_not_found(self, repo):
|
||||
"""Test activate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.activate(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_activate_sets_activated_at(self, repo, sample_model):
|
||||
"""Test activate sets activated_at timestamp."""
|
||||
sample_model.activated_at = None
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.activate(str(sample_model.version_id))
|
||||
|
||||
assert sample_model.activated_at is not None
|
||||
|
||||
# =========================================================================
|
||||
# deactivate() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_deactivate_deactivates_model(self, repo, active_model):
|
||||
"""Test deactivate sets model as inactive."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.deactivate(str(active_model.version_id))
|
||||
|
||||
assert result is not None
|
||||
assert active_model.is_active is False
|
||||
assert active_model.status == "inactive"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_deactivate_with_uuid(self, repo, active_model):
|
||||
"""Test deactivate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.deactivate(active_model.version_id)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_deactivate_returns_none_when_not_found(self, repo):
|
||||
"""Test deactivate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.deactivate(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
# =========================================================================
|
||||
# update() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_updates_model(self, repo, sample_model):
|
||||
"""Test update updates model metadata."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_model.version_id),
|
||||
name="Updated Model",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert sample_model.name == "Updated Model"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_update_all_fields(self, repo, sample_model):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_model.version_id),
|
||||
name="New Name",
|
||||
description="New Description",
|
||||
status="archived",
|
||||
)
|
||||
|
||||
assert sample_model.name == "New Name"
|
||||
assert sample_model.description == "New Description"
|
||||
assert sample_model.status == "archived"
|
||||
|
||||
def test_update_with_uuid(self, repo, sample_model):
|
||||
"""Test update works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(sample_model.version_id, name="Updated")
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_update_returns_none_when_not_found(self, repo):
|
||||
"""Test update returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(str(uuid4()), name="New Name")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_update_partial_fields(self, repo, sample_model):
|
||||
"""Test update only updates provided fields."""
|
||||
original_name = sample_model.name
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.update(
|
||||
str(sample_model.version_id),
|
||||
description="Only description changed",
|
||||
)
|
||||
|
||||
assert sample_model.name == original_name
|
||||
assert sample_model.description == "Only description changed"
|
||||
|
||||
# =========================================================================
|
||||
# archive() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_archive_archives_model(self, repo, sample_model):
|
||||
"""Test archive sets model status to archived."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.archive(str(sample_model.version_id))
|
||||
|
||||
assert result is not None
|
||||
assert sample_model.status == "archived"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_archive_with_uuid(self, repo, sample_model):
|
||||
"""Test archive works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.archive(sample_model.version_id)
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_archive_returns_none_when_not_found(self, repo):
|
||||
"""Test archive returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.archive(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_archive_returns_none_when_active(self, repo, active_model):
|
||||
"""Test archive returns None when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.archive(str(active_model.version_id))
|
||||
|
||||
assert result is None
|
||||
|
||||
# =========================================================================
|
||||
# delete() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_model):
|
||||
"""Test delete returns True when model exists and not active."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(sample_model.version_id))
|
||||
|
||||
assert result is True
|
||||
mock_session.delete.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_with_uuid(self, repo, sample_model):
|
||||
"""Test delete works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(sample_model.version_id)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
|
||||
def test_delete_returns_false_when_active(self, repo, active_model):
|
||||
"""Test delete returns False when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.delete(str(active_model.version_id))
|
||||
|
||||
assert result is False
|
||||
mock_session.delete.assert_not_called()
|
||||
199
tests/data/repositories/test_token_repository.py
Normal file
199
tests/data/repositories/test_token_repository.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Tests for TokenRepository
|
||||
|
||||
TDD tests for admin token management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenRepository:
|
||||
"""Tests for TokenRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_token(self) -> AdminToken:
|
||||
"""Create a sample token for testing."""
|
||||
return AdminToken(
|
||||
token="test-token-123",
|
||||
name="Test Token",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
last_used_at=None,
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def expired_token(self) -> AdminToken:
|
||||
"""Create an expired token."""
|
||||
return AdminToken(
|
||||
token="expired-token",
|
||||
name="Expired Token",
|
||||
is_active=True,
|
||||
created_at=datetime.now(timezone.utc) - timedelta(days=30),
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(days=1),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def inactive_token(self) -> AdminToken:
|
||||
"""Create an inactive token."""
|
||||
return AdminToken(
|
||||
token="inactive-token",
|
||||
name="Inactive Token",
|
||||
is_active=False,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> TokenRepository:
|
||||
"""Create a TokenRepository instance."""
|
||||
return TokenRepository()
|
||||
|
||||
def test_is_valid_returns_true_for_active_token(self, repo, sample_token):
|
||||
"""Test is_valid returns True for an active, non-expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.is_valid("test-token-123")
|
||||
|
||||
assert result is True
|
||||
mock_session.get.assert_called_once_with(AdminToken, "test-token-123")
|
||||
|
||||
def test_is_valid_returns_false_for_nonexistent_token(self, repo):
|
||||
"""Test is_valid returns False for a non-existent token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.is_valid("nonexistent-token")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_returns_false_for_inactive_token(self, repo, inactive_token):
|
||||
"""Test is_valid returns False for an inactive token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = inactive_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.is_valid("inactive-token")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_returns_false_for_expired_token(self, repo, expired_token):
|
||||
"""Test is_valid returns False for an expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.is_valid("expired-token")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_get_returns_token_when_exists(self, repo, sample_token):
|
||||
"""Test get returns token when it exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get("test-token-123")
|
||||
|
||||
assert result is not None
|
||||
assert result.token == "test-token-123"
|
||||
assert result.name == "Test Token"
|
||||
mock_session.expunge.assert_called_once_with(sample_token)
|
||||
|
||||
def test_get_returns_none_when_not_exists(self, repo):
|
||||
"""Test get returns None when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get("nonexistent-token")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_create_new_token(self, repo):
|
||||
"""Test creating a new token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None # Token doesn't exist
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create("new-token", "New Token", expires_at=None)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
added_token = mock_session.add.call_args[0][0]
|
||||
assert isinstance(added_token, AdminToken)
|
||||
assert added_token.token == "new-token"
|
||||
assert added_token.name == "New Token"
|
||||
|
||||
def test_create_updates_existing_token(self, repo, sample_token):
|
||||
"""Test create updates an existing token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create("test-token-123", "Updated Name", expires_at=None)
|
||||
|
||||
mock_session.add.assert_called_once_with(sample_token)
|
||||
assert sample_token.name == "Updated Name"
|
||||
assert sample_token.is_active is True
|
||||
|
||||
def test_update_usage(self, repo, sample_token):
|
||||
"""Test updating token last_used_at timestamp."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_usage("test-token-123")
|
||||
|
||||
assert sample_token.last_used_at is not None
|
||||
mock_session.add.assert_called_once_with(sample_token)
|
||||
|
||||
def test_deactivate_returns_true_when_token_exists(self, repo, sample_token):
|
||||
"""Test deactivate returns True when token exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.deactivate("test-token-123")
|
||||
|
||||
assert result is True
|
||||
assert sample_token.is_active is False
|
||||
mock_session.add.assert_called_once_with(sample_token)
|
||||
|
||||
def test_deactivate_returns_false_when_token_not_exists(self, repo):
|
||||
"""Test deactivate returns False when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.deactivate("nonexistent-token")
|
||||
|
||||
assert result is False
|
||||
615
tests/data/repositories/test_training_task_repository.py
Normal file
615
tests/data/repositories/test_training_task_repository.py
Normal file
@@ -0,0 +1,615 @@
|
||||
"""
|
||||
Tests for TrainingTaskRepository
|
||||
|
||||
100% coverage tests for training task management.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskRepository:
|
||||
"""Tests for TrainingTaskRepository."""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_task(self) -> TrainingTask:
|
||||
"""Create a sample training task for testing."""
|
||||
return TrainingTask(
|
||||
task_id=uuid4(),
|
||||
admin_token="admin-token",
|
||||
name="Test Training Task",
|
||||
task_type="train",
|
||||
description="A test training task",
|
||||
status="pending",
|
||||
config={"epochs": 100, "batch_size": 16},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_log(self) -> TrainingLog:
|
||||
"""Create a sample training log for testing."""
|
||||
return TrainingLog(
|
||||
log_id=uuid4(),
|
||||
task_id=uuid4(),
|
||||
level="INFO",
|
||||
message="Training started",
|
||||
details={"epoch": 1},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_link(self) -> TrainingDocumentLink:
|
||||
"""Create a sample training document link for testing."""
|
||||
return TrainingDocumentLink(
|
||||
link_id=uuid4(),
|
||||
task_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
annotation_snapshot={"annotations": []},
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def repo(self) -> TrainingTaskRepository:
|
||||
"""Create a TrainingTaskRepository instance."""
|
||||
return TrainingTaskRepository()
|
||||
|
||||
# =========================================================================
|
||||
# create() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_returns_task_id(self, repo):
|
||||
"""Test create returns task ID."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.flush.assert_called_once()
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.create(
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
task_type="finetune",
|
||||
description="Full test",
|
||||
config={"epochs": 50},
|
||||
scheduled_at=scheduled_time,
|
||||
cron_expression="0 0 * * *",
|
||||
is_recurring=True,
|
||||
dataset_id=str(uuid4()),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
added_task = mock_session.add.call_args[0][0]
|
||||
assert added_task.task_type == "finetune"
|
||||
assert added_task.description == "Full test"
|
||||
assert added_task.is_recurring is True
|
||||
assert added_task.status == "scheduled" # because scheduled_at is set
|
||||
|
||||
def test_create_pending_status_when_not_scheduled(self, repo):
|
||||
"""Test create sets pending status when no scheduled_at."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
)
|
||||
|
||||
added_task = mock_session.add.call_args[0][0]
|
||||
assert added_task.status == "pending"
|
||||
|
||||
def test_create_scheduled_status_when_scheduled(self, repo):
|
||||
"""Test create sets scheduled status when scheduled_at is provided."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.create(
|
||||
admin_token="admin-token",
|
||||
name="Test Task",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
added_task = mock_session.add.call_args[0][0]
|
||||
assert added_task.status == "scheduled"
|
||||
|
||||
# =========================================================================
|
||||
# get() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_returns_task(self, repo, sample_task):
|
||||
"""Test get returns task when exists."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(sample_task.task_id))
|
||||
|
||||
assert result is not None
|
||||
assert result.name == "Test Training Task"
|
||||
mock_session.expunge.assert_called_once()
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get(str(uuid4()))
|
||||
|
||||
assert result is None
|
||||
mock_session.expunge.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# get_by_token() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_by_token_returns_task(self, repo, sample_task):
|
||||
"""Test get_by_token returns task (delegates to get)."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_by_token(str(sample_task.task_id), "admin-token")
|
||||
|
||||
assert result is not None
|
||||
|
||||
def test_get_by_token_without_token_param(self, repo, sample_task):
|
||||
"""Test get_by_token works without token parameter."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_by_token(str(sample_task.task_id))
|
||||
|
||||
assert result is not None
|
||||
|
||||
# =========================================================================
|
||||
# get_paginated() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_paginated_returns_tasks_and_total(self, repo, sample_task):
|
||||
"""Test get_paginated returns list of tasks and total count."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
tasks, total = repo.get_paginated()
|
||||
|
||||
assert len(tasks) == 1
|
||||
assert total == 1
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_task):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
tasks, total = repo.get_paginated(status="pending")
|
||||
|
||||
assert len(tasks) == 1
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_task):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
tasks, total = repo.get_paginated(limit=10, offset=20)
|
||||
|
||||
assert total == 50
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
tasks, total = repo.get_paginated()
|
||||
|
||||
assert tasks == []
|
||||
assert total == 0
|
||||
|
||||
# =========================================================================
|
||||
# get_pending() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_pending_returns_pending_tasks(self, repo, sample_task):
|
||||
"""Test get_pending returns pending and scheduled tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_pending()
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_pending_returns_empty_list(self, repo):
|
||||
"""Test get_pending returns empty list when no pending tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_pending()
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# update_status() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_update_status_updates_task(self, repo, sample_task):
|
||||
"""Test update_status updates task status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_task.task_id), "running")
|
||||
|
||||
assert sample_task.status == "running"
|
||||
|
||||
def test_update_status_sets_started_at_for_running(self, repo, sample_task):
|
||||
"""Test update_status sets started_at when status is running."""
|
||||
sample_task.started_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_task.task_id), "running")
|
||||
|
||||
assert sample_task.started_at is not None
|
||||
|
||||
def test_update_status_sets_completed_at_for_completed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is completed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_task.task_id), "completed")
|
||||
|
||||
assert sample_task.completed_at is not None
|
||||
|
||||
def test_update_status_sets_completed_at_for_failed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is failed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(sample_task.task_id), "failed", error_message="Error occurred")
|
||||
|
||||
assert sample_task.completed_at is not None
|
||||
assert sample_task.error_message == "Error occurred"
|
||||
|
||||
def test_update_status_with_result_metrics(self, repo, sample_task):
|
||||
"""Test update_status with result metrics."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_task.task_id),
|
||||
"completed",
|
||||
result_metrics={"mAP": 0.95},
|
||||
)
|
||||
|
||||
assert sample_task.result_metrics == {"mAP": 0.95}
|
||||
|
||||
def test_update_status_with_model_path(self, repo, sample_task):
|
||||
"""Test update_status with model path."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(
|
||||
str(sample_task.task_id),
|
||||
"completed",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
|
||||
assert sample_task.model_path == "/path/to/model.pt"
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.update_status(str(uuid4()), "running")
|
||||
|
||||
mock_session.add.assert_not_called()
|
||||
|
||||
# =========================================================================
|
||||
# cancel() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_cancel_returns_true_for_pending(self, repo, sample_task):
|
||||
"""Test cancel returns True for pending task."""
|
||||
sample_task.status = "pending"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.cancel(str(sample_task.task_id))
|
||||
|
||||
assert result is True
|
||||
assert sample_task.status == "cancelled"
|
||||
|
||||
def test_cancel_returns_true_for_scheduled(self, repo, sample_task):
|
||||
"""Test cancel returns True for scheduled task."""
|
||||
sample_task.status = "scheduled"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.cancel(str(sample_task.task_id))
|
||||
|
||||
assert result is True
|
||||
assert sample_task.status == "cancelled"
|
||||
|
||||
def test_cancel_returns_false_for_running(self, repo, sample_task):
|
||||
"""Test cancel returns False for running task."""
|
||||
sample_task.status = "running"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.cancel(str(sample_task.task_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_cancel_returns_false_when_not_found(self, repo):
|
||||
"""Test cancel returns False when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.cancel(str(uuid4()))
|
||||
|
||||
assert result is False
|
||||
|
||||
# =========================================================================
|
||||
# add_log() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_add_log_creates_log_entry(self, repo):
|
||||
"""Test add_log creates a log entry."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.add_log(
|
||||
task_id=str(uuid4()),
|
||||
level="INFO",
|
||||
message="Training started",
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
added_log = mock_session.add.call_args[0][0]
|
||||
assert added_log.level == "INFO"
|
||||
assert added_log.message == "Training started"
|
||||
|
||||
def test_add_log_with_details(self, repo):
|
||||
"""Test add_log with details."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
repo.add_log(
|
||||
task_id=str(uuid4()),
|
||||
level="DEBUG",
|
||||
message="Epoch complete",
|
||||
details={"epoch": 5, "loss": 0.05},
|
||||
)
|
||||
|
||||
added_log = mock_session.add.call_args[0][0]
|
||||
assert added_log.details == {"epoch": 5, "loss": 0.05}
|
||||
|
||||
# =========================================================================
|
||||
# get_logs() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_logs_returns_list(self, repo, sample_log):
|
||||
"""Test get_logs returns list of logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_logs(str(sample_log.task_id))
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].level == "INFO"
|
||||
|
||||
def test_get_logs_with_pagination(self, repo, sample_log):
|
||||
"""Test get_logs with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_logs(str(sample_log.task_id), limit=50, offset=10)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_logs_returns_empty_list(self, repo):
|
||||
"""Test get_logs returns empty list when no logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_logs(str(uuid4()))
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# create_document_link() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_create_document_link_returns_link(self, repo):
|
||||
"""Test create_document_link returns created link."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
task_id = uuid4()
|
||||
document_id = uuid4()
|
||||
result = repo.create_document_link(
|
||||
task_id=task_id,
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_create_document_link_with_snapshot(self, repo):
|
||||
"""Test create_document_link with annotation snapshot."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
snapshot = {"annotations": [{"class_name": "invoice_number"}]}
|
||||
repo.create_document_link(
|
||||
task_id=uuid4(),
|
||||
document_id=uuid4(),
|
||||
annotation_snapshot=snapshot,
|
||||
)
|
||||
|
||||
added_link = mock_session.add.call_args[0][0]
|
||||
assert added_link.annotation_snapshot == snapshot
|
||||
|
||||
# =========================================================================
|
||||
# get_document_links() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_document_links_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_links returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_links(sample_link.task_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_document_links_returns_empty_list(self, repo):
|
||||
"""Test get_document_links returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_links(uuid4())
|
||||
|
||||
assert result == []
|
||||
|
||||
# =========================================================================
|
||||
# get_document_training_tasks() tests
|
||||
# =========================================================================
|
||||
|
||||
def test_get_document_training_tasks_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_training_tasks returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_training_tasks(sample_link.document_id)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_get_document_training_tasks_returns_empty_list(self, repo):
|
||||
"""Test get_document_training_tasks returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
result = repo.get_document_training_tasks(uuid4())
|
||||
|
||||
assert result == []
|
||||
Reference in New Issue
Block a user