re-structure
This commit is contained in:
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.admin_models import AdminAnnotation, AnnotationHistory
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepository:
|
||||
@@ -66,7 +66,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_returns_annotation_id(self, repo):
|
||||
"""Test create returns annotation ID."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -92,7 +92,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_with_optional_params(self, repo):
|
||||
"""Test create with optional text_value and confidence."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -124,7 +124,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_default_source_is_manual(self, repo):
|
||||
"""Test create uses manual as default source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -153,7 +153,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_returns_ids(self, repo):
|
||||
"""Test create_batch returns list of annotation IDs."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -195,7 +195,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_default_page_number(self, repo):
|
||||
"""Test create_batch uses page_number=1 by default."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -224,7 +224,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_with_all_optional_params(self, repo):
|
||||
"""Test create_batch with all optional parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -259,7 +259,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_batch_empty_list(self, repo):
|
||||
"""Test create_batch with empty list returns empty."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -275,7 +275,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_returns_annotation(self, repo, sample_annotation):
|
||||
"""Test get returns annotation when exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -289,7 +289,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -306,7 +306,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
|
||||
"""Test get_for_document returns all annotations for document."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -319,7 +319,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
|
||||
"""Test get_for_document filters by page number."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -331,7 +331,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_for_document_returns_empty_list(self, repo):
|
||||
"""Test get_for_document returns empty list when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -347,7 +347,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_returns_true(self, repo, sample_annotation):
|
||||
"""Test update returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -363,7 +363,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_returns_false_when_not_found(self, repo):
|
||||
"""Test update returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -375,7 +375,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_update_all_fields(self, repo, sample_annotation):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -412,7 +412,7 @@ class TestAnnotationRepository:
|
||||
def test_update_partial_fields(self, repo, sample_annotation):
|
||||
"""Test update only updates provided fields."""
|
||||
original_x = sample_annotation.x_center
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -433,7 +433,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_annotation):
|
||||
"""Test delete returns True when annotation exists."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -446,7 +446,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -463,7 +463,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_returns_count(self, repo, sample_annotation):
|
||||
"""Test delete_for_document returns count of deleted annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -476,7 +476,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
|
||||
"""Test delete_for_document filters by source."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_annotation]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -488,7 +488,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_delete_for_document_returns_zero(self, repo):
|
||||
"""Test delete_for_document returns 0 when no annotations."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -506,7 +506,7 @@ class TestAnnotationRepository:
|
||||
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
|
||||
"""Test verify marks annotation as verified."""
|
||||
sample_annotation.is_verified = False
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -521,7 +521,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_verify_returns_none_when_not_found(self, repo):
|
||||
"""Test verify returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -538,7 +538,7 @@ class TestAnnotationRepository:
|
||||
def test_override_updates_annotation(self, repo, sample_annotation):
|
||||
"""Test override updates annotation and creates history."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -559,7 +559,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_override_returns_none_when_not_found(self, repo):
|
||||
"""Test override returns None when annotation not found."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -573,7 +573,7 @@ class TestAnnotationRepository:
|
||||
"""Test override does not change override_source if already manual."""
|
||||
sample_annotation.source = "manual"
|
||||
sample_annotation.override_source = None
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -591,7 +591,7 @@ class TestAnnotationRepository:
|
||||
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
|
||||
"""Test override ignores unknown attributes."""
|
||||
sample_annotation.source = "auto"
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_annotation
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -614,7 +614,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_history_returns_history(self, repo):
|
||||
"""Test create_history returns created history record."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -636,7 +636,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_create_history_with_minimal_params(self, repo):
|
||||
"""Test create_history with minimal parameters."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -659,7 +659,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -672,7 +672,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_history_returns_empty_list(self, repo):
|
||||
"""Test get_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -688,7 +688,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_document_history_returns_list(self, repo, sample_history):
|
||||
"""Test get_document_history returns list of history records."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_history]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -700,7 +700,7 @@ class TestAnnotationRepository:
|
||||
|
||||
def test_get_document_history_returns_empty_list(self, repo):
|
||||
"""Test get_document_history returns empty list when no history."""
|
||||
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,7 +9,7 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.repositories.base import BaseRepository
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
|
||||
class ConcreteRepository(BaseRepository[MagicMock]):
|
||||
@@ -31,7 +31,7 @@ class TestBaseRepository:
|
||||
|
||||
def test_session_yields_session(self, repo):
|
||||
"""Test _session yields a database session."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
from backend.data.admin_models import BatchUpload, BatchUploadFile
|
||||
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadRepository:
|
||||
@@ -54,7 +54,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_returns_batch(self, repo):
|
||||
"""Test create returns created batch upload."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -70,7 +70,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_with_upload_source(self, repo):
|
||||
"""Test create with custom upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -87,7 +87,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_default_upload_source(self, repo):
|
||||
"""Test create uses default upload source."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -107,7 +107,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_returns_batch(self, repo, sample_batch):
|
||||
"""Test get returns batch when exists."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -121,7 +121,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -138,7 +138,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_updates_batch(self, repo, sample_batch):
|
||||
"""Test update updates batch fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -156,7 +156,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_ignores_unknown_fields(self, repo, sample_batch):
|
||||
"""Test update ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -171,7 +171,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_not_found(self, repo):
|
||||
"""Test update does nothing when batch not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -183,7 +183,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_multiple_fields(self, repo, sample_batch):
|
||||
"""Test update can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_batch
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -206,7 +206,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_file_returns_file(self, repo):
|
||||
"""Test create_file returns created file record."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -221,7 +221,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_create_file_with_kwargs(self, repo):
|
||||
"""Test create_file with additional kwargs."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -242,7 +242,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_updates_file(self, repo, sample_file):
|
||||
"""Test update_file updates file fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -258,7 +258,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
|
||||
"""Test update_file ignores unknown fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -273,7 +273,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_not_found(self, repo):
|
||||
"""Test update_file does nothing when file not found."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -285,7 +285,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_update_file_multiple_fields(self, repo, sample_file):
|
||||
"""Test update_file can update multiple fields."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_file
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -304,7 +304,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_files_returns_list(self, repo, sample_file):
|
||||
"""Test get_files returns list of files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_file]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -317,7 +317,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_files_returns_empty_list(self, repo):
|
||||
"""Test get_files returns empty list when no files."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -333,7 +333,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
|
||||
"""Test get_paginated returns list of batches and total count."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
@@ -347,7 +347,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_batch):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 100
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
@@ -360,7 +360,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -374,7 +374,7 @@ class TestBatchUploadRepository:
|
||||
|
||||
def test_get_paginated_with_admin_token(self, repo, sample_batch):
|
||||
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
|
||||
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_batch]
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepository:
|
||||
@@ -69,7 +69,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_returns_dataset(self, repo):
|
||||
"""Test create returns created dataset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -81,7 +81,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -103,7 +103,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_create_default_values(self, repo):
|
||||
"""Test create uses default values."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -121,7 +121,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_returns_dataset(self, repo, sample_dataset):
|
||||
"""Test get returns dataset when exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -135,7 +135,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_dataset):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -147,7 +147,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -164,7 +164,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
|
||||
"""Test get_paginated returns list of datasets and total count."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -178,7 +178,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -191,7 +191,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_dataset):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset]
|
||||
@@ -204,7 +204,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -222,7 +222,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
|
||||
"""Test get_active_training_tasks returns dict of active tasks."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_training_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -240,7 +240,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_active_training_tasks_invalid_uuid(self, repo):
|
||||
"""Test get_active_training_tasks filters invalid UUIDs."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -263,7 +263,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_status updates dataset status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -276,7 +276,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_error_message(self, repo, sample_dataset):
|
||||
"""Test update_status with error message."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -292,7 +292,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_totals(self, repo, sample_dataset):
|
||||
"""Test update_status with total counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -312,7 +312,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_dataset_path(self, repo, sample_dataset):
|
||||
"""Test update_status with dataset path."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -328,7 +328,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_with_uuid(self, repo, sample_dataset):
|
||||
"""Test update_status works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -340,7 +340,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -356,7 +356,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates training status."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -370,7 +370,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_with_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status with active task ID."""
|
||||
task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -387,7 +387,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
|
||||
"""Test update_training_status updates main status when completed."""
|
||||
sample_dataset.status = "ready"
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -405,7 +405,7 @@ class TestDatasetRepository:
|
||||
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
|
||||
"""Test update_training_status clears task ID when None."""
|
||||
sample_dataset.active_training_task_id = uuid4()
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -421,7 +421,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_update_training_status_not_found(self, repo):
|
||||
"""Test update_training_status does nothing when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -437,7 +437,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_creates_links(self, repo):
|
||||
"""Test add_documents creates dataset document links."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -464,7 +464,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_default_counts(self, repo):
|
||||
"""Test add_documents uses default counts."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -484,7 +484,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_with_uuid(self, repo):
|
||||
"""Test add_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -502,7 +502,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_add_documents_empty_list(self, repo):
|
||||
"""Test add_documents with empty list."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -518,7 +518,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_returns_list(self, repo, sample_dataset_document):
|
||||
"""Test get_documents returns list of dataset documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -531,7 +531,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
|
||||
"""Test get_documents works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -543,7 +543,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_get_documents_returns_empty_list(self, repo):
|
||||
"""Test get_documents returns empty list when no documents."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -559,7 +559,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_returns_true(self, repo, sample_dataset):
|
||||
"""Test delete returns True when dataset exists."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -573,7 +573,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_with_uuid(self, repo, sample_dataset):
|
||||
"""Test delete works with UUID object."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_dataset
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -585,7 +585,7 @@ class TestDatasetRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when dataset not found."""
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from backend.data.admin_models import AdminDocument, AdminAnnotation
|
||||
from backend.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
class TestDocumentRepository:
|
||||
@@ -95,7 +95,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_create_returns_document_id(self, repo):
|
||||
"""Test create returns document ID."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -113,7 +113,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_create_with_all_parameters(self, repo):
|
||||
"""Test create with all optional parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -145,7 +145,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_returns_document(self, repo, sample_document):
|
||||
"""Test get returns document when exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -159,7 +159,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -187,7 +187,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_no_filters(self, repo, sample_document):
|
||||
"""Test get_paginated with no filters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -201,7 +201,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -214,7 +214,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with upload_source filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -227,7 +227,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with auto_label_status filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -241,7 +241,7 @@ class TestDocumentRepository:
|
||||
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with batch_id filter."""
|
||||
batch_id = str(uuid4())
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -254,7 +254,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_category_filter(self, repo, sample_document):
|
||||
"""Test get_paginated with category filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -267,7 +267,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=True."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -280,7 +280,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
|
||||
"""Test get_paginated with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
@@ -297,7 +297,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status(self, repo, sample_document):
|
||||
"""Test update_status updates document status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -310,7 +310,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_with_auto_label_status(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_status."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -326,7 +326,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_with_auto_label_error(self, repo, sample_document):
|
||||
"""Test update_status with auto_label_error."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -342,7 +342,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_status_document_not_found(self, repo):
|
||||
"""Test update_status when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -358,7 +358,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_file_path(self, repo, sample_document):
|
||||
"""Test update_file_path updates document file path."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -371,7 +371,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_file_path_document_not_found(self, repo):
|
||||
"""Test update_file_path when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -387,7 +387,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_group_key_returns_true(self, repo, sample_document):
|
||||
"""Test update_group_key returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -400,7 +400,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_group_key_returns_false(self, repo):
|
||||
"""Test update_group_key returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -416,7 +416,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_category(self, repo, sample_document):
|
||||
"""Test update_category updates document category."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -429,7 +429,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_update_category_returns_none_when_not_found(self, repo):
|
||||
"""Test update_category returns None when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -445,7 +445,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_delete_returns_true_when_exists(self, repo, sample_document):
|
||||
"""Test delete returns True when document exists."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -460,7 +460,7 @@ class TestDocumentRepository:
|
||||
def test_delete_with_annotations(self, repo, sample_document):
|
||||
"""Test delete removes annotations before deleting document."""
|
||||
annotation = MagicMock()
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_session.exec.return_value.all.return_value = [annotation]
|
||||
@@ -474,7 +474,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_exists(self, repo):
|
||||
"""Test delete returns False when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -490,7 +490,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_categories(self, repo):
|
||||
"""Test get_categories returns unique categories."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -506,7 +506,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_labeled_for_export(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export returns labeled documents."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -519,7 +519,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
|
||||
"""Test get_labeled_for_export with admin_token filter."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -535,7 +535,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_count_by_status(self, repo):
|
||||
"""Test count_by_status returns status counts."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [
|
||||
("pending", 10),
|
||||
@@ -554,7 +554,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_by_ids(self, repo, sample_document):
|
||||
"""Test get_by_ids returns documents by IDs."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_document]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -570,7 +570,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_basic(self, repo, labeled_document):
|
||||
"""Test get_for_training with default parameters."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -584,7 +584,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
|
||||
"""Test get_for_training with min_annotation_count."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -597,7 +597,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_exclude_used(self, repo, labeled_document):
|
||||
"""Test get_for_training with exclude_used_in_training."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -610,7 +610,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_get_for_training_no_annotations(self, repo, labeled_document):
|
||||
"""Test get_for_training with has_annotations=False."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [labeled_document]
|
||||
@@ -628,7 +628,7 @@ class TestDocumentRepository:
|
||||
def test_acquire_annotation_lock_success(self, repo, sample_document):
|
||||
"""Test acquire_annotation_lock when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -641,7 +641,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
|
||||
"""Test acquire_annotation_lock fails when document is already locked."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -653,7 +653,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_acquire_annotation_lock_document_not_found(self, repo):
|
||||
"""Test acquire_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -669,7 +669,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_release_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test release_annotation_lock releases the lock."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -682,7 +682,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_release_annotation_lock_document_not_found(self, repo):
|
||||
"""Test release_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -699,7 +699,7 @@ class TestDocumentRepository:
|
||||
def test_extend_annotation_lock_success(self, repo, locked_document):
|
||||
"""Test extend_annotation_lock extends the lock."""
|
||||
original_lock = locked_document.annotation_lock_until
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = locked_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -713,7 +713,7 @@ class TestDocumentRepository:
|
||||
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
|
||||
"""Test extend_annotation_lock fails when no lock exists."""
|
||||
sample_document.annotation_lock_until = None
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -725,7 +725,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
|
||||
"""Test extend_annotation_lock fails when lock is expired."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_lock_document
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -737,7 +737,7 @@ class TestDocumentRepository:
|
||||
|
||||
def test_extend_annotation_lock_document_not_found(self, repo):
|
||||
"""Test extend_annotation_lock when document not found."""
|
||||
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from backend.data.admin_models import ModelVersion
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionRepository:
|
||||
@@ -62,7 +62,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_create_returns_model(self, repo):
|
||||
"""Test create returns created model version."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -82,7 +82,7 @@ class TestModelVersionRepository:
|
||||
dataset_id = uuid4()
|
||||
trained_at = datetime.now(timezone.utc)
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -115,7 +115,7 @@ class TestModelVersionRepository:
|
||||
task_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -134,7 +134,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_create_without_optional_ids(self, repo):
|
||||
"""Test create without task_id and dataset_id."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -155,7 +155,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_returns_model(self, repo, sample_model):
|
||||
"""Test get returns model when exists."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -169,7 +169,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_with_uuid(self, repo, sample_model):
|
||||
"""Test get works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -181,7 +181,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -198,7 +198,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
|
||||
"""Test get_paginated returns list of models and total count."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -212,7 +212,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_model):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -225,7 +225,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_model):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_model]
|
||||
@@ -238,7 +238,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -256,7 +256,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_active_returns_active_model(self, repo, active_model):
|
||||
"""Test get_active returns the active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -270,7 +270,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_get_active_returns_none(self, repo):
|
||||
"""Test get_active returns None when no active model."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.first.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -287,7 +287,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_activates_model(self, repo, sample_model, active_model):
|
||||
"""Test activate sets model as active and deactivates others."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [active_model]
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -304,7 +304,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_with_uuid(self, repo, sample_model):
|
||||
"""Test activate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -318,7 +318,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_activate_returns_none_when_not_found(self, repo):
|
||||
"""Test activate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = None
|
||||
@@ -332,7 +332,7 @@ class TestModelVersionRepository:
|
||||
def test_activate_sets_activated_at(self, repo, sample_model):
|
||||
"""Test activate sets activated_at timestamp."""
|
||||
sample_model.activated_at = None
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_session.get.return_value = sample_model
|
||||
@@ -349,7 +349,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_deactivates_model(self, repo, active_model):
|
||||
"""Test deactivate sets model as inactive."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -364,7 +364,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_with_uuid(self, repo, active_model):
|
||||
"""Test deactivate works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -376,7 +376,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_deactivate_returns_none_when_not_found(self, repo):
|
||||
"""Test deactivate returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -392,7 +392,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_updates_model(self, repo, sample_model):
|
||||
"""Test update updates model metadata."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -409,7 +409,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_all_fields(self, repo, sample_model):
|
||||
"""Test update can update all fields."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -428,7 +428,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_with_uuid(self, repo, sample_model):
|
||||
"""Test update works with UUID object."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -440,7 +440,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_update_returns_none_when_not_found(self, repo):
|
||||
"""Test update returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -453,7 +453,7 @@ class TestModelVersionRepository:
|
||||
def test_update_partial_fields(self, repo, sample_model):
|
||||
"""Test update only updates provided fields."""
|
||||
original_name = sample_model.name
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -474,7 +474,7 @@ class TestModelVersionRepository:
|
||||
def test_archive_archives_model(self, repo, sample_model):
|
||||
"""Test archive sets model status to archived."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -489,7 +489,7 @@ class TestModelVersionRepository:
|
||||
def test_archive_with_uuid(self, repo, sample_model):
|
||||
"""Test archive works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -501,7 +501,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_archive_returns_none_when_not_found(self, repo):
|
||||
"""Test archive returns None when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -513,7 +513,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_archive_returns_none_when_active(self, repo, active_model):
|
||||
"""Test archive returns None when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -530,7 +530,7 @@ class TestModelVersionRepository:
|
||||
def test_delete_returns_true(self, repo, sample_model):
|
||||
"""Test delete returns True when model exists and not active."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -545,7 +545,7 @@ class TestModelVersionRepository:
|
||||
def test_delete_with_uuid(self, repo, sample_model):
|
||||
"""Test delete works with UUID object."""
|
||||
sample_model.is_active = False
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -557,7 +557,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_delete_returns_false_when_not_found(self, repo):
|
||||
"""Test delete returns False when model not found."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -570,7 +570,7 @@ class TestModelVersionRepository:
|
||||
|
||||
def test_delete_returns_false_when_active(self, repo, active_model):
|
||||
"""Test delete returns False when model is active."""
|
||||
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = active_model
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -8,8 +8,8 @@ import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
from backend.data.admin_models import AdminToken
|
||||
from backend.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenRepository:
|
||||
@@ -55,7 +55,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_true_for_active_token(self, repo, sample_token):
|
||||
"""Test is_valid returns True for an active, non-expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -68,7 +68,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_nonexistent_token(self, repo):
|
||||
"""Test is_valid returns False for a non-existent token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -80,7 +80,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_inactive_token(self, repo, inactive_token):
|
||||
"""Test is_valid returns False for an inactive token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = inactive_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -92,7 +92,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_returns_false_for_expired_token(self, repo, expired_token):
|
||||
"""Test is_valid returns False for an expired token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = expired_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -104,7 +104,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_get_returns_token_when_exists(self, repo, sample_token):
|
||||
"""Test get returns token when it exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -119,7 +119,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_get_returns_none_when_not_exists(self, repo):
|
||||
"""Test get returns None when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -131,7 +131,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_create_new_token(self, repo):
|
||||
"""Test creating a new token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None # Token doesn't exist
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -147,7 +147,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_create_updates_existing_token(self, repo, sample_token):
|
||||
"""Test create updates an existing token."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -161,7 +161,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_update_usage(self, repo, sample_token):
|
||||
"""Test updating token last_used_at timestamp."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -174,7 +174,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_deactivate_returns_true_when_token_exists(self, repo, sample_token):
|
||||
"""Test deactivate returns True when token exists."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_token
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -188,7 +188,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_deactivate_returns_false_when_token_not_exists(self, repo):
|
||||
"""Test deactivate returns False when token doesn't exist."""
|
||||
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,8 +9,8 @@ from datetime import datetime, timezone, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from backend.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
|
||||
from backend.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskRepository:
|
||||
@@ -65,7 +65,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_returns_task_id(self, repo):
|
||||
"""Test create returns task ID."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -82,7 +82,7 @@ class TestTrainingTaskRepository:
|
||||
def test_create_with_all_params(self, repo):
|
||||
"""Test create with all parameters."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -108,7 +108,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_pending_status_when_not_scheduled(self, repo):
|
||||
"""Test create sets pending status when no scheduled_at."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -124,7 +124,7 @@ class TestTrainingTaskRepository:
|
||||
def test_create_scheduled_status_when_scheduled(self, repo):
|
||||
"""Test create sets scheduled status when scheduled_at is provided."""
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -144,7 +144,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_returns_task(self, repo, sample_task):
|
||||
"""Test get returns task when exists."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -158,7 +158,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_returns_none_when_not_found(self, repo):
|
||||
"""Test get returns None when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -175,7 +175,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_by_token_returns_task(self, repo, sample_task):
|
||||
"""Test get_by_token returns task (delegates to get)."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -187,7 +187,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_by_token_without_token_param(self, repo, sample_task):
|
||||
"""Test get_by_token works without token parameter."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -203,7 +203,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_returns_tasks_and_total(self, repo, sample_task):
|
||||
"""Test get_paginated returns list of tasks and total count."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -217,7 +217,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_with_status_filter(self, repo, sample_task):
|
||||
"""Test get_paginated filters by status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 1
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -230,7 +230,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_with_pagination(self, repo, sample_task):
|
||||
"""Test get_paginated with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 50
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
@@ -243,7 +243,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_paginated_empty_results(self, repo):
|
||||
"""Test get_paginated with no results."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.one.return_value = 0
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
@@ -261,7 +261,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_pending_returns_pending_tasks(self, repo, sample_task):
|
||||
"""Test get_pending returns pending and scheduled tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_task]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -273,7 +273,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_pending_returns_empty_list(self, repo):
|
||||
"""Test get_pending returns empty list when no pending tasks."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -289,7 +289,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_updates_task(self, repo, sample_task):
|
||||
"""Test update_status updates task status."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -302,7 +302,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_started_at_for_running(self, repo, sample_task):
|
||||
"""Test update_status sets started_at when status is running."""
|
||||
sample_task.started_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -315,7 +315,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_completed_at_for_completed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is completed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -328,7 +328,7 @@ class TestTrainingTaskRepository:
|
||||
def test_update_status_sets_completed_at_for_failed(self, repo, sample_task):
|
||||
"""Test update_status sets completed_at when status is failed."""
|
||||
sample_task.completed_at = None
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -341,7 +341,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_with_result_metrics(self, repo, sample_task):
|
||||
"""Test update_status with result metrics."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -357,7 +357,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_with_model_path(self, repo, sample_task):
|
||||
"""Test update_status with model path."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -373,7 +373,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_update_status_not_found(self, repo):
|
||||
"""Test update_status does nothing when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -390,7 +390,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_true_for_pending(self, repo, sample_task):
|
||||
"""Test cancel returns True for pending task."""
|
||||
sample_task.status = "pending"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -404,7 +404,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_true_for_scheduled(self, repo, sample_task):
|
||||
"""Test cancel returns True for scheduled task."""
|
||||
sample_task.status = "scheduled"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -418,7 +418,7 @@ class TestTrainingTaskRepository:
|
||||
def test_cancel_returns_false_for_running(self, repo, sample_task):
|
||||
"""Test cancel returns False for running task."""
|
||||
sample_task.status = "running"
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = sample_task
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -430,7 +430,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_cancel_returns_false_when_not_found(self, repo):
|
||||
"""Test cancel returns False when task not found."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.get.return_value = None
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -446,7 +446,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_add_log_creates_log_entry(self, repo):
|
||||
"""Test add_log creates a log entry."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -464,7 +464,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_add_log_with_details(self, repo):
|
||||
"""Test add_log with details."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -485,7 +485,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_returns_list(self, repo, sample_log):
|
||||
"""Test get_logs returns list of logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -498,7 +498,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_with_pagination(self, repo, sample_log):
|
||||
"""Test get_logs with limit and offset."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_log]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -510,7 +510,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_logs_returns_empty_list(self, repo):
|
||||
"""Test get_logs returns empty list when no logs."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -526,7 +526,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_document_link_returns_link(self, repo):
|
||||
"""Test create_document_link returns created link."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -543,7 +543,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_create_document_link_with_snapshot(self, repo):
|
||||
"""Test create_document_link with annotation snapshot."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
|
||||
@@ -564,7 +564,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_links_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_links returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -576,7 +576,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_links_returns_empty_list(self, repo):
|
||||
"""Test get_document_links returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -592,7 +592,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_training_tasks_returns_list(self, repo, sample_link):
|
||||
"""Test get_document_training_tasks returns list of links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = [sample_link]
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
@@ -604,7 +604,7 @@ class TestTrainingTaskRepository:
|
||||
|
||||
def test_get_document_training_tasks_returns_empty_list(self, repo):
|
||||
"""Test get_document_training_tasks returns empty list when no links."""
|
||||
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_session.exec.return_value.all.return_value = []
|
||||
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
|
||||
@@ -9,7 +9,7 @@ import pytest
|
||||
from datetime import datetime
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
BatchUpload,
|
||||
BatchUploadFile,
|
||||
TrainingDocumentLink,
|
||||
|
||||
@@ -11,8 +11,8 @@ Tests field normalization functions:
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from inference.pipeline.field_extractor import FieldExtractor
|
||||
from inference.pipeline.normalizers import (
|
||||
from backend.pipeline.field_extractor import FieldExtractor
|
||||
from backend.pipeline.normalizers import (
|
||||
InvoiceNumberNormalizer,
|
||||
OcrNumberNormalizer,
|
||||
BankgiroNormalizer,
|
||||
|
||||
@@ -8,7 +8,7 @@ matching variants from known values.
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from inference.pipeline.normalizers import (
|
||||
from backend.pipeline.normalizers import (
|
||||
NormalizationResult,
|
||||
InvoiceNumberNormalizer,
|
||||
OcrNumberNormalizer,
|
||||
@@ -490,7 +490,7 @@ class TestEnhancedAmountNormalizer:
|
||||
# Need input that bypasses labeled patterns AND shared validator
|
||||
# but has decimal pattern matches
|
||||
with patch(
|
||||
"inference.pipeline.normalizers.amount.FieldValidators.parse_amount",
|
||||
"backend.pipeline.normalizers.amount.FieldValidators.parse_amount",
|
||||
return_value=None,
|
||||
):
|
||||
result = normalizer.normalize("Items: 100,00 and 200,00 and 300,00")
|
||||
@@ -560,7 +560,7 @@ class TestDateNormalizer:
|
||||
def test_fallback_pattern_with_mock(self, normalizer):
|
||||
"""Test fallback PATTERNS when shared validator returns None (line 83)."""
|
||||
with patch(
|
||||
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
return_value=None,
|
||||
):
|
||||
result = normalizer.normalize("2025-08-29")
|
||||
@@ -667,7 +667,7 @@ class TestEnhancedDateNormalizer:
|
||||
def test_swedish_pattern_fallback_with_mock(self, normalizer):
|
||||
"""Test Swedish pattern when shared validator returns None (line 170)."""
|
||||
with patch(
|
||||
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
return_value=None,
|
||||
):
|
||||
result = normalizer.normalize("15 maj 2025")
|
||||
@@ -677,7 +677,7 @@ class TestEnhancedDateNormalizer:
|
||||
def test_ymd_compact_fallback_with_mock(self, normalizer):
|
||||
"""Test ymd_compact pattern when shared validator returns None (lines 187-192)."""
|
||||
with patch(
|
||||
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
|
||||
return_value=None,
|
||||
):
|
||||
result = normalizer.normalize("20250315")
|
||||
|
||||
@@ -10,7 +10,7 @@ Tests the cross-validation logic between payment_line and detected fields:
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from inference.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
|
||||
from backend.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
|
||||
|
||||
|
||||
class TestCrossValidationResult:
|
||||
|
||||
@@ -82,7 +82,7 @@ def mock_inference_service():
|
||||
@pytest.fixture
|
||||
def mock_storage_config(temp_storage_dir):
|
||||
"""Create mock storage configuration."""
|
||||
from inference.web.config import StorageConfig
|
||||
from backend.web.config import StorageConfig
|
||||
|
||||
return StorageConfig(
|
||||
upload_dir=temp_storage_dir["uploads"],
|
||||
@@ -104,13 +104,13 @@ def mock_storage_helper(temp_storage_dir):
|
||||
@pytest.fixture
|
||||
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
|
||||
"""Create a test FastAPI application with mocked storage."""
|
||||
from inference.web.api.v1.public.inference import create_inference_router
|
||||
from backend.web.api.v1.public.inference import create_inference_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Patch get_storage_helper to return our mock
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
|
||||
@@ -123,7 +123,7 @@ def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
|
||||
def client(test_app, mock_storage_helper):
|
||||
"""Create a test client with storage helper patched."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
yield TestClient(test_app)
|
||||
@@ -151,7 +151,7 @@ class TestInferenceEndpoint:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -171,7 +171,7 @@ class TestInferenceEndpoint:
|
||||
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -186,7 +186,7 @@ class TestInferenceEndpoint:
|
||||
def test_infer_invalid_file_type(self, client, mock_storage_helper):
|
||||
"""Test rejection of invalid file types."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -199,7 +199,7 @@ class TestInferenceEndpoint:
|
||||
def test_infer_no_file(self, client, mock_storage_helper):
|
||||
"""Test rejection when no file provided."""
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post("/api/v1/infer")
|
||||
@@ -211,7 +211,7 @@ class TestInferenceEndpoint:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -238,7 +238,7 @@ class TestInferenceResultFormat:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -258,7 +258,7 @@ class TestInferenceResultFormat:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -282,7 +282,7 @@ class TestErrorHandling:
|
||||
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -297,7 +297,7 @@ class TestErrorHandling:
|
||||
"""Test handling of empty files."""
|
||||
# Empty file still has valid content type
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -318,7 +318,7 @@ class TestResponseFormat:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -338,7 +338,7 @@ class TestResponseFormat:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -359,7 +359,7 @@ class TestDocumentIdGeneration:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
@@ -376,7 +376,7 @@ class TestDocumentIdGeneration:
|
||||
pdf_content = b"%PDF-1.4\n%test\n"
|
||||
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper",
|
||||
"backend.web.api.v1.public.inference.get_storage_helper",
|
||||
return_value=mock_storage_helper,
|
||||
):
|
||||
response = client.post(
|
||||
|
||||
@@ -11,7 +11,7 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
@@ -20,8 +20,8 @@ from inference.data.admin_models import (
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from inference.web.core.auth import get_admin_token_dep
|
||||
from backend.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from backend.web.core.auth import validate_admin_token
|
||||
|
||||
|
||||
def create_test_app(override_token_dep):
|
||||
@@ -31,7 +31,7 @@ def create_test_app(override_token_dep):
|
||||
app.include_router(router)
|
||||
|
||||
# Override auth dependency
|
||||
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
|
||||
app.dependency_overrides[validate_admin_token] = lambda: override_token_dep
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlmodel import Session, SQLModel, create_engine
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AdminToken,
|
||||
@@ -170,15 +170,15 @@ def patched_session(db_session):
|
||||
|
||||
# All modules that import get_session_context
|
||||
patch_targets = [
|
||||
"inference.data.database.get_session_context",
|
||||
"inference.data.repositories.document_repository.get_session_context",
|
||||
"inference.data.repositories.annotation_repository.get_session_context",
|
||||
"inference.data.repositories.dataset_repository.get_session_context",
|
||||
"inference.data.repositories.training_task_repository.get_session_context",
|
||||
"inference.data.repositories.model_version_repository.get_session_context",
|
||||
"inference.data.repositories.batch_upload_repository.get_session_context",
|
||||
"inference.data.repositories.token_repository.get_session_context",
|
||||
"inference.web.services.dashboard_service.get_session_context",
|
||||
"backend.data.database.get_session_context",
|
||||
"backend.data.repositories.document_repository.get_session_context",
|
||||
"backend.data.repositories.annotation_repository.get_session_context",
|
||||
"backend.data.repositories.dataset_repository.get_session_context",
|
||||
"backend.data.repositories.training_task_repository.get_session_context",
|
||||
"backend.data.repositories.model_version_repository.get_session_context",
|
||||
"backend.data.repositories.batch_upload_repository.get_session_context",
|
||||
"backend.data.repositories.token_repository.get_session_context",
|
||||
"backend.web.services.dashboard_service.get_session_context",
|
||||
]
|
||||
|
||||
with ExitStack() as stack:
|
||||
|
||||
@@ -14,13 +14,13 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
import numpy as np
|
||||
|
||||
from inference.pipeline.pipeline import (
|
||||
from backend.pipeline.pipeline import (
|
||||
InferencePipeline,
|
||||
InferenceResult,
|
||||
CrossValidationResult,
|
||||
)
|
||||
from inference.pipeline.yolo_detector import Detection
|
||||
from inference.pipeline.field_extractor import ExtractedField
|
||||
from backend.pipeline.yolo_detector import Detection
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepositoryCreate:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadCreate:
|
||||
|
||||
@@ -8,7 +8,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepositoryCreate:
|
||||
@@ -300,7 +300,7 @@ class TestActiveTrainingTasks:
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Update task to running
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from backend.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
task_repo = TrainingTaskRepository()
|
||||
task_repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
@@ -10,8 +10,8 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from backend.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from backend.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
@@ -243,7 +243,7 @@ class TestDocumentRepositoryDelete:
|
||||
assert result is True
|
||||
|
||||
# Verify annotation is also deleted
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
ann_repo = AnnotationRepository()
|
||||
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionCreate:
|
||||
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
from backend.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenCreate:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
from backend.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskCreate:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
AnnotationHistory,
|
||||
@@ -17,7 +17,7 @@ from inference.data.admin_models import (
|
||||
TrainingDataset,
|
||||
TrainingTask,
|
||||
)
|
||||
from inference.web.services.dashboard_service import (
|
||||
from backend.web.services.dashboard_service import (
|
||||
DashboardStatsService,
|
||||
DashboardActivityService,
|
||||
is_annotation_complete,
|
||||
|
||||
@@ -12,11 +12,11 @@ from uuid import uuid4
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from backend.data.repositories.annotation_repository import AnnotationRepository
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.repositories.document_repository import DocumentRepository
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.services.document_service import DocumentService, DocumentResult
|
||||
from backend.web.services.document_service import DocumentService, DocumentResult
|
||||
|
||||
|
||||
class MockStorageBackend:
|
||||
|
||||
@@ -7,7 +7,7 @@ Tests for database connection, session management, and basic operations.
|
||||
import pytest
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminToken
|
||||
from backend.data.admin_models import AdminDocument, AdminToken
|
||||
|
||||
|
||||
class TestDatabaseConnection:
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from inference.pipeline.customer_number_parser import (
|
||||
from backend.pipeline.customer_number_parser import (
|
||||
CustomerNumberParser,
|
||||
DashFormatPattern,
|
||||
NoDashFormatPattern,
|
||||
|
||||
@@ -26,7 +26,7 @@ def _collect_modules(package_name: str) -> list[str]:
|
||||
|
||||
|
||||
SHARED_MODULES = _collect_modules("shared")
|
||||
INFERENCE_MODULES = _collect_modules("inference")
|
||||
BACKEND_MODULES = _collect_modules("backend")
|
||||
TRAINING_MODULES = _collect_modules("training")
|
||||
|
||||
|
||||
@@ -36,9 +36,9 @@ def test_shared_module_imports(module_name: str) -> None:
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("module_name", INFERENCE_MODULES)
|
||||
def test_inference_module_imports(module_name: str) -> None:
|
||||
"""Every module in the inference package should import without error."""
|
||||
@pytest.mark.parametrize("module_name", BACKEND_MODULES)
|
||||
def test_backend_module_imports(module_name: str) -> None:
|
||||
"""Every module in the backend package should import without error."""
|
||||
importlib.import_module(module_name)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from inference.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData
|
||||
from backend.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData
|
||||
|
||||
|
||||
class TestPaymentLineParser:
|
||||
|
||||
@@ -10,12 +10,12 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB
|
||||
from inference.data.models import AsyncRequest
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.core.rate_limiter import RateLimiter
|
||||
from backend.data.async_request_db import ApiKeyConfig, AsyncRequestDB
|
||||
from backend.data.models import AsyncRequest
|
||||
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from backend.web.services.async_processing import AsyncProcessingService
|
||||
from backend.web.config import AsyncConfig, StorageConfig
|
||||
from backend.web.core.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -14,7 +14,7 @@ class TestTaskStatus:
|
||||
|
||||
def test_task_status_basic_fields(self) -> None:
|
||||
"""TaskStatus has all required fields."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test_runner",
|
||||
@@ -29,7 +29,7 @@ class TestTaskStatus:
|
||||
|
||||
def test_task_status_with_error(self) -> None:
|
||||
"""TaskStatus can include optional error message."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="failed_runner",
|
||||
@@ -42,7 +42,7 @@ class TestTaskStatus:
|
||||
|
||||
def test_task_status_default_error_is_none(self) -> None:
|
||||
"""TaskStatus error defaults to None."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test",
|
||||
@@ -54,7 +54,7 @@ class TestTaskStatus:
|
||||
|
||||
def test_task_status_is_frozen(self) -> None:
|
||||
"""TaskStatus is immutable (frozen dataclass)."""
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
status = TaskStatus(
|
||||
name="test",
|
||||
@@ -71,20 +71,20 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_cannot_instantiate_directly(self) -> None:
|
||||
"""TaskRunner is abstract and cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
TaskRunner() # type: ignore[abstract]
|
||||
|
||||
def test_is_abstract_base_class(self) -> None:
|
||||
"""TaskRunner inherits from ABC."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
assert issubclass(TaskRunner, ABC)
|
||||
|
||||
def test_subclass_missing_name_cannot_instantiate(self) -> None:
|
||||
"""Subclass without name property cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingName(TaskRunner):
|
||||
def start(self) -> None:
|
||||
@@ -105,7 +105,7 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_subclass_missing_start_cannot_instantiate(self) -> None:
|
||||
"""Subclass without start method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingStart(TaskRunner):
|
||||
@property
|
||||
@@ -127,7 +127,7 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
|
||||
"""Subclass without stop method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingStop(TaskRunner):
|
||||
@property
|
||||
@@ -149,7 +149,7 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
|
||||
"""Subclass without is_running property cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class MissingIsRunning(TaskRunner):
|
||||
@property
|
||||
@@ -170,7 +170,7 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
|
||||
"""Subclass without get_status method cannot be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
class MissingGetStatus(TaskRunner):
|
||||
@property
|
||||
@@ -192,7 +192,7 @@ class TestTaskRunnerInterface:
|
||||
|
||||
def test_complete_subclass_can_instantiate(self) -> None:
|
||||
"""Complete subclass implementing all methods can be instantiated."""
|
||||
from inference.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
|
||||
class CompleteRunner(TaskRunner):
|
||||
def __init__(self) -> None:
|
||||
@@ -240,7 +240,7 @@ class TestTaskManager:
|
||||
|
||||
def test_register_runner(self) -> None:
|
||||
"""Can register a task runner."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
@property
|
||||
@@ -268,14 +268,14 @@ class TestTaskManager:
|
||||
|
||||
def test_get_runner_returns_none_for_unknown(self) -> None:
|
||||
"""get_runner returns None for unknown runner name."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
from backend.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
assert manager.get_runner("unknown") is None
|
||||
|
||||
def test_start_all_runners(self) -> None:
|
||||
"""start_all starts all registered runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
@@ -315,7 +315,7 @@ class TestTaskManager:
|
||||
|
||||
def test_stop_all_runners(self) -> None:
|
||||
"""stop_all stops all registered runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
@@ -355,7 +355,7 @@ class TestTaskManager:
|
||||
|
||||
def test_get_all_status(self) -> None:
|
||||
"""get_all_status returns status of all runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str, pending: int) -> None:
|
||||
@@ -391,14 +391,14 @@ class TestTaskManager:
|
||||
|
||||
def test_get_all_status_empty_when_no_runners(self) -> None:
|
||||
"""get_all_status returns empty dict when no runners registered."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
from backend.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
assert manager.get_all_status() == {}
|
||||
|
||||
def test_runner_names_property(self) -> None:
|
||||
"""runner_names returns list of all registered runner names."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
class MockRunner(TaskRunner):
|
||||
def __init__(self, runner_name: str) -> None:
|
||||
@@ -430,7 +430,7 @@ class TestTaskManager:
|
||||
|
||||
def test_stop_all_with_timeout_distribution(self) -> None:
|
||||
"""stop_all distributes timeout across runners."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
received_timeouts: list[float | None] = []
|
||||
|
||||
@@ -467,7 +467,7 @@ class TestTaskManager:
|
||||
|
||||
def test_start_all_skips_runners_requiring_arguments(self) -> None:
|
||||
"""start_all skips runners that require arguments."""
|
||||
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
|
||||
|
||||
no_args_started = []
|
||||
with_args_started = []
|
||||
@@ -521,7 +521,7 @@ class TestTaskManager:
|
||||
|
||||
def test_stop_all_with_no_runners(self) -> None:
|
||||
"""stop_all does nothing when no runners registered."""
|
||||
from inference.web.core.task_interface import TaskManager
|
||||
from backend.web.core.task_interface import TaskManager
|
||||
|
||||
manager = TaskManager()
|
||||
# Should not raise any exception
|
||||
@@ -535,23 +535,23 @@ class TestTrainingSchedulerInterface:
|
||||
|
||||
def test_training_scheduler_is_task_runner(self) -> None:
|
||||
"""TrainingScheduler inherits from TaskRunner."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.core.scheduler import TrainingScheduler
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
assert isinstance(scheduler, TaskRunner)
|
||||
|
||||
def test_training_scheduler_name(self) -> None:
|
||||
"""TrainingScheduler has correct name."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from backend.web.core.scheduler import TrainingScheduler
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
assert scheduler.name == "training_scheduler"
|
||||
|
||||
def test_training_scheduler_get_status(self) -> None:
|
||||
"""TrainingScheduler provides status via get_status."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.scheduler import TrainingScheduler
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
scheduler = TrainingScheduler()
|
||||
# Mock the training tasks repository
|
||||
@@ -572,29 +572,29 @@ class TestAutoLabelSchedulerInterface:
|
||||
|
||||
def test_autolabel_scheduler_is_task_runner(self) -> None:
|
||||
"""AutoLabelScheduler inherits from TaskRunner."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
scheduler = AutoLabelScheduler()
|
||||
assert isinstance(scheduler, TaskRunner)
|
||||
|
||||
def test_autolabel_scheduler_name(self) -> None:
|
||||
"""AutoLabelScheduler has correct name."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
scheduler = AutoLabelScheduler()
|
||||
assert scheduler.name == "autolabel_scheduler"
|
||||
|
||||
def test_autolabel_scheduler_get_status(self) -> None:
|
||||
"""AutoLabelScheduler provides status via get_status."""
|
||||
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
|
||||
with patch(
|
||||
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
|
||||
"backend.web.core.autolabel_scheduler.get_pending_autolabel_documents"
|
||||
) as mock_get:
|
||||
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
|
||||
|
||||
@@ -612,23 +612,23 @@ class TestAsyncTaskQueueInterface:
|
||||
|
||||
def test_async_queue_is_task_runner(self) -> None:
|
||||
"""AsyncTaskQueue inherits from TaskRunner."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.workers.async_queue import AsyncTaskQueue
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
assert isinstance(queue, TaskRunner)
|
||||
|
||||
def test_async_queue_name(self) -> None:
|
||||
"""AsyncTaskQueue has correct name."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from backend.web.workers.async_queue import AsyncTaskQueue
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
assert queue.name == "async_task_queue"
|
||||
|
||||
def test_async_queue_get_status(self) -> None:
|
||||
"""AsyncTaskQueue provides status via get_status."""
|
||||
from inference.web.workers.async_queue import AsyncTaskQueue
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.workers.async_queue import AsyncTaskQueue
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
queue = AsyncTaskQueue()
|
||||
status = queue.get_status()
|
||||
@@ -645,23 +645,23 @@ class TestBatchTaskQueueInterface:
|
||||
|
||||
def test_batch_queue_is_task_runner(self) -> None:
|
||||
"""BatchTaskQueue inherits from TaskRunner."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
from inference.web.core.task_interface import TaskRunner
|
||||
from backend.web.workers.batch_queue import BatchTaskQueue
|
||||
from backend.web.core.task_interface import TaskRunner
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
assert isinstance(queue, TaskRunner)
|
||||
|
||||
def test_batch_queue_name(self) -> None:
|
||||
"""BatchTaskQueue has correct name."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
from backend.web.workers.batch_queue import BatchTaskQueue
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
assert queue.name == "batch_task_queue"
|
||||
|
||||
def test_batch_queue_get_status(self) -> None:
|
||||
"""BatchTaskQueue provides status via get_status."""
|
||||
from inference.web.workers.batch_queue import BatchTaskQueue
|
||||
from inference.web.core.task_interface import TaskStatus
|
||||
from backend.web.workers.batch_queue import BatchTaskQueue
|
||||
from backend.web.core.task_interface import TaskStatus
|
||||
|
||||
queue = BatchTaskQueue()
|
||||
status = queue.get_status()
|
||||
|
||||
@@ -9,10 +9,10 @@ from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from backend.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from shared.fields import FIELD_CLASSES
|
||||
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from inference.web.schemas.admin import (
|
||||
from backend.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
|
||||
from backend.web.schemas.admin import (
|
||||
AnnotationCreate,
|
||||
AnnotationUpdate,
|
||||
AutoLabelRequest,
|
||||
@@ -234,10 +234,10 @@ class TestAutoLabelFilePathResolution:
|
||||
mock_storage.get_raw_pdf_local_path.return_value = mock_path
|
||||
|
||||
with patch(
|
||||
"inference.web.services.storage_helpers.get_storage_helper",
|
||||
"backend.web.services.storage_helpers.get_storage_helper",
|
||||
return_value=mock_storage,
|
||||
):
|
||||
from inference.web.services.storage_helpers import get_storage_helper
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
storage = get_storage_helper()
|
||||
result = storage.get_raw_pdf_local_path("test.pdf")
|
||||
|
||||
@@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from inference.data.repositories import TokenRepository
|
||||
from inference.data.admin_models import AdminToken
|
||||
from inference.web.core.auth import (
|
||||
from backend.data.repositories import TokenRepository
|
||||
from backend.data.admin_models import AdminToken
|
||||
from backend.web.core.auth import (
|
||||
get_token_repository,
|
||||
reset_token_repository,
|
||||
validate_admin_token,
|
||||
@@ -81,7 +81,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_active_token(self):
|
||||
"""Test valid active token."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -98,7 +98,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_inactive_token(self):
|
||||
"""Test inactive token."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -115,7 +115,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_expired_token(self):
|
||||
"""Test expired token."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestTokenRepository:
|
||||
This verifies the fix for comparing timezone-aware and naive datetimes.
|
||||
The auth API now creates tokens with timezone-aware expiration dates.
|
||||
"""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -157,7 +157,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_not_expired_token_timezone_aware(self):
|
||||
"""Test non-expired token with timezone-aware datetime."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
@@ -175,7 +175,7 @@ class TestTokenRepository:
|
||||
|
||||
def test_is_valid_token_not_found(self):
|
||||
"""Test token not found."""
|
||||
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
||||
mock_session = MagicMock()
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
mock_session.get.return_value = None
|
||||
|
||||
@@ -12,9 +12,9 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.data.admin_models import AdminDocument, AdminToken
|
||||
from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
from backend.data.admin_models import AdminDocument, AdminToken
|
||||
from backend.web.api.v1.admin.documents import _validate_uuid, create_documents_router
|
||||
from backend.web.config import StorageConfig
|
||||
|
||||
|
||||
# Test UUID
|
||||
@@ -66,7 +66,7 @@ class TestCreateTokenEndpoint:
|
||||
|
||||
def test_create_token_success(self, mock_db):
|
||||
"""Test successful token creation."""
|
||||
from inference.web.schemas.admin import AdminTokenCreate
|
||||
from backend.web.schemas.admin import AdminTokenCreate
|
||||
|
||||
request = AdminTokenCreate(name="Test Token", expires_in_days=30)
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.documents import create_documents_router
|
||||
from inference.web.config import StorageConfig
|
||||
from inference.web.core.auth import (
|
||||
from backend.web.api.v1.admin.documents import create_documents_router
|
||||
from backend.web.config import StorageConfig
|
||||
from backend.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
"""
|
||||
Tests to verify admin schemas split maintains backward compatibility.
|
||||
|
||||
All existing imports from inference.web.schemas.admin must continue to work.
|
||||
All existing imports from backend.web.schemas.admin must continue to work.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestEnumImports:
|
||||
"""All enums importable from inference.web.schemas.admin."""
|
||||
"""All enums importable from backend.web.schemas.admin."""
|
||||
|
||||
def test_document_status(self):
|
||||
from inference.web.schemas.admin import DocumentStatus
|
||||
from backend.web.schemas.admin import DocumentStatus
|
||||
assert DocumentStatus.PENDING == "pending"
|
||||
|
||||
def test_auto_label_status(self):
|
||||
from inference.web.schemas.admin import AutoLabelStatus
|
||||
from backend.web.schemas.admin import AutoLabelStatus
|
||||
assert AutoLabelStatus.RUNNING == "running"
|
||||
|
||||
def test_training_status(self):
|
||||
from inference.web.schemas.admin import TrainingStatus
|
||||
from backend.web.schemas.admin import TrainingStatus
|
||||
assert TrainingStatus.PENDING == "pending"
|
||||
|
||||
def test_training_type(self):
|
||||
from inference.web.schemas.admin import TrainingType
|
||||
from backend.web.schemas.admin import TrainingType
|
||||
assert TrainingType.TRAIN == "train"
|
||||
|
||||
def test_annotation_source(self):
|
||||
from inference.web.schemas.admin import AnnotationSource
|
||||
from backend.web.schemas.admin import AnnotationSource
|
||||
assert AnnotationSource.MANUAL == "manual"
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ class TestAuthImports:
|
||||
"""Auth schemas importable."""
|
||||
|
||||
def test_admin_token_create(self):
|
||||
from inference.web.schemas.admin import AdminTokenCreate
|
||||
from backend.web.schemas.admin import AdminTokenCreate
|
||||
token = AdminTokenCreate(name="test")
|
||||
assert token.name == "test"
|
||||
|
||||
def test_admin_token_response(self):
|
||||
from inference.web.schemas.admin import AdminTokenResponse
|
||||
from backend.web.schemas.admin import AdminTokenResponse
|
||||
assert AdminTokenResponse is not None
|
||||
|
||||
|
||||
@@ -48,23 +48,23 @@ class TestDocumentImports:
|
||||
"""Document schemas importable."""
|
||||
|
||||
def test_document_upload_response(self):
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
from backend.web.schemas.admin import DocumentUploadResponse
|
||||
assert DocumentUploadResponse is not None
|
||||
|
||||
def test_document_item(self):
|
||||
from inference.web.schemas.admin import DocumentItem
|
||||
from backend.web.schemas.admin import DocumentItem
|
||||
assert DocumentItem is not None
|
||||
|
||||
def test_document_list_response(self):
|
||||
from inference.web.schemas.admin import DocumentListResponse
|
||||
from backend.web.schemas.admin import DocumentListResponse
|
||||
assert DocumentListResponse is not None
|
||||
|
||||
def test_document_detail_response(self):
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
from backend.web.schemas.admin import DocumentDetailResponse
|
||||
assert DocumentDetailResponse is not None
|
||||
|
||||
def test_document_stats_response(self):
|
||||
from inference.web.schemas.admin import DocumentStatsResponse
|
||||
from backend.web.schemas.admin import DocumentStatsResponse
|
||||
assert DocumentStatsResponse is not None
|
||||
|
||||
|
||||
@@ -72,60 +72,60 @@ class TestAnnotationImports:
|
||||
"""Annotation schemas importable."""
|
||||
|
||||
def test_bounding_box(self):
|
||||
from inference.web.schemas.admin import BoundingBox
|
||||
from backend.web.schemas.admin import BoundingBox
|
||||
bb = BoundingBox(x=0, y=0, width=100, height=50)
|
||||
assert bb.width == 100
|
||||
|
||||
def test_annotation_create(self):
|
||||
from inference.web.schemas.admin import AnnotationCreate
|
||||
from backend.web.schemas.admin import AnnotationCreate
|
||||
assert AnnotationCreate is not None
|
||||
|
||||
def test_annotation_update(self):
|
||||
from inference.web.schemas.admin import AnnotationUpdate
|
||||
from backend.web.schemas.admin import AnnotationUpdate
|
||||
assert AnnotationUpdate is not None
|
||||
|
||||
def test_annotation_item(self):
|
||||
from inference.web.schemas.admin import AnnotationItem
|
||||
from backend.web.schemas.admin import AnnotationItem
|
||||
assert AnnotationItem is not None
|
||||
|
||||
def test_annotation_response(self):
|
||||
from inference.web.schemas.admin import AnnotationResponse
|
||||
from backend.web.schemas.admin import AnnotationResponse
|
||||
assert AnnotationResponse is not None
|
||||
|
||||
def test_annotation_list_response(self):
|
||||
from inference.web.schemas.admin import AnnotationListResponse
|
||||
from backend.web.schemas.admin import AnnotationListResponse
|
||||
assert AnnotationListResponse is not None
|
||||
|
||||
def test_annotation_lock_request(self):
|
||||
from inference.web.schemas.admin import AnnotationLockRequest
|
||||
from backend.web.schemas.admin import AnnotationLockRequest
|
||||
assert AnnotationLockRequest is not None
|
||||
|
||||
def test_annotation_lock_response(self):
|
||||
from inference.web.schemas.admin import AnnotationLockResponse
|
||||
from backend.web.schemas.admin import AnnotationLockResponse
|
||||
assert AnnotationLockResponse is not None
|
||||
|
||||
def test_auto_label_request(self):
|
||||
from inference.web.schemas.admin import AutoLabelRequest
|
||||
from backend.web.schemas.admin import AutoLabelRequest
|
||||
assert AutoLabelRequest is not None
|
||||
|
||||
def test_auto_label_response(self):
|
||||
from inference.web.schemas.admin import AutoLabelResponse
|
||||
from backend.web.schemas.admin import AutoLabelResponse
|
||||
assert AutoLabelResponse is not None
|
||||
|
||||
def test_annotation_verify_request(self):
|
||||
from inference.web.schemas.admin import AnnotationVerifyRequest
|
||||
from backend.web.schemas.admin import AnnotationVerifyRequest
|
||||
assert AnnotationVerifyRequest is not None
|
||||
|
||||
def test_annotation_verify_response(self):
|
||||
from inference.web.schemas.admin import AnnotationVerifyResponse
|
||||
from backend.web.schemas.admin import AnnotationVerifyResponse
|
||||
assert AnnotationVerifyResponse is not None
|
||||
|
||||
def test_annotation_override_request(self):
|
||||
from inference.web.schemas.admin import AnnotationOverrideRequest
|
||||
from backend.web.schemas.admin import AnnotationOverrideRequest
|
||||
assert AnnotationOverrideRequest is not None
|
||||
|
||||
def test_annotation_override_response(self):
|
||||
from inference.web.schemas.admin import AnnotationOverrideResponse
|
||||
from backend.web.schemas.admin import AnnotationOverrideResponse
|
||||
assert AnnotationOverrideResponse is not None
|
||||
|
||||
|
||||
@@ -133,68 +133,68 @@ class TestTrainingImports:
|
||||
"""Training schemas importable."""
|
||||
|
||||
def test_training_config(self):
|
||||
from inference.web.schemas.admin import TrainingConfig
|
||||
from backend.web.schemas.admin import TrainingConfig
|
||||
config = TrainingConfig()
|
||||
assert config.epochs == 100
|
||||
|
||||
def test_training_task_create(self):
|
||||
from inference.web.schemas.admin import TrainingTaskCreate
|
||||
from backend.web.schemas.admin import TrainingTaskCreate
|
||||
assert TrainingTaskCreate is not None
|
||||
|
||||
def test_training_task_item(self):
|
||||
from inference.web.schemas.admin import TrainingTaskItem
|
||||
from backend.web.schemas.admin import TrainingTaskItem
|
||||
assert TrainingTaskItem is not None
|
||||
|
||||
def test_training_task_list_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskListResponse
|
||||
from backend.web.schemas.admin import TrainingTaskListResponse
|
||||
assert TrainingTaskListResponse is not None
|
||||
|
||||
def test_training_task_detail_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskDetailResponse
|
||||
from backend.web.schemas.admin import TrainingTaskDetailResponse
|
||||
assert TrainingTaskDetailResponse is not None
|
||||
|
||||
def test_training_task_response(self):
|
||||
from inference.web.schemas.admin import TrainingTaskResponse
|
||||
from backend.web.schemas.admin import TrainingTaskResponse
|
||||
assert TrainingTaskResponse is not None
|
||||
|
||||
def test_training_log_item(self):
|
||||
from inference.web.schemas.admin import TrainingLogItem
|
||||
from backend.web.schemas.admin import TrainingLogItem
|
||||
assert TrainingLogItem is not None
|
||||
|
||||
def test_training_logs_response(self):
|
||||
from inference.web.schemas.admin import TrainingLogsResponse
|
||||
from backend.web.schemas.admin import TrainingLogsResponse
|
||||
assert TrainingLogsResponse is not None
|
||||
|
||||
def test_export_request(self):
|
||||
from inference.web.schemas.admin import ExportRequest
|
||||
from backend.web.schemas.admin import ExportRequest
|
||||
assert ExportRequest is not None
|
||||
|
||||
def test_export_response(self):
|
||||
from inference.web.schemas.admin import ExportResponse
|
||||
from backend.web.schemas.admin import ExportResponse
|
||||
assert ExportResponse is not None
|
||||
|
||||
def test_training_document_item(self):
|
||||
from inference.web.schemas.admin import TrainingDocumentItem
|
||||
from backend.web.schemas.admin import TrainingDocumentItem
|
||||
assert TrainingDocumentItem is not None
|
||||
|
||||
def test_training_documents_response(self):
|
||||
from inference.web.schemas.admin import TrainingDocumentsResponse
|
||||
from backend.web.schemas.admin import TrainingDocumentsResponse
|
||||
assert TrainingDocumentsResponse is not None
|
||||
|
||||
def test_model_metrics(self):
|
||||
from inference.web.schemas.admin import ModelMetrics
|
||||
from backend.web.schemas.admin import ModelMetrics
|
||||
assert ModelMetrics is not None
|
||||
|
||||
def test_training_model_item(self):
|
||||
from inference.web.schemas.admin import TrainingModelItem
|
||||
from backend.web.schemas.admin import TrainingModelItem
|
||||
assert TrainingModelItem is not None
|
||||
|
||||
def test_training_models_response(self):
|
||||
from inference.web.schemas.admin import TrainingModelsResponse
|
||||
from backend.web.schemas.admin import TrainingModelsResponse
|
||||
assert TrainingModelsResponse is not None
|
||||
|
||||
def test_training_history_item(self):
|
||||
from inference.web.schemas.admin import TrainingHistoryItem
|
||||
from backend.web.schemas.admin import TrainingHistoryItem
|
||||
assert TrainingHistoryItem is not None
|
||||
|
||||
|
||||
@@ -202,31 +202,31 @@ class TestDatasetImports:
|
||||
"""Dataset schemas importable."""
|
||||
|
||||
def test_dataset_create_request(self):
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
from backend.web.schemas.admin import DatasetCreateRequest
|
||||
assert DatasetCreateRequest is not None
|
||||
|
||||
def test_dataset_document_item(self):
|
||||
from inference.web.schemas.admin import DatasetDocumentItem
|
||||
from backend.web.schemas.admin import DatasetDocumentItem
|
||||
assert DatasetDocumentItem is not None
|
||||
|
||||
def test_dataset_response(self):
|
||||
from inference.web.schemas.admin import DatasetResponse
|
||||
from backend.web.schemas.admin import DatasetResponse
|
||||
assert DatasetResponse is not None
|
||||
|
||||
def test_dataset_detail_response(self):
|
||||
from inference.web.schemas.admin import DatasetDetailResponse
|
||||
from backend.web.schemas.admin import DatasetDetailResponse
|
||||
assert DatasetDetailResponse is not None
|
||||
|
||||
def test_dataset_list_item(self):
|
||||
from inference.web.schemas.admin import DatasetListItem
|
||||
from backend.web.schemas.admin import DatasetListItem
|
||||
assert DatasetListItem is not None
|
||||
|
||||
def test_dataset_list_response(self):
|
||||
from inference.web.schemas.admin import DatasetListResponse
|
||||
from backend.web.schemas.admin import DatasetListResponse
|
||||
assert DatasetListResponse is not None
|
||||
|
||||
def test_dataset_train_request(self):
|
||||
from inference.web.schemas.admin import DatasetTrainRequest
|
||||
from backend.web.schemas.admin import DatasetTrainRequest
|
||||
assert DatasetTrainRequest is not None
|
||||
|
||||
|
||||
@@ -234,12 +234,12 @@ class TestForwardReferences:
|
||||
"""Forward references resolve correctly."""
|
||||
|
||||
def test_document_detail_has_annotation_items(self):
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
from backend.web.schemas.admin import DocumentDetailResponse
|
||||
fields = DocumentDetailResponse.model_fields
|
||||
assert "annotations" in fields
|
||||
assert "training_history" in fields
|
||||
|
||||
def test_dataset_train_request_has_config(self):
|
||||
from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig
|
||||
from backend.web.schemas.admin import DatasetTrainRequest, TrainingConfig
|
||||
req = DatasetTrainRequest(name="test", config=TrainingConfig())
|
||||
assert req.config.epochs == 100
|
||||
|
||||
@@ -7,15 +7,15 @@ from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from inference.data.admin_models import TrainingTask, TrainingLog
|
||||
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from inference.web.core.scheduler import (
|
||||
from backend.data.admin_models import TrainingTask, TrainingLog
|
||||
from backend.web.api.v1.admin.training import _validate_uuid, create_training_router
|
||||
from backend.web.core.scheduler import (
|
||||
TrainingScheduler,
|
||||
get_training_scheduler,
|
||||
start_scheduler,
|
||||
stop_scheduler,
|
||||
)
|
||||
from inference.web.schemas.admin import (
|
||||
from backend.web.schemas.admin import (
|
||||
TrainingConfig,
|
||||
TrainingStatus,
|
||||
TrainingTaskCreate,
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.locks import create_locks_router
|
||||
from inference.web.core.auth import (
|
||||
from backend.web.api.v1.admin.locks import create_locks_router
|
||||
from backend.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
)
|
||||
|
||||
@@ -9,12 +9,12 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.annotations import (
|
||||
from backend.web.api.v1.admin.annotations import (
|
||||
create_annotation_router,
|
||||
get_doc_repository,
|
||||
get_ann_repository,
|
||||
)
|
||||
from inference.web.core.auth import validate_admin_token
|
||||
from backend.web.core.auth import validate_admin_token
|
||||
|
||||
|
||||
class MockAdminDocument:
|
||||
|
||||
@@ -11,7 +11,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
|
||||
|
||||
class TestAsyncTask:
|
||||
|
||||
@@ -11,12 +11,12 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
||||
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from inference.web.services.async_processing import AsyncSubmitResult
|
||||
from inference.web.dependencies import init_dependencies
|
||||
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
|
||||
from inference.web.schemas.inference import AsyncStatus
|
||||
from backend.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
|
||||
from backend.web.api.v1.public.async_api import create_async_router, set_async_service
|
||||
from backend.web.services.async_processing import AsyncSubmitResult
|
||||
from backend.web.dependencies import init_dependencies
|
||||
from backend.web.rate_limiter import RateLimiter, RateLimitStatus
|
||||
from backend.web.schemas.inference import AsyncStatus
|
||||
|
||||
# Valid UUID for testing
|
||||
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
@@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.async_request_db import AsyncRequest
|
||||
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
|
||||
from inference.web.config import AsyncConfig, StorageConfig
|
||||
from inference.web.rate_limiter import RateLimiter
|
||||
from backend.data.async_request_db import AsyncRequest
|
||||
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
||||
from backend.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
|
||||
from backend.web.config import AsyncConfig, StorageConfig
|
||||
from backend.web.rate_limiter import RateLimiter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -231,7 +231,7 @@ class TestAsyncProcessingService:
|
||||
mock_db.get_request.return_value = None
|
||||
|
||||
# Mock the storage helper to return the same directory as the fixture
|
||||
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
|
||||
with patch("backend.web.services.async_processing.get_storage_helper") as mock_storage:
|
||||
mock_helper = MagicMock()
|
||||
mock_helper.get_uploads_base_path.return_value = temp_dir
|
||||
mock_storage.return_value = mock_helper
|
||||
|
||||
@@ -10,8 +10,8 @@ from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
import numpy as np
|
||||
|
||||
from inference.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from inference.web.core.auth import (
|
||||
from backend.web.api.v1.admin.augmentation import create_augmentation_router
|
||||
from backend.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_dataset_repository,
|
||||
@@ -175,7 +175,7 @@ class TestAugmentationPreviewEndpoint:
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
"backend.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
@@ -251,7 +251,7 @@ class TestAugmentationPreviewConfigEndpoint:
|
||||
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
"backend.web.services.augmentation_service.AugmentationService._load_document_page"
|
||||
) as mock_load:
|
||||
mock_load.return_value = fake_image
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pathlib import Path
|
||||
from unittest.mock import Mock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.web.services.autolabel import AutoLabelService
|
||||
from backend.web.services.autolabel import AutoLabelService
|
||||
|
||||
|
||||
class MockDocument:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
||||
from backend.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
||||
|
||||
|
||||
class MockBatchService:
|
||||
|
||||
@@ -11,10 +11,10 @@ import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.batch.routes import router, get_batch_repository
|
||||
from inference.web.core.auth import validate_admin_token
|
||||
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
from backend.web.api.v1.batch.routes import router, get_batch_repository
|
||||
from backend.web.core.auth import validate_admin_token
|
||||
from backend.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
|
||||
from backend.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
class MockBatchUploadRepository:
|
||||
|
||||
@@ -9,7 +9,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.web.services.batch_upload import BatchUploadService
|
||||
from backend.web.services.batch_upload import BatchUploadService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
|
||||
"""Document with invoice_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
@@ -39,7 +39,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
|
||||
"""Document with ocr_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
@@ -50,7 +50,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
|
||||
"""Document with invoice_number + plusgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
@@ -61,7 +61,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
|
||||
"""Document with ocr_number + bankgiro should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 3, "class_name": "ocr_number"},
|
||||
@@ -72,7 +72,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_only_identifier_is_incomplete(self):
|
||||
"""Document with only identifier field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
@@ -82,7 +82,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_only_payment_is_incomplete(self):
|
||||
"""Document with only payment field should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 4, "class_name": "bankgiro"},
|
||||
@@ -92,13 +92,13 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_no_annotations_is_incomplete(self):
|
||||
"""Document with no annotations should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
assert is_annotation_complete([]) is False
|
||||
|
||||
def test_document_with_other_fields_only_is_incomplete(self):
|
||||
"""Document with only non-essential fields should be incomplete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 1, "class_name": "invoice_date"},
|
||||
@@ -109,7 +109,7 @@ class TestAnnotationCompletenessLogic:
|
||||
|
||||
def test_document_with_all_fields_is_complete(self):
|
||||
"""Document with all fields should be complete."""
|
||||
from inference.web.services.dashboard_service import is_annotation_complete
|
||||
from backend.web.services.dashboard_service import is_annotation_complete
|
||||
|
||||
annotations = [
|
||||
{"class_id": 0, "class_name": "invoice_number"},
|
||||
@@ -178,7 +178,7 @@ class TestDashboardSchemas:
|
||||
|
||||
def test_dashboard_stats_response_schema(self):
|
||||
"""Test DashboardStatsResponse schema validation."""
|
||||
from inference.web.schemas.admin import DashboardStatsResponse
|
||||
from backend.web.schemas.admin import DashboardStatsResponse
|
||||
|
||||
response = DashboardStatsResponse(
|
||||
total_documents=38,
|
||||
@@ -195,10 +195,10 @@ class TestDashboardSchemas:
|
||||
assert response.completeness_rate == 75.76
|
||||
|
||||
def test_active_model_response_schema(self):
|
||||
"""Test ActiveModelResponse schema with null model."""
|
||||
from inference.web.schemas.admin import ActiveModelResponse
|
||||
"""Test DashboardActiveModelResponse schema with null model."""
|
||||
from backend.web.schemas.admin import DashboardActiveModelResponse
|
||||
|
||||
response = ActiveModelResponse(
|
||||
response = DashboardActiveModelResponse(
|
||||
model=None,
|
||||
running_training=None,
|
||||
)
|
||||
@@ -208,7 +208,7 @@ class TestDashboardSchemas:
|
||||
|
||||
def test_active_model_info_schema(self):
|
||||
"""Test ActiveModelInfo schema validation."""
|
||||
from inference.web.schemas.admin import ActiveModelInfo
|
||||
from backend.web.schemas.admin import ActiveModelInfo
|
||||
|
||||
model = ActiveModelInfo(
|
||||
version_id=TEST_MODEL_UUID,
|
||||
@@ -227,7 +227,7 @@ class TestDashboardSchemas:
|
||||
|
||||
def test_running_training_info_schema(self):
|
||||
"""Test RunningTrainingInfo schema validation."""
|
||||
from inference.web.schemas.admin import RunningTrainingInfo
|
||||
from backend.web.schemas.admin import RunningTrainingInfo
|
||||
|
||||
task = RunningTrainingInfo(
|
||||
task_id=TEST_TASK_UUID,
|
||||
@@ -243,7 +243,7 @@ class TestDashboardSchemas:
|
||||
|
||||
def test_activity_item_schema(self):
|
||||
"""Test ActivityItem schema validation."""
|
||||
from inference.web.schemas.admin import ActivityItem
|
||||
from backend.web.schemas.admin import ActivityItem
|
||||
|
||||
activity = ActivityItem(
|
||||
type="model_activated",
|
||||
@@ -258,7 +258,7 @@ class TestDashboardSchemas:
|
||||
|
||||
def test_recent_activity_response_schema(self):
|
||||
"""Test RecentActivityResponse schema with empty activities."""
|
||||
from inference.web.schemas.admin import RecentActivityResponse
|
||||
from backend.web.schemas.admin import RecentActivityResponse
|
||||
|
||||
response = RecentActivityResponse(activities=[])
|
||||
|
||||
@@ -270,7 +270,7 @@ class TestDashboardRouterCreation:
|
||||
|
||||
def test_creates_router_with_expected_endpoints(self):
|
||||
"""Test router is created with expected endpoint paths."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from backend.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
@@ -282,7 +282,7 @@ class TestDashboardRouterCreation:
|
||||
|
||||
def test_router_has_correct_prefix(self):
|
||||
"""Test router has /admin/dashboard prefix."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from backend.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
@@ -290,7 +290,7 @@ class TestDashboardRouterCreation:
|
||||
|
||||
def test_router_has_dashboard_tag(self):
|
||||
"""Test router uses Dashboard tag."""
|
||||
from inference.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
from backend.web.api.v1.admin.dashboard import create_dashboard_router
|
||||
|
||||
router = create_dashboard_router()
|
||||
|
||||
@@ -302,7 +302,7 @@ class TestFieldClassIds:
|
||||
|
||||
def test_identifier_class_ids(self):
|
||||
"""Test identifier field class IDs."""
|
||||
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
|
||||
from backend.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
|
||||
|
||||
# invoice_number = 0, ocr_number = 3
|
||||
assert 0 in IDENTIFIER_CLASS_IDS
|
||||
@@ -310,7 +310,7 @@ class TestFieldClassIds:
|
||||
|
||||
def test_payment_class_ids(self):
|
||||
"""Test payment field class IDs."""
|
||||
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
|
||||
from backend.web.services.dashboard_service import PAYMENT_CLASS_IDS
|
||||
|
||||
# bankgiro = 4, plusgiro = 5
|
||||
assert 4 in PAYMENT_CLASS_IDS
|
||||
|
||||
@@ -12,7 +12,7 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import (
|
||||
from backend.data.admin_models import (
|
||||
AdminAnnotation,
|
||||
AdminDocument,
|
||||
TrainingDataset,
|
||||
@@ -105,7 +105,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
dataset_dir = tmp_path / "datasets" / "test"
|
||||
builder = DatasetBuilder(
|
||||
@@ -141,7 +141,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Images should be copied from admin_images to dataset folder."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -177,7 +177,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""YOLO label files should be generated with correct format."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -221,7 +221,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""data.yaml should be generated with correct field classes."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -257,7 +257,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Documents should be split into train/val/test according to ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -294,7 +294,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""After successful build, dataset status should be updated to 'ready'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -327,7 +327,7 @@ class TestDatasetBuilder:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""If build fails, dataset status should be set to 'failed'."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -357,7 +357,7 @@ class TestDatasetBuilder:
|
||||
sample_documents, sample_annotations
|
||||
):
|
||||
"""Same seed should produce same splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
results = []
|
||||
for _ in range(2):
|
||||
@@ -405,7 +405,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with unique group_key are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -433,7 +433,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -461,7 +461,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Documents with same group_key should be assigned to the same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -494,7 +494,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Multi-doc groups should be split according to train/val/test ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -536,7 +536,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Mix of single-doc and multi-doc groups should be handled correctly."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -574,7 +574,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Same seed should produce same split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -601,7 +601,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Different seeds should potentially produce different split assignments."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -627,7 +627,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Every document should be assigned a split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -654,7 +654,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""Empty document list should return empty result."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -671,7 +671,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have multiple docs, splits should follow ratios."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -707,7 +707,7 @@ class TestAssignSplitsByGroup:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""When all groups have single doc, they are distributed across splits."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
builder = DatasetBuilder(
|
||||
datasets_repo=mock_datasets_repo,
|
||||
@@ -798,7 +798,7 @@ class TestBuildDatasetWithGroupKey:
|
||||
mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""build_dataset should use group_key for split assignment."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
tmp_path, docs = grouped_documents
|
||||
|
||||
@@ -847,7 +847,7 @@ class TestBuildDatasetWithGroupKey:
|
||||
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
|
||||
):
|
||||
"""All docs with same group_key should go to same split."""
|
||||
from inference.web.services.dataset_builder import DatasetBuilder
|
||||
from backend.web.services.dataset_builder import DatasetBuilder
|
||||
|
||||
# Create 5 docs all with same group_key
|
||||
docs = []
|
||||
|
||||
@@ -9,9 +9,9 @@ from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import UUID
|
||||
|
||||
from inference.data.admin_models import TrainingDataset, DatasetDocument
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
from backend.data.admin_models import TrainingDataset, DatasetDocument
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
from backend.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
DatasetTrainRequest,
|
||||
TrainingConfig,
|
||||
@@ -130,10 +130,10 @@ class TestCreateDatasetRoute:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
"backend.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
), patch(
|
||||
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
"backend.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
||||
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
||||
@@ -222,10 +222,10 @@ class TestCreateDatasetRoute:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"inference.web.services.dataset_builder.DatasetBuilder",
|
||||
"backend.web.services.dataset_builder.DatasetBuilder",
|
||||
return_value=mock_builder,
|
||||
), patch(
|
||||
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
"backend.web.api.v1.admin.training.datasets.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
|
||||
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestTrainingDatasetModel:
|
||||
|
||||
def test_training_dataset_has_training_status_field(self):
|
||||
"""TrainingDataset model should have training_status field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(
|
||||
name="test-dataset",
|
||||
@@ -37,7 +37,7 @@ class TestTrainingDatasetModel:
|
||||
|
||||
def test_training_dataset_has_active_training_task_id_field(self):
|
||||
"""TrainingDataset model should have active_training_task_id field."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
task_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
@@ -48,7 +48,7 @@ class TestTrainingDatasetModel:
|
||||
|
||||
def test_training_dataset_defaults(self):
|
||||
"""TrainingDataset should have correct defaults for new fields."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test-dataset")
|
||||
assert dataset.training_status is None
|
||||
@@ -71,7 +71,7 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
|
||||
def test_update_training_status_sets_status(self, mock_session):
|
||||
"""update_training_status should set training_status."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
@@ -81,10 +81,10 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.repositories import DatasetRepository
|
||||
from backend.data.repositories import DatasetRepository
|
||||
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
@@ -98,7 +98,7 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
|
||||
def test_update_training_status_sets_task_id(self, mock_session):
|
||||
"""update_training_status should set active_training_task_id."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
@@ -109,10 +109,10 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.repositories import DatasetRepository
|
||||
from backend.data.repositories import DatasetRepository
|
||||
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
@@ -127,7 +127,7 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
self, mock_session
|
||||
):
|
||||
"""update_training_status should update main status to 'trained' when completed."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
dataset = TrainingDataset(
|
||||
@@ -137,10 +137,10 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.repositories import DatasetRepository
|
||||
from backend.data.repositories import DatasetRepository
|
||||
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
@@ -156,7 +156,7 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
self, mock_session
|
||||
):
|
||||
"""update_training_status should clear task_id when training completes."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset_id = uuid4()
|
||||
task_id = uuid4()
|
||||
@@ -169,10 +169,10 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
)
|
||||
mock_session.get.return_value = dataset
|
||||
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.repositories import DatasetRepository
|
||||
from backend.data.repositories import DatasetRepository
|
||||
|
||||
repo = DatasetRepository()
|
||||
repo.update_training_status(
|
||||
@@ -187,10 +187,10 @@ class TestDatasetRepositoryTrainingStatus:
|
||||
"""update_training_status should handle missing dataset gracefully."""
|
||||
mock_session.get.return_value = None
|
||||
|
||||
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
|
||||
mock_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
from inference.data.repositories import DatasetRepository
|
||||
from backend.data.repositories import DatasetRepository
|
||||
|
||||
repo = DatasetRepository()
|
||||
# Should not raise
|
||||
@@ -213,7 +213,7 @@ class TestDatasetDetailResponseTrainingStatus:
|
||||
|
||||
def test_dataset_detail_response_includes_training_status(self):
|
||||
"""DatasetDetailResponse schema should include training_status field."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
from backend.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
@@ -240,7 +240,7 @@ class TestDatasetDetailResponseTrainingStatus:
|
||||
|
||||
def test_dataset_detail_response_allows_null_training_status(self):
|
||||
"""DatasetDetailResponse should allow null training_status."""
|
||||
from inference.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
from backend.web.schemas.admin.datasets import DatasetDetailResponse
|
||||
|
||||
response = DatasetDetailResponse(
|
||||
dataset_id=str(uuid4()),
|
||||
@@ -294,7 +294,7 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
|
||||
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
|
||||
"""Scheduler should set dataset training_status to 'running' when task starts."""
|
||||
from inference.web.core.scheduler import TrainingScheduler
|
||||
from backend.web.core.scheduler import TrainingScheduler
|
||||
|
||||
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
|
||||
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
|
||||
@@ -333,35 +333,35 @@ class TestDatasetStatusValues:
|
||||
|
||||
def test_dataset_status_building(self):
|
||||
"""Dataset can have status 'building'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="building")
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_dataset_status_ready(self):
|
||||
"""Dataset can have status 'ready'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="ready")
|
||||
assert dataset.status == "ready"
|
||||
|
||||
def test_dataset_status_trained(self):
|
||||
"""Dataset can have status 'trained'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="trained")
|
||||
assert dataset.status == "trained"
|
||||
|
||||
def test_dataset_status_failed(self):
|
||||
"""Dataset can have status 'failed'."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
dataset = TrainingDataset(name="test", status="failed")
|
||||
assert dataset.status == "failed"
|
||||
|
||||
def test_training_status_values(self):
|
||||
"""Training status can have various values."""
|
||||
from inference.data.admin_models import TrainingDataset
|
||||
from backend.data.admin_models import TrainingDataset
|
||||
|
||||
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
|
||||
for status in valid_statuses:
|
||||
|
||||
@@ -10,7 +10,7 @@ from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from inference.data.admin_models import AdminDocument
|
||||
from backend.data.admin_models import AdminDocument
|
||||
|
||||
|
||||
# Test constants
|
||||
@@ -76,7 +76,7 @@ class TestDocumentCategoryInReadModel:
|
||||
|
||||
def test_admin_document_read_has_category(self):
|
||||
"""Test AdminDocumentRead includes category field."""
|
||||
from inference.data.admin_models import AdminDocumentRead
|
||||
from backend.data.admin_models import AdminDocumentRead
|
||||
|
||||
# Check the model has category field in its schema
|
||||
assert "category" in AdminDocumentRead.model_fields
|
||||
@@ -94,7 +94,7 @@ class TestDocumentCategoryAPI:
|
||||
|
||||
def test_upload_document_with_category(self, mock_admin_db):
|
||||
"""Test uploading document with category parameter."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
from backend.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
# Verify response schema supports category
|
||||
response = DocumentUploadResponse(
|
||||
@@ -110,7 +110,7 @@ class TestDocumentCategoryAPI:
|
||||
|
||||
def test_list_documents_returns_category(self, mock_admin_db):
|
||||
"""Test list documents endpoint returns category."""
|
||||
from inference.web.schemas.admin import DocumentItem
|
||||
from backend.web.schemas.admin import DocumentItem
|
||||
|
||||
item = DocumentItem(
|
||||
document_id=TEST_DOC_UUID,
|
||||
@@ -127,7 +127,7 @@ class TestDocumentCategoryAPI:
|
||||
|
||||
def test_document_detail_includes_category(self, mock_admin_db):
|
||||
"""Test document detail response includes category."""
|
||||
from inference.web.schemas.admin import DocumentDetailResponse
|
||||
from backend.web.schemas.admin import DocumentDetailResponse
|
||||
|
||||
# Check schema has category
|
||||
assert "category" in DocumentDetailResponse.model_fields
|
||||
@@ -167,14 +167,14 @@ class TestDocumentCategoryUpdate:
|
||||
|
||||
def test_update_document_category_schema(self):
|
||||
"""Test update document request supports category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
from backend.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="letter")
|
||||
assert request.category == "letter"
|
||||
|
||||
def test_update_document_category_optional(self):
|
||||
"""Test category is optional in update request."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
from backend.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
# Should not raise - category is optional
|
||||
request = DocumentUpdateRequest()
|
||||
@@ -186,7 +186,7 @@ class TestDatasetWithCategory:
|
||||
|
||||
def test_dataset_create_with_category_filter(self):
|
||||
"""Test creating dataset can filter by document category."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
from backend.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Invoice Training Set",
|
||||
@@ -197,7 +197,7 @@ class TestDatasetWithCategory:
|
||||
|
||||
def test_dataset_create_category_is_optional(self):
|
||||
"""Test category filter is optional when creating dataset."""
|
||||
from inference.web.schemas.admin import DatasetCreateRequest
|
||||
from backend.web.schemas.admin import DatasetCreateRequest
|
||||
|
||||
request = DatasetCreateRequest(
|
||||
name="Mixed Training Set",
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestGetCategoriesEndpoint:
|
||||
|
||||
def test_categories_endpoint_returns_list(self):
|
||||
"""Test categories endpoint returns list of available categories."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
from backend.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
# Test schema exists and works
|
||||
response = DocumentCategoriesResponse(
|
||||
@@ -35,7 +35,7 @@ class TestGetCategoriesEndpoint:
|
||||
|
||||
def test_categories_response_schema(self):
|
||||
"""Test DocumentCategoriesResponse schema structure."""
|
||||
from inference.web.schemas.admin import DocumentCategoriesResponse
|
||||
from backend.web.schemas.admin import DocumentCategoriesResponse
|
||||
|
||||
assert "categories" in DocumentCategoriesResponse.model_fields
|
||||
assert "total" in DocumentCategoriesResponse.model_fields
|
||||
@@ -69,7 +69,7 @@ class TestDocumentListFilterByCategory:
|
||||
"""Test list documents endpoint accepts category query parameter."""
|
||||
# The endpoint should accept ?category=invoice parameter
|
||||
# This test verifies the schema/query parameter exists
|
||||
from inference.web.schemas.admin import DocumentListResponse
|
||||
from backend.web.schemas.admin import DocumentListResponse
|
||||
|
||||
# Schema should work with category filter applied
|
||||
assert DocumentListResponse is not None
|
||||
@@ -93,7 +93,7 @@ class TestDocumentUploadWithCategory:
|
||||
|
||||
def test_upload_response_includes_category(self):
|
||||
"""Test upload response includes the category that was set."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
from backend.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
@@ -108,7 +108,7 @@ class TestDocumentUploadWithCategory:
|
||||
|
||||
def test_upload_defaults_to_invoice_category(self):
|
||||
"""Test upload defaults to 'invoice' if no category specified."""
|
||||
from inference.web.schemas.admin import DocumentUploadResponse
|
||||
from backend.web.schemas.admin import DocumentUploadResponse
|
||||
|
||||
response = DocumentUploadResponse(
|
||||
document_id=TEST_DOC_UUID,
|
||||
@@ -127,14 +127,14 @@ class TestDocumentRepositoryCategoryMethods:
|
||||
|
||||
def test_get_categories_method_exists(self):
|
||||
"""Test DocumentRepository has get_categories method."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
from backend.data.repositories import DocumentRepository
|
||||
|
||||
repo = DocumentRepository()
|
||||
assert hasattr(repo, "get_categories")
|
||||
|
||||
def test_get_paginated_accepts_category_filter(self):
|
||||
"""Test get_paginated method accepts category parameter."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
from backend.data.repositories import DocumentRepository
|
||||
import inspect
|
||||
|
||||
repo = DocumentRepository()
|
||||
@@ -152,14 +152,14 @@ class TestUpdateDocumentCategory:
|
||||
|
||||
def test_update_category_method_exists(self):
|
||||
"""Test DocumentRepository has method to update document category."""
|
||||
from inference.data.repositories import DocumentRepository
|
||||
from backend.data.repositories import DocumentRepository
|
||||
|
||||
repo = DocumentRepository()
|
||||
assert hasattr(repo, "update_category")
|
||||
|
||||
def test_update_request_schema(self):
|
||||
"""Test DocumentUpdateRequest can update category."""
|
||||
from inference.web.schemas.admin import DocumentUpdateRequest
|
||||
from backend.web.schemas.admin import DocumentUpdateRequest
|
||||
|
||||
request = DocumentUpdateRequest(category="receipt")
|
||||
assert request.category == "receipt"
|
||||
|
||||
@@ -11,8 +11,8 @@ from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from inference.web.app import create_app
|
||||
from inference.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
from backend.web.app import create_app
|
||||
from backend.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -87,8 +87,8 @@ class TestHealthEndpoint:
|
||||
class TestInferEndpoint:
|
||||
"""Test /api/v1/infer endpoint."""
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_accepts_png_file(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -150,8 +150,8 @@ class TestInferEndpoint:
|
||||
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_returns_cross_validation_if_available(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -210,8 +210,8 @@ class TestInferEndpoint:
|
||||
# This test documents that it should be added
|
||||
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_infer_handles_processing_errors_gracefully(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -263,7 +263,7 @@ class TestResultsEndpoint:
|
||||
|
||||
# Mock the storage helper to return our test file path
|
||||
with patch(
|
||||
"inference.web.api.v1.public.inference.get_storage_helper"
|
||||
"backend.web.api.v1.public.inference.get_storage_helper"
|
||||
) as mock_storage:
|
||||
mock_helper = Mock()
|
||||
mock_helper.get_result_local_path.return_value = result_file
|
||||
@@ -285,15 +285,15 @@ class TestInferenceServiceImports:
|
||||
|
||||
This test will fail if there are ImportError issues like:
|
||||
- from ..inference.pipeline (wrong relative import)
|
||||
- from inference.web.inference (non-existent module)
|
||||
- from backend.web.inference (non-existent module)
|
||||
|
||||
It ensures the imports are correct before runtime.
|
||||
"""
|
||||
from inference.web.services.inference import InferenceService
|
||||
from backend.web.services.inference import InferenceService
|
||||
|
||||
# Import the modules that InferenceService tries to import
|
||||
from inference.pipeline.pipeline import InferencePipeline
|
||||
from inference.pipeline.yolo_detector import YOLODetector
|
||||
from backend.pipeline.pipeline import InferencePipeline
|
||||
from backend.pipeline.yolo_detector import YOLODetector
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
|
||||
# If we got here, all imports work correctly
|
||||
|
||||
@@ -10,8 +10,8 @@ from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from inference.web.services.inference import InferenceService
|
||||
from inference.web.config import ModelConfig, StorageConfig
|
||||
from backend.web.services.inference import InferenceService
|
||||
from backend.web.config import ModelConfig, StorageConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -72,8 +72,8 @@ class TestInferenceServiceInitialization:
|
||||
gpu_available = inference_service.gpu_available
|
||||
assert isinstance(gpu_available, bool)
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_imports_correctly(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -102,8 +102,8 @@ class TestInferenceServiceInitialization:
|
||||
mock_yolo_detector.assert_called_once()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_sets_up_pipeline(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -135,8 +135,8 @@ class TestInferenceServiceInitialization:
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_initialize_idempotent(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -161,8 +161,8 @@ class TestInferenceServiceInitialization:
|
||||
class TestInferenceServiceProcessing:
|
||||
"""Test inference processing methods."""
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_process_image_basic_flow(
|
||||
self,
|
||||
@@ -197,8 +197,8 @@ class TestInferenceServiceProcessing:
|
||||
assert result.confidence == {"InvoiceNumber": 0.95}
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
def test_process_image_handles_errors(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
@@ -228,8 +228,8 @@ class TestInferenceServiceProcessing:
|
||||
class TestInferenceServicePDFRendering:
|
||||
"""Test PDF rendering imports."""
|
||||
|
||||
@patch('inference.pipeline.pipeline.InferencePipeline')
|
||||
@patch('inference.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
|
||||
@@ -9,9 +9,9 @@ from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.admin_models import ModelVersion
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.schemas.admin import (
|
||||
from backend.data.admin_models import ModelVersion
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
from backend.web.schemas.admin import (
|
||||
ModelVersionCreateRequest,
|
||||
ModelVersionUpdateRequest,
|
||||
)
|
||||
@@ -173,13 +173,16 @@ class TestGetActiveModelRoute:
|
||||
def test_get_active_model_when_exists(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
|
||||
mock_model = _make_model_version(status="active", is_active=True)
|
||||
mock_models_repo.get_active.return_value = mock_model
|
||||
|
||||
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
|
||||
|
||||
assert result.has_active_model is True
|
||||
assert result.model is not None
|
||||
# model is now a ModelVersionItem, not ModelVersion
|
||||
assert result.model.is_active is True
|
||||
assert result.model.version == "1.0.0"
|
||||
|
||||
def test_get_active_model_when_none(self, mock_models_repo):
|
||||
fn = _find_endpoint("get_active_model")
|
||||
|
||||
@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.async_request_db import ApiKeyConfig
|
||||
from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
from backend.data.async_request_db import ApiKeyConfig
|
||||
from backend.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
|
||||
from backend.web.services.storage_helpers import StorageHelper, get_storage_helper
|
||||
from shared.storage import PREFIXES
|
||||
|
||||
|
||||
@@ -402,10 +402,10 @@ class TestGetStorageHelper:
|
||||
|
||||
def test_returns_helper_instance(self) -> None:
|
||||
"""Should return a StorageHelper instance."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
with patch("backend.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
# Reset the global helper
|
||||
import inference.web.services.storage_helpers as module
|
||||
import backend.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper = get_storage_helper()
|
||||
@@ -414,9 +414,9 @@ class TestGetStorageHelper:
|
||||
|
||||
def test_returns_same_instance(self) -> None:
|
||||
"""Should return the same instance on subsequent calls."""
|
||||
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
with patch("backend.web.services.storage_helpers.get_default_storage") as mock_get:
|
||||
mock_get.return_value = MagicMock()
|
||||
import inference.web.services.storage_helpers as module
|
||||
import backend.web.services.storage_helpers as module
|
||||
module._default_helper = None
|
||||
|
||||
helper1 = get_storage_helper()
|
||||
|
||||
@@ -18,7 +18,7 @@ class TestStorageBackendInitialization:
|
||||
"""Test that get_storage_backend returns a StorageBackend instance."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
from backend.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
@@ -36,7 +36,7 @@ class TestStorageBackendInitialization:
|
||||
"""Test that storage config file is used when present."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
from backend.web.config import get_storage_backend
|
||||
|
||||
config_file = tmp_path / "storage.yaml"
|
||||
storage_path = tmp_path / "storage"
|
||||
@@ -55,7 +55,7 @@ local:
|
||||
"""Test fallback to environment variables when no config file."""
|
||||
from shared.storage.local import LocalStorageBackend
|
||||
|
||||
from inference.web.config import get_storage_backend
|
||||
from backend.web.config import get_storage_backend
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
@@ -71,7 +71,7 @@ local:
|
||||
"""Test that AppConfig can be created with storage backend."""
|
||||
from shared.storage.base import StorageBackend
|
||||
|
||||
from inference.web.config import AppConfig, create_app_config
|
||||
from backend.web.config import AppConfig, create_app_config
|
||||
|
||||
env = {
|
||||
"STORAGE_BACKEND": "local",
|
||||
@@ -103,7 +103,7 @@ class TestStorageBackendInDocumentUpload:
|
||||
# Create a mock upload file
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -130,7 +130,7 @@ class TestStorageBackendInDocumentUpload:
|
||||
|
||||
pdf_content = b"%PDF-1.4 test content"
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -163,7 +163,7 @@ class TestStorageBackendInDocumentDownload:
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestStorageBackendInDocumentDownload:
|
||||
original_content = b"%PDF-1.4 test content"
|
||||
backend.upload_bytes(original_content, doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -214,7 +214,7 @@ class TestStorageBackendInImageServing:
|
||||
image_path = "images/doc-123/page_1.png"
|
||||
backend.upload_bytes(b"fake png content", image_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -233,7 +233,7 @@ class TestStorageBackendInImageServing:
|
||||
storage_path.mkdir(parents=True, exist_ok=True)
|
||||
backend = LocalStorageBackend(str(storage_path))
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -261,7 +261,7 @@ class TestStorageBackendInDocumentDeletion:
|
||||
doc_path = "documents/test-doc.pdf"
|
||||
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
@@ -284,7 +284,7 @@ class TestStorageBackendInDocumentDeletion:
|
||||
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
|
||||
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
|
||||
|
||||
from inference.web.services.document_service import DocumentService
|
||||
from backend.web.services.document_service import DocumentService
|
||||
|
||||
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@ from uuid import uuid4
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from inference.web.api.v1.admin.training import create_training_router
|
||||
from inference.web.core.auth import (
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
from backend.web.core.auth import (
|
||||
validate_admin_token,
|
||||
get_document_repository,
|
||||
get_annotation_repository,
|
||||
|
||||
Reference in New Issue
Block a user