re-structure
This commit is contained in:
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepository:
|
||||
@@ -66,7 +66,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_returns_annotation_id(self, repo):
|
||||
"""Test create returns annotation ID."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -92,7 +92,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_with_optional_params(self, repo):
|
||||
"""Test create with optional text_value and confidence."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -124,7 +124,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_default_source_is_manual(self, repo):
|
||||
"""Test create uses manual as default source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -153,7 +153,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_returns_ids(self, repo):
|
||||
"""Test create_batch returns list of annotation IDs."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -195,7 +195,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_default_page_number(self, repo):
|
||||
"""Test create_batch uses page_number=1 by default."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -224,7 +224,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_with_all_optional_params(self, repo):
|
||||
"""Test create_batch with all optional parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -259,7 +259,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_empty_list(self, repo):
|
||||
"""Test create_batch with empty list returns empty."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -275,7 +275,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_returns_annotation(self, repo, sample_annotation):
|
||||
"""Test get returns annotation when exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -289,7 +289,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -306,7 +306,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
|
||||
"""Test get_for_document returns all annotations for document."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -319,7 +319,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
|
||||
"""Test get_for_document filters by page number."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -331,7 +331,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_returns_empty_list(self, repo):
|
||||
"""Test get_for_document returns empty list when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -347,7 +347,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_returns_true(self, repo, sample_annotation):
|
||||
"""Test update returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -363,7 +363,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_returns_false_when_not_found(self, repo):
|
||||
"""Test update returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -375,7 +375,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_all_fields(self, repo, sample_annotation):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -412,7 +412,7 @@ class TestAnnotationRepository:
|
||||
def test_update_partial_fields(self, repo, sample_annotation):
|
||||
"""Test update only updates provided fields."""
|
||||
original_x = sample_annotation.x_center
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -433,7 +433,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_annotation):
|
||||
"""Test delete returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -446,7 +446,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -463,7 +463,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_returns_count(self, repo, sample_annotation):
|
||||
"""Test delete_for_document returns count of deleted annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -476,7 +476,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
|
||||
"""Test delete_for_document filters by source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -488,7 +488,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_returns_zero(self, repo):
|
||||
"""Test delete_for_document returns 0 when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -506,7 +506,7 @@ class TestAnnotationRepository:
|
||||
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
|
||||
"""Test verify marks annotation as verified."""
|
||||
sample_annotation.is_verified = False
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -521,7 +521,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_verify_returns_none_when_not_found(self, repo):
|
||||
"""Test verify returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -538,7 +538,7 @@ class TestAnnotationRepository:
|
||||
def test_override_updates_annotation(self, repo, sample_annotation):
|
||||
"""Test override updates annotation and creates history."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -559,7 +559,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_override_returns_none_when_not_found(self, repo):
|
||||
"""Test override returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -573,7 +573,7 @@ class TestAnnotationRepository:
|
||||
"""Test override does not change override_source if already manual."""
|
||||
sample_annotation.source = "manual"
|
||||
sample_annotation.override_source = None
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -591,7 +591,7 @@ class TestAnnotationRepository:
|
||||
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
|
||||
"""Test override ignores unknown attributes."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -614,7 +614,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_history_returns_history(self, repo):
|
||||
"""Test create_history returns created history record."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -636,7 +636,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_history_with_minimal_params(self, repo):
|
||||
"""Test create_history with minimal parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -659,7 +659,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -672,7 +672,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_history_returns_empty_list(self, repo):
|
||||
"""Test get_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -688,7 +688,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_document_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_document_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -700,7 +700,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_document_history_returns_empty_list(self, repo):
|
||||
"""Test get_document_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ConcreteRepository(BaseRepository[MagicMock]):
|
||||
@@ -31,7 +31,7 @@ class TestBaseRepository:
|
||||
|
||||
def test_session_yields_session(self, repo):
|
||||
"""Test _session yields a database session."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
from backend.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadRepository:
|
||||
@@ -54,7 +54,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_returns_batch(self, repo):
|
||||
"""Test create returns created batch upload."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -70,7 +70,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_with_upload_source(self, repo):
|
||||
"""Test create with custom upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -87,7 +87,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_default_upload_source(self, repo):
|
||||
"""Test create uses default upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -107,7 +107,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_returns_batch(self, repo, sample_batch):
|
||||
"""Test get returns batch when exists."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -121,7 +121,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -138,7 +138,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_updates_batch(self, repo, sample_batch):
|
||||
"""Test update updates batch fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -156,7 +156,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_ignores_unknown_fields(self, repo, sample_batch):
|
||||
"""Test update ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -171,7 +171,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_not_found(self, repo):
|
||||
"""Test update does nothing when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -183,7 +183,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_multiple_fields(self, repo, sample_batch):
|
||||
"""Test update can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -206,7 +206,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_file_returns_file(self, repo):
|
||||
"""Test create_file returns created file record."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -221,7 +221,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_file_with_kwargs(self, repo):
|
||||
"""Test create_file with additional kwargs."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -242,7 +242,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_updates_file(self, repo, sample_file):
|
||||
"""Test update_file updates file fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -258,7 +258,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
|
||||
"""Test update_file ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -273,7 +273,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_not_found(self, repo):
|
||||
"""Test update_file does nothing when file not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -285,7 +285,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_multiple_fields(self, repo, sample_file):
|
||||
"""Test update_file can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -304,7 +304,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_files_returns_list(self, repo, sample_file):
|
||||
"""Test get_files returns list of files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_file]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -317,7 +317,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_files_returns_empty_list(self, repo):
|
||||
"""Test get_files returns empty list when no files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -333,7 +333,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
|
||||
"""Test get_paginated returns list of batches and total count."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
@@ -347,7 +347,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_batch):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 100
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
@@ -360,7 +360,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -374,7 +374,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_with_admin_token(self, repo, sample_batch):
|
||||
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepository:
|
||||
@@ -69,7 +69,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_returns_dataset(self, repo):
|
||||
"""Test create returns created dataset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -81,7 +81,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -103,7 +103,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_default_values(self, repo):
|
||||
"""Test create uses default values."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -121,7 +121,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_returns_dataset(self, repo, sample_dataset):
|
||||
"""Test get returns dataset when exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -135,7 +135,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_dataset):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -147,7 +147,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -164,7 +164,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
|
||||
"""Test get_paginated returns list of datasets and total count."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -178,7 +178,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -191,7 +191,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_dataset):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -204,7 +204,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -222,7 +222,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
|
||||
"""Test get_active_training_tasks returns dict of active tasks."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_training_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -240,7 +240,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_active_training_tasks_invalid_uuid(self, repo):
|
||||
"""Test get_active_training_tasks filters invalid UUIDs."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -263,7 +263,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_status updates dataset status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -276,7 +276,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_error_message(self, repo, sample_dataset):
|
||||
"""Test update_status with error message."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -292,7 +292,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_totals(self, repo, sample_dataset):
|
||||
"""Test update_status with total counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -312,7 +312,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_dataset_path(self, repo, sample_dataset):
|
||||
"""Test update_status with dataset path."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -328,7 +328,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_uuid(self, repo, sample_dataset):
|
||||
"""Test update_status works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -340,7 +340,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -356,7 +356,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates training status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -370,7 +370,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_with_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status with active task ID."""
|
||||
task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -387,7 +387,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates main status when completed."""
|
||||
sample_dataset.status = "ready"
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -405,7 +405,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status clears task ID when None."""
|
||||
sample_dataset.active_training_task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -421,7 +421,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_training_status_not_found(self, repo):
|
||||
"""Test update_training_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -437,7 +437,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_creates_links(self, repo):
|
||||
"""Test add_documents creates dataset document links."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -464,7 +464,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_default_counts(self, repo):
|
||||
"""Test add_documents uses default counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -484,7 +484,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_with_uuid(self, repo):
|
||||
"""Test add_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -502,7 +502,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_empty_list(self, repo):
|
||||
"""Test add_documents with empty list."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -518,7 +518,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_returns_list(self, repo, sample_dataset_document):
|
||||
"""Test get_documents returns list of dataset documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -531,7 +531,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
|
||||
"""Test get_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -543,7 +543,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_returns_empty_list(self, repo):
|
||||
"""Test get_documents returns empty list when no documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -559,7 +559,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_dataset):
|
||||
"""Test delete returns True when dataset exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -573,7 +573,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_with_uuid(self, repo, sample_dataset):
|
||||
"""Test delete works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -585,7 +585,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from backend.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from backend.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
class TestDocumentRepository:
|
||||
@@ -95,7 +95,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_create_returns_document_id(self, repo):
|
||||
"""Test create returns document ID."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -113,7 +113,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_create_with_all_parameters(self, repo):
|
||||
"""Test create with all optional parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -145,7 +145,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_returns_document(self, repo, sample_document):
|
||||
"""Test get returns document when exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -159,7 +159,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -187,7 +187,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_no_filters(self, repo, sample_document):
|
||||
"""Test get_paginated with no filters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -201,7 +201,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -214,7 +214,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with upload_source filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -227,7 +227,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with auto_label_status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -241,7 +241,7 @@ class TestDocumentRepository:
|
||||
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with batch_id filter."""
|
||||
batch_id = str(uuid4())
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -254,7 +254,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_category_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with category filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -267,7 +267,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=True."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -280,7 +280,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -297,7 +297,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status(self, repo, sample_document):
|
||||
"""Test update_status updates document status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -310,7 +310,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_with_auto_label_status(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -326,7 +326,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_with_auto_label_error(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_error."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -342,7 +342,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_document_not_found(self, repo):
|
||||
"""Test update_status when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -358,7 +358,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_file_path(self, repo, sample_document):
|
||||
"""Test update_file_path updates document file path."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -371,7 +371,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_file_path_document_not_found(self, repo):
|
||||
"""Test update_file_path when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -387,7 +387,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_group_key_returns_true(self, repo, sample_document):
|
||||
"""Test update_group_key returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -400,7 +400,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_group_key_returns_false(self, repo):
|
||||
"""Test update_group_key returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -416,7 +416,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_category(self, repo, sample_document):
|
||||
"""Test update_category updates document category."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -429,7 +429,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_category_returns_none_when_not_found(self, repo):
|
||||
"""Test update_category returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -445,7 +445,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_delete_returns_true_when_exists(self, repo, sample_document):
|
||||
"""Test delete returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -460,7 +460,7 @@ class TestDocumentRepository:
|
||||
def test_delete_with_annotations(self, repo, sample_document):
|
||||
"""Test delete removes annotations before deleting document."""
|
||||
annotation = MagicMock()
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = [annotation]
|
||||
@@ -474,7 +474,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_exists(self, repo):
|
||||
"""Test delete returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -490,7 +490,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_categories(self, repo):
|
||||
"""Test get_categories returns unique categories."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -506,7 +506,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_labeled_for_export(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export returns labeled documents."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -519,7 +519,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export with admin_token filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -535,7 +535,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_count_by_status(self, repo):
|
||||
"""Test count_by_status returns status counts."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [
|
||||
("pending", 10),
|
||||
@@ -554,7 +554,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_by_ids(self, repo, sample_document):
|
||||
"""Test get_by_ids returns documents by IDs."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -570,7 +570,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_basic(self, repo, labeled_document):
|
||||
"""Test get_for_training with default parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -584,7 +584,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
|
||||
"""Test get_for_training with min_annotation_count."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -597,7 +597,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_exclude_used(self, repo, labeled_document):
|
||||
"""Test get_for_training with exclude_used_in_training."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -610,7 +610,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_no_annotations(self, repo, labeled_document):
|
||||
"""Test get_for_training with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -628,7 +628,7 @@ class TestDocumentRepository:
|
||||
def test_acquire_annotation_lock_success(self, repo, sample_document):
|
||||
"""Test acquire_annotation_lock when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -641,7 +641,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
|
||||
"""Test acquire_annotation_lock fails when document is already locked."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -653,7 +653,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_acquire_annotation_lock_document_not_found(self, repo):
|
||||
"""Test acquire_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -669,7 +669,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_release_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test release_annotation_lock releases the lock."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -682,7 +682,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_release_annotation_lock_document_not_found(self, repo):
|
||||
"""Test release_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -699,7 +699,7 @@ class TestDocumentRepository:
|
||||
def test_extend_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test extend_annotation_lock extends the lock."""
|
||||
original_lock = locked_document.annotation_lock_until
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -713,7 +713,7 @@ class TestDocumentRepository:
|
||||
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
|
||||
"""Test extend_annotation_lock fails when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -725,7 +725,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
|
||||
"""Test extend_annotation_lock fails when lock is expired."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_lock_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -737,7 +737,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_extend_annotation_lock_document_not_found(self, repo):
|
||||
"""Test extend_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from backend.data.admin_models import ModelVersion
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionRepository:
|
||||
@@ -62,7 +62,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_create_returns_model(self, repo):
|
||||
"""Test create returns created model version."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -82,7 +82,7 @@ class TestModelVersionRepository:
|
||||
dataset_id = uuid4()
|
||||
trained_at = datetime.now(timezone.utc)
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -115,7 +115,7 @@ class TestModelVersionRepository:
|
||||
task_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -134,7 +134,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_create_without_optional_ids(self, repo):
|
||||
"""Test create without task_id and dataset_id."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -155,7 +155,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_returns_model(self, repo, sample_model):
|
||||
"""Test get returns model when exists."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -169,7 +169,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_model):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -181,7 +181,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -198,7 +198,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
|
||||
"""Test get_paginated returns list of models and total count."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -212,7 +212,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_model):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -225,7 +225,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_model):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -238,7 +238,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -256,7 +256,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_active_returns_active_model(self, repo, active_model):
|
||||
"""Test get_active returns the active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -270,7 +270,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_active_returns_none(self, repo):
|
||||
"""Test get_active returns None when no active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -287,7 +287,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_activates_model(self, repo, sample_model, active_model):
|
||||
"""Test activate sets model as active and deactivates others."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [active_model]
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -304,7 +304,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_with_uuid(self, repo, sample_model):
|
||||
"""Test activate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -318,7 +318,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_returns_none_when_not_found(self, repo):
|
||||
"""Test activate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = None
|
||||
@@ -332,7 +332,7 @@ class TestModelVersionRepository:
|
||||
def test_activate_sets_activated_at(self, repo, sample_model):
|
||||
"""Test activate sets activated_at timestamp."""
|
||||
sample_model.activated_at = None
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -349,7 +349,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_deactivates_model(self, repo, active_model):
|
||||
"""Test deactivate sets model as inactive."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -364,7 +364,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_with_uuid(self, repo, active_model):
|
||||
"""Test deactivate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -376,7 +376,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_returns_none_when_not_found(self, repo):
|
||||
"""Test deactivate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -392,7 +392,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_updates_model(self, repo, sample_model):
|
||||
"""Test update updates model metadata."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -409,7 +409,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_all_fields(self, repo, sample_model):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -428,7 +428,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_with_uuid(self, repo, sample_model):
|
||||
"""Test update works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -440,7 +440,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_returns_none_when_not_found(self, repo):
|
||||
"""Test update returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -453,7 +453,7 @@ class TestModelVersionRepository:
|
||||
def test_update_partial_fields(self, repo, sample_model):
|
||||
"""Test update only updates provided fields."""
|
||||
original_name = sample_model.name
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -474,7 +474,7 @@ class TestModelVersionRepository:
|
||||
def test_archive_archives_model(self, repo, sample_model):
|
||||
"""Test archive sets model status to archived."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -489,7 +489,7 @@ class TestModelVersionRepository:
|
||||
def test_archive_with_uuid(self, repo, sample_model):
|
||||
"""Test archive works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -501,7 +501,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_archive_returns_none_when_not_found(self, repo):
|
||||
"""Test archive returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -513,7 +513,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_archive_returns_none_when_active(self, repo, active_model):
|
||||
"""Test archive returns None when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -530,7 +530,7 @@ class TestModelVersionRepository:
|
||||
def test_delete_returns_true(self, repo, sample_model):
|
||||
"""Test delete returns True when model exists and not active."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -545,7 +545,7 @@ class TestModelVersionRepository:
|
||||
def test_delete_with_uuid(self, repo, sample_model):
|
||||
"""Test delete works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -557,7 +557,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -570,7 +570,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_delete_returns_false_when_active(self, repo, active_model):
|
||||
"""Test delete returns False when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
from backend.data.admin_models import AdminToken
|
||||
from backend.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenRepository:
|
||||
@@ -55,7 +55,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_true_for_active_token(self, repo, sample_token):
|
||||
"""Test is_valid returns True for an active, non-expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -68,7 +68,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_nonexistent_token(self, repo):
|
||||
"""Test is_valid returns False for a non-existent token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -80,7 +80,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_inactive_token(self, repo, inactive_token):
|
||||
"""Test is_valid returns False for an inactive token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = inactive_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -92,7 +92,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_expired_token(self, repo, expired_token):
|
||||
"""Test is_valid returns False for an expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -104,7 +104,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_get_returns_token_when_exists(self, repo, sample_token):
|
||||
"""Test get returns token when it exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -119,7 +119,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_get_returns_none_when_not_exists(self, repo):
|
||||
"""Test get returns None when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -131,7 +131,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_create_new_token(self, repo):
|
||||
"""Test creating a new token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None # Token doesn't exist
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -147,7 +147,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_create_updates_existing_token(self, repo, sample_token):
|
||||
"""Test create updates an existing token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -161,7 +161,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_update_usage(self, repo, sample_token):
|
||||
"""Test updating token last_used_at timestamp."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -174,7 +174,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_deactivate_returns_true_when_token_exists(self, repo, sample_token):
|
||||
"""Test deactivate returns True when token exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -188,7 +188,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_deactivate_returns_false_when_token_not_exists(self, repo):
|
||||
"""Test deactivate returns False when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from backend.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from backend.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskRepository:
|
||||
@@ -65,7 +65,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_returns_task_id(self, repo):
|
||||
"""Test create returns task ID."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -82,7 +82,7 @@ class TestTrainingTaskRepository:
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -108,7 +108,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_pending_status_when_not_scheduled(self, repo):
|
||||
"""Test create sets pending status when no scheduled_at."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -124,7 +124,7 @@ class TestTrainingTaskRepository:
|
||||
def test_create_scheduled_status_when_scheduled(self, repo):
|
||||
"""Test create sets scheduled status when scheduled_at is provided."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -144,7 +144,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_returns_task(self, repo, sample_task):
|
||||
"""Test get returns task when exists."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -158,7 +158,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -175,7 +175,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_by_token_returns_task(self, repo, sample_task):
|
||||
"""Test get_by_token returns task (delegates to get)."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -187,7 +187,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_by_token_without_token_param(self, repo, sample_task):
|
||||
"""Test get_by_token works without token parameter."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -203,7 +203,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_returns_tasks_and_total(self, repo, sample_task):
|
||||
"""Test get_paginated returns list of tasks and total count."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -217,7 +217,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_task):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -230,7 +230,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_task):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -243,7 +243,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -261,7 +261,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_pending_returns_pending_tasks(self, repo, sample_task):
|
||||
"""Test get_pending returns pending and scheduled tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -273,7 +273,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_pending_returns_empty_list(self, repo):
|
||||
"""Test get_pending returns empty list when no pending tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -289,7 +289,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_updates_task(self, repo, sample_task):
|
||||
"""Test update_status updates task status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -302,7 +302,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_started_at_for_running(self, repo, sample_task):
|
||||
"""Test update_status sets started_at when status is running."""
|
||||
sample_task.started_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -315,7 +315,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_completed_at_for_completed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is completed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -328,7 +328,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_completed_at_for_failed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is failed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -341,7 +341,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_with_result_metrics(self, repo, sample_task):
|
||||
"""Test update_status with result metrics."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -357,7 +357,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_with_model_path(self, repo, sample_task):
|
||||
"""Test update_status with model path."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -373,7 +373,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -390,7 +390,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_true_for_pending(self, repo, sample_task):
|
||||
"""Test cancel returns True for pending task."""
|
||||
sample_task.status = "pending"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -404,7 +404,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_true_for_scheduled(self, repo, sample_task):
|
||||
"""Test cancel returns True for scheduled task."""
|
||||
sample_task.status = "scheduled"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -418,7 +418,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_false_for_running(self, repo, sample_task):
|
||||
"""Test cancel returns False for running task."""
|
||||
sample_task.status = "running"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -430,7 +430,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_cancel_returns_false_when_not_found(self, repo):
|
||||
"""Test cancel returns False when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -446,7 +446,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_add_log_creates_log_entry(self, repo):
|
||||
"""Test add_log creates a log entry."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -464,7 +464,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_add_log_with_details(self, repo):
|
||||
"""Test add_log with details."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -485,7 +485,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_returns_list(self, repo, sample_log):
|
||||
"""Test get_logs returns list of logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -498,7 +498,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_with_pagination(self, repo, sample_log):
|
||||
"""Test get_logs with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -510,7 +510,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_returns_empty_list(self, repo):
|
||||
"""Test get_logs returns empty list when no logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -526,7 +526,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_document_link_returns_link(self, repo):
|
||||
"""Test create_document_link returns created link."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -543,7 +543,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_document_link_with_snapshot(self, repo):
|
||||
"""Test create_document_link with annotation snapshot."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -564,7 +564,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_links_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_links returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -576,7 +576,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_links_returns_empty_list(self, repo):
|
||||
"""Test get_document_links returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -592,7 +592,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_training_tasks_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_training_tasks returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -604,7 +604,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_training_tasks_returns_empty_list(self, repo):
|
||||
"""Test get_document_training_tasks returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
BatchUpload,
|
||||
BatchUploadFile,
|
||||
TrainingDocumentLink,
|
||||
|
||||
Reference in New Issue
Block a user