re-structure

This commit is contained in:
Yaojia Wang
2026-02-01 22:55:31 +01:00
parent 400b12a967
commit b602d0a340
176 changed files with 856 additions and 853 deletions

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import AdminAnnotation, AnnotationHistory
from inference.data.repositories.annotation_repository import AnnotationRepository
from backend.data.admin_models import AdminAnnotation, AnnotationHistory
from backend.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepository:
@@ -66,7 +66,7 @@ class TestAnnotationRepository:
def test_create_returns_annotation_id(self, repo):
"""Test create returns annotation ID."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -92,7 +92,7 @@ class TestAnnotationRepository:
def test_create_with_optional_params(self, repo):
"""Test create with optional text_value and confidence."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -124,7 +124,7 @@ class TestAnnotationRepository:
def test_create_default_source_is_manual(self, repo):
"""Test create uses manual as default source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -153,7 +153,7 @@ class TestAnnotationRepository:
def test_create_batch_returns_ids(self, repo):
"""Test create_batch returns list of annotation IDs."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -195,7 +195,7 @@ class TestAnnotationRepository:
def test_create_batch_default_page_number(self, repo):
"""Test create_batch uses page_number=1 by default."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -224,7 +224,7 @@ class TestAnnotationRepository:
def test_create_batch_with_all_optional_params(self, repo):
"""Test create_batch with all optional parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -259,7 +259,7 @@ class TestAnnotationRepository:
def test_create_batch_empty_list(self, repo):
"""Test create_batch with empty list returns empty."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -275,7 +275,7 @@ class TestAnnotationRepository:
def test_get_returns_annotation(self, repo, sample_annotation):
"""Test get returns annotation when exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -289,7 +289,7 @@ class TestAnnotationRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -306,7 +306,7 @@ class TestAnnotationRepository:
def test_get_for_document_returns_all_annotations(self, repo, sample_annotation):
"""Test get_for_document returns all annotations for document."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -319,7 +319,7 @@ class TestAnnotationRepository:
def test_get_for_document_with_page_filter(self, repo, sample_annotation):
"""Test get_for_document filters by page number."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -331,7 +331,7 @@ class TestAnnotationRepository:
def test_get_for_document_returns_empty_list(self, repo):
"""Test get_for_document returns empty list when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -347,7 +347,7 @@ class TestAnnotationRepository:
def test_update_returns_true(self, repo, sample_annotation):
"""Test update returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -363,7 +363,7 @@ class TestAnnotationRepository:
def test_update_returns_false_when_not_found(self, repo):
"""Test update returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -375,7 +375,7 @@ class TestAnnotationRepository:
def test_update_all_fields(self, repo, sample_annotation):
"""Test update can update all fields."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -412,7 +412,7 @@ class TestAnnotationRepository:
def test_update_partial_fields(self, repo, sample_annotation):
"""Test update only updates provided fields."""
original_x = sample_annotation.x_center
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -433,7 +433,7 @@ class TestAnnotationRepository:
def test_delete_returns_true(self, repo, sample_annotation):
"""Test delete returns True when annotation exists."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -446,7 +446,7 @@ class TestAnnotationRepository:
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -463,7 +463,7 @@ class TestAnnotationRepository:
def test_delete_for_document_returns_count(self, repo, sample_annotation):
"""Test delete_for_document returns count of deleted annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -476,7 +476,7 @@ class TestAnnotationRepository:
def test_delete_for_document_with_source_filter(self, repo, sample_annotation):
"""Test delete_for_document filters by source."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_annotation]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -488,7 +488,7 @@ class TestAnnotationRepository:
def test_delete_for_document_returns_zero(self, repo):
"""Test delete_for_document returns 0 when no annotations."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -506,7 +506,7 @@ class TestAnnotationRepository:
def test_verify_marks_annotation_verified(self, repo, sample_annotation):
"""Test verify marks annotation as verified."""
sample_annotation.is_verified = False
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -521,7 +521,7 @@ class TestAnnotationRepository:
def test_verify_returns_none_when_not_found(self, repo):
"""Test verify returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -538,7 +538,7 @@ class TestAnnotationRepository:
def test_override_updates_annotation(self, repo, sample_annotation):
"""Test override updates annotation and creates history."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -559,7 +559,7 @@ class TestAnnotationRepository:
def test_override_returns_none_when_not_found(self, repo):
"""Test override returns None when annotation not found."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -573,7 +573,7 @@ class TestAnnotationRepository:
"""Test override does not change override_source if already manual."""
sample_annotation.source = "manual"
sample_annotation.override_source = None
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -591,7 +591,7 @@ class TestAnnotationRepository:
def test_override_skips_unknown_attributes(self, repo, sample_annotation):
"""Test override ignores unknown attributes."""
sample_annotation.source = "auto"
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_annotation
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -614,7 +614,7 @@ class TestAnnotationRepository:
def test_create_history_returns_history(self, repo):
"""Test create_history returns created history record."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -636,7 +636,7 @@ class TestAnnotationRepository:
def test_create_history_with_minimal_params(self, repo):
"""Test create_history with minimal parameters."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -659,7 +659,7 @@ class TestAnnotationRepository:
def test_get_history_returns_list(self, repo, sample_history):
"""Test get_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -672,7 +672,7 @@ class TestAnnotationRepository:
def test_get_history_returns_empty_list(self, repo):
"""Test get_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -688,7 +688,7 @@ class TestAnnotationRepository:
def test_get_document_history_returns_list(self, repo, sample_history):
"""Test get_document_history returns list of history records."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_history]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -700,7 +700,7 @@ class TestAnnotationRepository:
def test_get_document_history_returns_empty_list(self, repo):
"""Test get_document_history returns empty list when no history."""
with patch("inference.data.repositories.annotation_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.annotation_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -9,7 +9,7 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.repositories.base import BaseRepository
from backend.data.repositories.base import BaseRepository
class ConcreteRepository(BaseRepository[MagicMock]):
@@ -31,7 +31,7 @@ class TestBaseRepository:
def test_session_yields_session(self, repo):
"""Test _session yields a database session."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import BatchUpload, BatchUploadFile
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
from backend.data.admin_models import BatchUpload, BatchUploadFile
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadRepository:
@@ -54,7 +54,7 @@ class TestBatchUploadRepository:
def test_create_returns_batch(self, repo):
"""Test create returns created batch upload."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -70,7 +70,7 @@ class TestBatchUploadRepository:
def test_create_with_upload_source(self, repo):
"""Test create with custom upload source."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -87,7 +87,7 @@ class TestBatchUploadRepository:
def test_create_default_upload_source(self, repo):
"""Test create uses default upload source."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -107,7 +107,7 @@ class TestBatchUploadRepository:
def test_get_returns_batch(self, repo, sample_batch):
"""Test get returns batch when exists."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -121,7 +121,7 @@ class TestBatchUploadRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when batch not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -138,7 +138,7 @@ class TestBatchUploadRepository:
def test_update_updates_batch(self, repo, sample_batch):
"""Test update updates batch fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -156,7 +156,7 @@ class TestBatchUploadRepository:
def test_update_ignores_unknown_fields(self, repo, sample_batch):
"""Test update ignores unknown fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -171,7 +171,7 @@ class TestBatchUploadRepository:
def test_update_not_found(self, repo):
"""Test update does nothing when batch not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -183,7 +183,7 @@ class TestBatchUploadRepository:
def test_update_multiple_fields(self, repo, sample_batch):
"""Test update can update multiple fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_batch
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -206,7 +206,7 @@ class TestBatchUploadRepository:
def test_create_file_returns_file(self, repo):
"""Test create_file returns created file record."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -221,7 +221,7 @@ class TestBatchUploadRepository:
def test_create_file_with_kwargs(self, repo):
"""Test create_file with additional kwargs."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -242,7 +242,7 @@ class TestBatchUploadRepository:
def test_update_file_updates_file(self, repo, sample_file):
"""Test update_file updates file fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -258,7 +258,7 @@ class TestBatchUploadRepository:
def test_update_file_ignores_unknown_fields(self, repo, sample_file):
"""Test update_file ignores unknown fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -273,7 +273,7 @@ class TestBatchUploadRepository:
def test_update_file_not_found(self, repo):
"""Test update_file does nothing when file not found."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -285,7 +285,7 @@ class TestBatchUploadRepository:
def test_update_file_multiple_fields(self, repo, sample_file):
"""Test update_file can update multiple fields."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_file
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -304,7 +304,7 @@ class TestBatchUploadRepository:
def test_get_files_returns_list(self, repo, sample_file):
"""Test get_files returns list of files."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_file]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -317,7 +317,7 @@ class TestBatchUploadRepository:
def test_get_files_returns_empty_list(self, repo):
"""Test get_files returns empty list when no files."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -333,7 +333,7 @@ class TestBatchUploadRepository:
def test_get_paginated_returns_batches_and_total(self, repo, sample_batch):
"""Test get_paginated returns list of batches and total count."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_batch]
@@ -347,7 +347,7 @@ class TestBatchUploadRepository:
def test_get_paginated_with_pagination(self, repo, sample_batch):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 100
mock_session.exec.return_value.all.return_value = [sample_batch]
@@ -360,7 +360,7 @@ class TestBatchUploadRepository:
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
@@ -374,7 +374,7 @@ class TestBatchUploadRepository:
def test_get_paginated_with_admin_token(self, repo, sample_batch):
"""Test get_paginated with admin_token parameter (deprecated, ignored)."""
with patch("inference.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.batch_upload_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_batch]

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from inference.data.repositories.dataset_repository import DatasetRepository
from backend.data.admin_models import TrainingDataset, DatasetDocument, TrainingTask
from backend.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepository:
@@ -69,7 +69,7 @@ class TestDatasetRepository:
def test_create_returns_dataset(self, repo):
"""Test create returns created dataset."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -81,7 +81,7 @@ class TestDatasetRepository:
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -103,7 +103,7 @@ class TestDatasetRepository:
def test_create_default_values(self, repo):
"""Test create uses default values."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -121,7 +121,7 @@ class TestDatasetRepository:
def test_get_returns_dataset(self, repo, sample_dataset):
"""Test get returns dataset when exists."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -135,7 +135,7 @@ class TestDatasetRepository:
def test_get_with_uuid(self, repo, sample_dataset):
"""Test get works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -147,7 +147,7 @@ class TestDatasetRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -164,7 +164,7 @@ class TestDatasetRepository:
def test_get_paginated_returns_datasets_and_total(self, repo, sample_dataset):
"""Test get_paginated returns list of datasets and total count."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
@@ -178,7 +178,7 @@ class TestDatasetRepository:
def test_get_paginated_with_status_filter(self, repo, sample_dataset):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_dataset]
@@ -191,7 +191,7 @@ class TestDatasetRepository:
def test_get_paginated_with_pagination(self, repo, sample_dataset):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_dataset]
@@ -204,7 +204,7 @@ class TestDatasetRepository:
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
@@ -222,7 +222,7 @@ class TestDatasetRepository:
def test_get_active_training_tasks_returns_dict(self, repo, sample_training_task):
"""Test get_active_training_tasks returns dict of active tasks."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_training_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -240,7 +240,7 @@ class TestDatasetRepository:
def test_get_active_training_tasks_invalid_uuid(self, repo):
"""Test get_active_training_tasks filters invalid UUIDs."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -263,7 +263,7 @@ class TestDatasetRepository:
def test_update_status_updates_dataset(self, repo, sample_dataset):
"""Test update_status updates dataset status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -276,7 +276,7 @@ class TestDatasetRepository:
def test_update_status_with_error_message(self, repo, sample_dataset):
"""Test update_status with error message."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -292,7 +292,7 @@ class TestDatasetRepository:
def test_update_status_with_totals(self, repo, sample_dataset):
"""Test update_status with total counts."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -312,7 +312,7 @@ class TestDatasetRepository:
def test_update_status_with_dataset_path(self, repo, sample_dataset):
"""Test update_status with dataset path."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -328,7 +328,7 @@ class TestDatasetRepository:
def test_update_status_with_uuid(self, repo, sample_dataset):
"""Test update_status works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -340,7 +340,7 @@ class TestDatasetRepository:
def test_update_status_not_found(self, repo):
"""Test update_status does nothing when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -356,7 +356,7 @@ class TestDatasetRepository:
def test_update_training_status_updates_dataset(self, repo, sample_dataset):
"""Test update_training_status updates training status."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -370,7 +370,7 @@ class TestDatasetRepository:
def test_update_training_status_with_task_id(self, repo, sample_dataset):
"""Test update_training_status with active task ID."""
task_id = uuid4()
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -387,7 +387,7 @@ class TestDatasetRepository:
def test_update_training_status_updates_main_status(self, repo, sample_dataset):
"""Test update_training_status updates main status when completed."""
sample_dataset.status = "ready"
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -405,7 +405,7 @@ class TestDatasetRepository:
def test_update_training_status_clears_task_id(self, repo, sample_dataset):
"""Test update_training_status clears task ID when None."""
sample_dataset.active_training_task_id = uuid4()
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -421,7 +421,7 @@ class TestDatasetRepository:
def test_update_training_status_not_found(self, repo):
"""Test update_training_status does nothing when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -437,7 +437,7 @@ class TestDatasetRepository:
def test_add_documents_creates_links(self, repo):
"""Test add_documents creates dataset document links."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -464,7 +464,7 @@ class TestDatasetRepository:
def test_add_documents_default_counts(self, repo):
"""Test add_documents uses default counts."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -484,7 +484,7 @@ class TestDatasetRepository:
def test_add_documents_with_uuid(self, repo):
"""Test add_documents works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -502,7 +502,7 @@ class TestDatasetRepository:
def test_add_documents_empty_list(self, repo):
"""Test add_documents with empty list."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -518,7 +518,7 @@ class TestDatasetRepository:
def test_get_documents_returns_list(self, repo, sample_dataset_document):
"""Test get_documents returns list of dataset documents."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -531,7 +531,7 @@ class TestDatasetRepository:
def test_get_documents_with_uuid(self, repo, sample_dataset_document):
"""Test get_documents works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_dataset_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -543,7 +543,7 @@ class TestDatasetRepository:
def test_get_documents_returns_empty_list(self, repo):
"""Test get_documents returns empty list when no documents."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -559,7 +559,7 @@ class TestDatasetRepository:
def test_delete_returns_true(self, repo, sample_dataset):
"""Test delete returns True when dataset exists."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -573,7 +573,7 @@ class TestDatasetRepository:
def test_delete_with_uuid(self, repo, sample_dataset):
"""Test delete works with UUID object."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_dataset
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -585,7 +585,7 @@ class TestDatasetRepository:
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when dataset not found."""
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4
from inference.data.admin_models import AdminDocument, AdminAnnotation
from inference.data.repositories.document_repository import DocumentRepository
from backend.data.admin_models import AdminDocument, AdminAnnotation
from backend.data.repositories.document_repository import DocumentRepository
class TestDocumentRepository:
@@ -95,7 +95,7 @@ class TestDocumentRepository:
def test_create_returns_document_id(self, repo):
"""Test create returns document ID."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -113,7 +113,7 @@ class TestDocumentRepository:
def test_create_with_all_parameters(self, repo):
"""Test create with all optional parameters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -145,7 +145,7 @@ class TestDocumentRepository:
def test_get_returns_document(self, repo, sample_document):
"""Test get returns document when exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -159,7 +159,7 @@ class TestDocumentRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -187,7 +187,7 @@ class TestDocumentRepository:
def test_get_paginated_no_filters(self, repo, sample_document):
"""Test get_paginated with no filters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -201,7 +201,7 @@ class TestDocumentRepository:
def test_get_paginated_with_status_filter(self, repo, sample_document):
"""Test get_paginated with status filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -214,7 +214,7 @@ class TestDocumentRepository:
def test_get_paginated_with_upload_source_filter(self, repo, sample_document):
"""Test get_paginated with upload_source filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -227,7 +227,7 @@ class TestDocumentRepository:
def test_get_paginated_with_auto_label_status_filter(self, repo, sample_document):
"""Test get_paginated with auto_label_status filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -241,7 +241,7 @@ class TestDocumentRepository:
def test_get_paginated_with_batch_id_filter(self, repo, sample_document):
"""Test get_paginated with batch_id filter."""
batch_id = str(uuid4())
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -254,7 +254,7 @@ class TestDocumentRepository:
def test_get_paginated_with_category_filter(self, repo, sample_document):
"""Test get_paginated with category filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -267,7 +267,7 @@ class TestDocumentRepository:
def test_get_paginated_with_has_annotations_true(self, repo, sample_document):
"""Test get_paginated with has_annotations=True."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -280,7 +280,7 @@ class TestDocumentRepository:
def test_get_paginated_with_has_annotations_false(self, repo, sample_document):
"""Test get_paginated with has_annotations=False."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_document]
@@ -297,7 +297,7 @@ class TestDocumentRepository:
def test_update_status(self, repo, sample_document):
"""Test update_status updates document status."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -310,7 +310,7 @@ class TestDocumentRepository:
def test_update_status_with_auto_label_status(self, repo, sample_document):
"""Test update_status with auto_label_status."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -326,7 +326,7 @@ class TestDocumentRepository:
def test_update_status_with_auto_label_error(self, repo, sample_document):
"""Test update_status with auto_label_error."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -342,7 +342,7 @@ class TestDocumentRepository:
def test_update_status_document_not_found(self, repo):
"""Test update_status when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -358,7 +358,7 @@ class TestDocumentRepository:
def test_update_file_path(self, repo, sample_document):
"""Test update_file_path updates document file path."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -371,7 +371,7 @@ class TestDocumentRepository:
def test_update_file_path_document_not_found(self, repo):
"""Test update_file_path when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -387,7 +387,7 @@ class TestDocumentRepository:
def test_update_group_key_returns_true(self, repo, sample_document):
"""Test update_group_key returns True when document exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -400,7 +400,7 @@ class TestDocumentRepository:
def test_update_group_key_returns_false(self, repo):
"""Test update_group_key returns False when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -416,7 +416,7 @@ class TestDocumentRepository:
def test_update_category(self, repo, sample_document):
"""Test update_category updates document category."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -429,7 +429,7 @@ class TestDocumentRepository:
def test_update_category_returns_none_when_not_found(self, repo):
"""Test update_category returns None when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -445,7 +445,7 @@ class TestDocumentRepository:
def test_delete_returns_true_when_exists(self, repo, sample_document):
"""Test delete returns True when document exists."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = []
@@ -460,7 +460,7 @@ class TestDocumentRepository:
def test_delete_with_annotations(self, repo, sample_document):
"""Test delete removes annotations before deleting document."""
annotation = MagicMock()
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_session.exec.return_value.all.return_value = [annotation]
@@ -474,7 +474,7 @@ class TestDocumentRepository:
def test_delete_returns_false_when_not_exists(self, repo):
"""Test delete returns False when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -490,7 +490,7 @@ class TestDocumentRepository:
def test_get_categories(self, repo):
"""Test get_categories returns unique categories."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = ["invoice", "receipt", None]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -506,7 +506,7 @@ class TestDocumentRepository:
def test_get_labeled_for_export(self, repo, labeled_document):
"""Test get_labeled_for_export returns labeled documents."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -519,7 +519,7 @@ class TestDocumentRepository:
def test_get_labeled_for_export_with_token(self, repo, labeled_document):
"""Test get_labeled_for_export with admin_token filter."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [labeled_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -535,7 +535,7 @@ class TestDocumentRepository:
def test_count_by_status(self, repo):
"""Test count_by_status returns status counts."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [
("pending", 10),
@@ -554,7 +554,7 @@ class TestDocumentRepository:
def test_get_by_ids(self, repo, sample_document):
"""Test get_by_ids returns documents by IDs."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_document]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -570,7 +570,7 @@ class TestDocumentRepository:
def test_get_for_training_basic(self, repo, labeled_document):
"""Test get_for_training with default parameters."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
@@ -584,7 +584,7 @@ class TestDocumentRepository:
def test_get_for_training_with_min_annotation_count(self, repo, labeled_document):
"""Test get_for_training with min_annotation_count."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
@@ -597,7 +597,7 @@ class TestDocumentRepository:
def test_get_for_training_exclude_used(self, repo, labeled_document):
"""Test get_for_training with exclude_used_in_training."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
@@ -610,7 +610,7 @@ class TestDocumentRepository:
def test_get_for_training_no_annotations(self, repo, labeled_document):
"""Test get_for_training with has_annotations=False."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [labeled_document]
@@ -628,7 +628,7 @@ class TestDocumentRepository:
def test_acquire_annotation_lock_success(self, repo, sample_document):
"""Test acquire_annotation_lock when no lock exists."""
sample_document.annotation_lock_until = None
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -641,7 +641,7 @@ class TestDocumentRepository:
def test_acquire_annotation_lock_fails_when_locked(self, repo, locked_document):
"""Test acquire_annotation_lock fails when document is already locked."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -653,7 +653,7 @@ class TestDocumentRepository:
def test_acquire_annotation_lock_document_not_found(self, repo):
"""Test acquire_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -669,7 +669,7 @@ class TestDocumentRepository:
def test_release_annotation_lock_success(self, repo, locked_document):
"""Test release_annotation_lock releases the lock."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -682,7 +682,7 @@ class TestDocumentRepository:
def test_release_annotation_lock_document_not_found(self, repo):
"""Test release_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -699,7 +699,7 @@ class TestDocumentRepository:
def test_extend_annotation_lock_success(self, repo, locked_document):
"""Test extend_annotation_lock extends the lock."""
original_lock = locked_document.annotation_lock_until
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = locked_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -713,7 +713,7 @@ class TestDocumentRepository:
def test_extend_annotation_lock_fails_when_no_lock(self, repo, sample_document):
"""Test extend_annotation_lock fails when no lock exists."""
sample_document.annotation_lock_until = None
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -725,7 +725,7 @@ class TestDocumentRepository:
def test_extend_annotation_lock_fails_when_expired(self, repo, expired_lock_document):
"""Test extend_annotation_lock fails when lock is expired."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = expired_lock_document
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -737,7 +737,7 @@ class TestDocumentRepository:
def test_extend_annotation_lock_document_not_found(self, repo):
"""Test extend_annotation_lock when document not found."""
with patch("inference.data.repositories.document_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.document_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import ModelVersion
from inference.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.admin_models import ModelVersion
from backend.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionRepository:
@@ -62,7 +62,7 @@ class TestModelVersionRepository:
def test_create_returns_model(self, repo):
"""Test create returns created model version."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -82,7 +82,7 @@ class TestModelVersionRepository:
dataset_id = uuid4()
trained_at = datetime.now(timezone.utc)
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -115,7 +115,7 @@ class TestModelVersionRepository:
task_id = uuid4()
dataset_id = uuid4()
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -134,7 +134,7 @@ class TestModelVersionRepository:
def test_create_without_optional_ids(self, repo):
"""Test create without task_id and dataset_id."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -155,7 +155,7 @@ class TestModelVersionRepository:
def test_get_returns_model(self, repo, sample_model):
"""Test get returns model when exists."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -169,7 +169,7 @@ class TestModelVersionRepository:
def test_get_with_uuid(self, repo, sample_model):
"""Test get works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -181,7 +181,7 @@ class TestModelVersionRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -198,7 +198,7 @@ class TestModelVersionRepository:
def test_get_paginated_returns_models_and_total(self, repo, sample_model):
"""Test get_paginated returns list of models and total count."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
@@ -212,7 +212,7 @@ class TestModelVersionRepository:
def test_get_paginated_with_status_filter(self, repo, sample_model):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_model]
@@ -225,7 +225,7 @@ class TestModelVersionRepository:
def test_get_paginated_with_pagination(self, repo, sample_model):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_model]
@@ -238,7 +238,7 @@ class TestModelVersionRepository:
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
@@ -256,7 +256,7 @@ class TestModelVersionRepository:
def test_get_active_returns_active_model(self, repo, active_model):
"""Test get_active returns the active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -270,7 +270,7 @@ class TestModelVersionRepository:
def test_get_active_returns_none(self, repo):
"""Test get_active returns None when no active model."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.first.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -287,7 +287,7 @@ class TestModelVersionRepository:
def test_activate_activates_model(self, repo, sample_model, active_model):
"""Test activate sets model as active and deactivates others."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [active_model]
mock_session.get.return_value = sample_model
@@ -304,7 +304,7 @@ class TestModelVersionRepository:
def test_activate_with_uuid(self, repo, sample_model):
"""Test activate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
@@ -318,7 +318,7 @@ class TestModelVersionRepository:
def test_activate_returns_none_when_not_found(self, repo):
"""Test activate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = None
@@ -332,7 +332,7 @@ class TestModelVersionRepository:
def test_activate_sets_activated_at(self, repo, sample_model):
"""Test activate sets activated_at timestamp."""
sample_model.activated_at = None
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_session.get.return_value = sample_model
@@ -349,7 +349,7 @@ class TestModelVersionRepository:
def test_deactivate_deactivates_model(self, repo, active_model):
"""Test deactivate sets model as inactive."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -364,7 +364,7 @@ class TestModelVersionRepository:
def test_deactivate_with_uuid(self, repo, active_model):
"""Test deactivate works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -376,7 +376,7 @@ class TestModelVersionRepository:
def test_deactivate_returns_none_when_not_found(self, repo):
"""Test deactivate returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -392,7 +392,7 @@ class TestModelVersionRepository:
def test_update_updates_model(self, repo, sample_model):
"""Test update updates model metadata."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -409,7 +409,7 @@ class TestModelVersionRepository:
def test_update_all_fields(self, repo, sample_model):
"""Test update can update all fields."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -428,7 +428,7 @@ class TestModelVersionRepository:
def test_update_with_uuid(self, repo, sample_model):
"""Test update works with UUID object."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -440,7 +440,7 @@ class TestModelVersionRepository:
def test_update_returns_none_when_not_found(self, repo):
"""Test update returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -453,7 +453,7 @@ class TestModelVersionRepository:
def test_update_partial_fields(self, repo, sample_model):
"""Test update only updates provided fields."""
original_name = sample_model.name
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -474,7 +474,7 @@ class TestModelVersionRepository:
def test_archive_archives_model(self, repo, sample_model):
"""Test archive sets model status to archived."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -489,7 +489,7 @@ class TestModelVersionRepository:
def test_archive_with_uuid(self, repo, sample_model):
"""Test archive works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -501,7 +501,7 @@ class TestModelVersionRepository:
def test_archive_returns_none_when_not_found(self, repo):
"""Test archive returns None when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -513,7 +513,7 @@ class TestModelVersionRepository:
def test_archive_returns_none_when_active(self, repo, active_model):
"""Test archive returns None when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -530,7 +530,7 @@ class TestModelVersionRepository:
def test_delete_returns_true(self, repo, sample_model):
"""Test delete returns True when model exists and not active."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -545,7 +545,7 @@ class TestModelVersionRepository:
def test_delete_with_uuid(self, repo, sample_model):
"""Test delete works with UUID object."""
sample_model.is_active = False
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -557,7 +557,7 @@ class TestModelVersionRepository:
def test_delete_returns_false_when_not_found(self, repo):
"""Test delete returns False when model not found."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -570,7 +570,7 @@ class TestModelVersionRepository:
def test_delete_returns_false_when_active(self, repo, active_model):
"""Test delete returns False when model is active."""
with patch("inference.data.repositories.model_version_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.model_version_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = active_model
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -8,8 +8,8 @@ import pytest
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch
from inference.data.admin_models import AdminToken
from inference.data.repositories.token_repository import TokenRepository
from backend.data.admin_models import AdminToken
from backend.data.repositories.token_repository import TokenRepository
class TestTokenRepository:
@@ -55,7 +55,7 @@ class TestTokenRepository:
def test_is_valid_returns_true_for_active_token(self, repo, sample_token):
"""Test is_valid returns True for an active, non-expired token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -68,7 +68,7 @@ class TestTokenRepository:
def test_is_valid_returns_false_for_nonexistent_token(self, repo):
"""Test is_valid returns False for a non-existent token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -80,7 +80,7 @@ class TestTokenRepository:
def test_is_valid_returns_false_for_inactive_token(self, repo, inactive_token):
"""Test is_valid returns False for an inactive token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = inactive_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -92,7 +92,7 @@ class TestTokenRepository:
def test_is_valid_returns_false_for_expired_token(self, repo, expired_token):
"""Test is_valid returns False for an expired token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = expired_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -104,7 +104,7 @@ class TestTokenRepository:
def test_get_returns_token_when_exists(self, repo, sample_token):
"""Test get returns token when it exists."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -119,7 +119,7 @@ class TestTokenRepository:
def test_get_returns_none_when_not_exists(self, repo):
"""Test get returns None when token doesn't exist."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -131,7 +131,7 @@ class TestTokenRepository:
def test_create_new_token(self, repo):
"""Test creating a new token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None # Token doesn't exist
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -147,7 +147,7 @@ class TestTokenRepository:
def test_create_updates_existing_token(self, repo, sample_token):
"""Test create updates an existing token."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -161,7 +161,7 @@ class TestTokenRepository:
def test_update_usage(self, repo, sample_token):
"""Test updating token last_used_at timestamp."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -174,7 +174,7 @@ class TestTokenRepository:
def test_deactivate_returns_true_when_token_exists(self, repo, sample_token):
"""Test deactivate returns True when token exists."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_token
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -188,7 +188,7 @@ class TestTokenRepository:
def test_deactivate_returns_false_when_token_not_exists(self, repo):
"""Test deactivate returns False when token doesn't exist."""
with patch("inference.data.repositories.base.get_session_context") as mock_ctx:
with patch("backend.data.repositories.base.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -9,8 +9,8 @@ from datetime import datetime, timezone, timedelta
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
from inference.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
from inference.data.repositories.training_task_repository import TrainingTaskRepository
from backend.data.admin_models import TrainingTask, TrainingLog, TrainingDocumentLink
from backend.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskRepository:
@@ -65,7 +65,7 @@ class TestTrainingTaskRepository:
def test_create_returns_task_id(self, repo):
"""Test create returns task ID."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -82,7 +82,7 @@ class TestTrainingTaskRepository:
def test_create_with_all_params(self, repo):
"""Test create with all parameters."""
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -108,7 +108,7 @@ class TestTrainingTaskRepository:
def test_create_pending_status_when_not_scheduled(self, repo):
"""Test create sets pending status when no scheduled_at."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -124,7 +124,7 @@ class TestTrainingTaskRepository:
def test_create_scheduled_status_when_scheduled(self, repo):
"""Test create sets scheduled status when scheduled_at is provided."""
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -144,7 +144,7 @@ class TestTrainingTaskRepository:
def test_get_returns_task(self, repo, sample_task):
"""Test get returns task when exists."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -158,7 +158,7 @@ class TestTrainingTaskRepository:
def test_get_returns_none_when_not_found(self, repo):
"""Test get returns None when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -175,7 +175,7 @@ class TestTrainingTaskRepository:
def test_get_by_token_returns_task(self, repo, sample_task):
"""Test get_by_token returns task (delegates to get)."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -187,7 +187,7 @@ class TestTrainingTaskRepository:
def test_get_by_token_without_token_param(self, repo, sample_task):
"""Test get_by_token works without token parameter."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -203,7 +203,7 @@ class TestTrainingTaskRepository:
def test_get_paginated_returns_tasks_and_total(self, repo, sample_task):
"""Test get_paginated returns list of tasks and total count."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_task]
@@ -217,7 +217,7 @@ class TestTrainingTaskRepository:
def test_get_paginated_with_status_filter(self, repo, sample_task):
"""Test get_paginated filters by status."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 1
mock_session.exec.return_value.all.return_value = [sample_task]
@@ -230,7 +230,7 @@ class TestTrainingTaskRepository:
def test_get_paginated_with_pagination(self, repo, sample_task):
"""Test get_paginated with limit and offset."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 50
mock_session.exec.return_value.all.return_value = [sample_task]
@@ -243,7 +243,7 @@ class TestTrainingTaskRepository:
def test_get_paginated_empty_results(self, repo):
"""Test get_paginated with no results."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.one.return_value = 0
mock_session.exec.return_value.all.return_value = []
@@ -261,7 +261,7 @@ class TestTrainingTaskRepository:
def test_get_pending_returns_pending_tasks(self, repo, sample_task):
"""Test get_pending returns pending and scheduled tasks."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_task]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -273,7 +273,7 @@ class TestTrainingTaskRepository:
def test_get_pending_returns_empty_list(self, repo):
"""Test get_pending returns empty list when no pending tasks."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -289,7 +289,7 @@ class TestTrainingTaskRepository:
def test_update_status_updates_task(self, repo, sample_task):
"""Test update_status updates task status."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -302,7 +302,7 @@ class TestTrainingTaskRepository:
def test_update_status_sets_started_at_for_running(self, repo, sample_task):
"""Test update_status sets started_at when status is running."""
sample_task.started_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -315,7 +315,7 @@ class TestTrainingTaskRepository:
def test_update_status_sets_completed_at_for_completed(self, repo, sample_task):
"""Test update_status sets completed_at when status is completed."""
sample_task.completed_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -328,7 +328,7 @@ class TestTrainingTaskRepository:
def test_update_status_sets_completed_at_for_failed(self, repo, sample_task):
"""Test update_status sets completed_at when status is failed."""
sample_task.completed_at = None
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -341,7 +341,7 @@ class TestTrainingTaskRepository:
def test_update_status_with_result_metrics(self, repo, sample_task):
"""Test update_status with result metrics."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -357,7 +357,7 @@ class TestTrainingTaskRepository:
def test_update_status_with_model_path(self, repo, sample_task):
"""Test update_status with model path."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -373,7 +373,7 @@ class TestTrainingTaskRepository:
def test_update_status_not_found(self, repo):
"""Test update_status does nothing when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -390,7 +390,7 @@ class TestTrainingTaskRepository:
def test_cancel_returns_true_for_pending(self, repo, sample_task):
"""Test cancel returns True for pending task."""
sample_task.status = "pending"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -404,7 +404,7 @@ class TestTrainingTaskRepository:
def test_cancel_returns_true_for_scheduled(self, repo, sample_task):
"""Test cancel returns True for scheduled task."""
sample_task.status = "scheduled"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -418,7 +418,7 @@ class TestTrainingTaskRepository:
def test_cancel_returns_false_for_running(self, repo, sample_task):
"""Test cancel returns False for running task."""
sample_task.status = "running"
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = sample_task
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -430,7 +430,7 @@ class TestTrainingTaskRepository:
def test_cancel_returns_false_when_not_found(self, repo):
"""Test cancel returns False when task not found."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.get.return_value = None
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -446,7 +446,7 @@ class TestTrainingTaskRepository:
def test_add_log_creates_log_entry(self, repo):
"""Test add_log creates a log entry."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -464,7 +464,7 @@ class TestTrainingTaskRepository:
def test_add_log_with_details(self, repo):
"""Test add_log with details."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -485,7 +485,7 @@ class TestTrainingTaskRepository:
def test_get_logs_returns_list(self, repo, sample_log):
"""Test get_logs returns list of logs."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_log]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -498,7 +498,7 @@ class TestTrainingTaskRepository:
def test_get_logs_with_pagination(self, repo, sample_log):
"""Test get_logs with limit and offset."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_log]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -510,7 +510,7 @@ class TestTrainingTaskRepository:
def test_get_logs_returns_empty_list(self, repo):
"""Test get_logs returns empty list when no logs."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -526,7 +526,7 @@ class TestTrainingTaskRepository:
def test_create_document_link_returns_link(self, repo):
"""Test create_document_link returns created link."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -543,7 +543,7 @@ class TestTrainingTaskRepository:
def test_create_document_link_with_snapshot(self, repo):
"""Test create_document_link with annotation snapshot."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_ctx.return_value.__exit__ = MagicMock(return_value=False)
@@ -564,7 +564,7 @@ class TestTrainingTaskRepository:
def test_get_document_links_returns_list(self, repo, sample_link):
"""Test get_document_links returns list of links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_link]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -576,7 +576,7 @@ class TestTrainingTaskRepository:
def test_get_document_links_returns_empty_list(self, repo):
"""Test get_document_links returns empty list when no links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -592,7 +592,7 @@ class TestTrainingTaskRepository:
def test_get_document_training_tasks_returns_list(self, repo, sample_link):
"""Test get_document_training_tasks returns list of links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = [sample_link]
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)
@@ -604,7 +604,7 @@ class TestTrainingTaskRepository:
def test_get_document_training_tasks_returns_empty_list(self, repo):
"""Test get_document_training_tasks returns empty list when no links."""
with patch("inference.data.repositories.training_task_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.training_task_repository.get_session_context") as mock_ctx:
mock_session = MagicMock()
mock_session.exec.return_value.all.return_value = []
mock_ctx.return_value.__enter__ = MagicMock(return_value=mock_session)

View File

@@ -9,7 +9,7 @@ import pytest
from datetime import datetime
from uuid import UUID, uuid4
from inference.data.admin_models import (
from backend.data.admin_models import (
BatchUpload,
BatchUploadFile,
TrainingDocumentLink,

View File

@@ -11,8 +11,8 @@ Tests field normalization functions:
"""
import pytest
from inference.pipeline.field_extractor import FieldExtractor
from inference.pipeline.normalizers import (
from backend.pipeline.field_extractor import FieldExtractor
from backend.pipeline.normalizers import (
InvoiceNumberNormalizer,
OcrNumberNormalizer,
BankgiroNormalizer,

View File

@@ -8,7 +8,7 @@ matching variants from known values.
from unittest.mock import patch
import pytest
from inference.pipeline.normalizers import (
from backend.pipeline.normalizers import (
NormalizationResult,
InvoiceNumberNormalizer,
OcrNumberNormalizer,
@@ -490,7 +490,7 @@ class TestEnhancedAmountNormalizer:
# Need input that bypasses labeled patterns AND shared validator
# but has decimal pattern matches
with patch(
"inference.pipeline.normalizers.amount.FieldValidators.parse_amount",
"backend.pipeline.normalizers.amount.FieldValidators.parse_amount",
return_value=None,
):
result = normalizer.normalize("Items: 100,00 and 200,00 and 300,00")
@@ -560,7 +560,7 @@ class TestDateNormalizer:
def test_fallback_pattern_with_mock(self, normalizer):
"""Test fallback PATTERNS when shared validator returns None (line 83)."""
with patch(
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
return_value=None,
):
result = normalizer.normalize("2025-08-29")
@@ -667,7 +667,7 @@ class TestEnhancedDateNormalizer:
def test_swedish_pattern_fallback_with_mock(self, normalizer):
"""Test Swedish pattern when shared validator returns None (line 170)."""
with patch(
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
return_value=None,
):
result = normalizer.normalize("15 maj 2025")
@@ -677,7 +677,7 @@ class TestEnhancedDateNormalizer:
def test_ymd_compact_fallback_with_mock(self, normalizer):
"""Test ymd_compact pattern when shared validator returns None (lines 187-192)."""
with patch(
"inference.pipeline.normalizers.date.FieldValidators.format_date_iso",
"backend.pipeline.normalizers.date.FieldValidators.format_date_iso",
return_value=None,
):
result = normalizer.normalize("20250315")

View File

@@ -10,7 +10,7 @@ Tests the cross-validation logic between payment_line and detected fields:
import pytest
from unittest.mock import MagicMock, patch
from inference.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
from backend.pipeline.pipeline import InferencePipeline, InferenceResult, CrossValidationResult
class TestCrossValidationResult:

View File

@@ -82,7 +82,7 @@ def mock_inference_service():
@pytest.fixture
def mock_storage_config(temp_storage_dir):
"""Create mock storage configuration."""
from inference.web.config import StorageConfig
from backend.web.config import StorageConfig
return StorageConfig(
upload_dir=temp_storage_dir["uploads"],
@@ -104,13 +104,13 @@ def mock_storage_helper(temp_storage_dir):
@pytest.fixture
def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
"""Create a test FastAPI application with mocked storage."""
from inference.web.api.v1.public.inference import create_inference_router
from backend.web.api.v1.public.inference import create_inference_router
app = FastAPI()
# Patch get_storage_helper to return our mock
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
inference_router = create_inference_router(mock_inference_service, mock_storage_config)
@@ -123,7 +123,7 @@ def test_app(mock_inference_service, mock_storage_config, mock_storage_helper):
def client(test_app, mock_storage_helper):
"""Create a test client with storage helper patched."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
yield TestClient(test_app)
@@ -151,7 +151,7 @@ class TestInferenceEndpoint:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -171,7 +171,7 @@ class TestInferenceEndpoint:
png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -186,7 +186,7 @@ class TestInferenceEndpoint:
def test_infer_invalid_file_type(self, client, mock_storage_helper):
"""Test rejection of invalid file types."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -199,7 +199,7 @@ class TestInferenceEndpoint:
def test_infer_no_file(self, client, mock_storage_helper):
"""Test rejection when no file provided."""
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post("/api/v1/infer")
@@ -211,7 +211,7 @@ class TestInferenceEndpoint:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -238,7 +238,7 @@ class TestInferenceResultFormat:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -258,7 +258,7 @@ class TestInferenceResultFormat:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -282,7 +282,7 @@ class TestErrorHandling:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -297,7 +297,7 @@ class TestErrorHandling:
"""Test handling of empty files."""
# Empty file still has valid content type
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -318,7 +318,7 @@ class TestResponseFormat:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -338,7 +338,7 @@ class TestResponseFormat:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -359,7 +359,7 @@ class TestDocumentIdGeneration:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(
@@ -376,7 +376,7 @@ class TestDocumentIdGeneration:
pdf_content = b"%PDF-1.4\n%test\n"
with patch(
"inference.web.api.v1.public.inference.get_storage_helper",
"backend.web.api.v1.public.inference.get_storage_helper",
return_value=mock_storage_helper,
):
response = client.post(

View File

@@ -11,7 +11,7 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.data.admin_models import (
from backend.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
@@ -20,8 +20,8 @@ from inference.data.admin_models import (
TrainingDataset,
TrainingTask,
)
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from inference.web.core.auth import get_admin_token_dep
from backend.web.api.v1.admin.dashboard import create_dashboard_router
from backend.web.core.auth import validate_admin_token
def create_test_app(override_token_dep):
@@ -31,7 +31,7 @@ def create_test_app(override_token_dep):
app.include_router(router)
# Override auth dependency
app.dependency_overrides[get_admin_token_dep] = lambda: override_token_dep
app.dependency_overrides[validate_admin_token] = lambda: override_token_dep
return app

View File

@@ -26,7 +26,7 @@ from uuid import uuid4
import pytest
from sqlmodel import Session, SQLModel, create_engine
from inference.data.admin_models import (
from backend.data.admin_models import (
AdminAnnotation,
AdminDocument,
AdminToken,
@@ -170,15 +170,15 @@ def patched_session(db_session):
# All modules that import get_session_context
patch_targets = [
"inference.data.database.get_session_context",
"inference.data.repositories.document_repository.get_session_context",
"inference.data.repositories.annotation_repository.get_session_context",
"inference.data.repositories.dataset_repository.get_session_context",
"inference.data.repositories.training_task_repository.get_session_context",
"inference.data.repositories.model_version_repository.get_session_context",
"inference.data.repositories.batch_upload_repository.get_session_context",
"inference.data.repositories.token_repository.get_session_context",
"inference.web.services.dashboard_service.get_session_context",
"backend.data.database.get_session_context",
"backend.data.repositories.document_repository.get_session_context",
"backend.data.repositories.annotation_repository.get_session_context",
"backend.data.repositories.dataset_repository.get_session_context",
"backend.data.repositories.training_task_repository.get_session_context",
"backend.data.repositories.model_version_repository.get_session_context",
"backend.data.repositories.batch_upload_repository.get_session_context",
"backend.data.repositories.token_repository.get_session_context",
"backend.web.services.dashboard_service.get_session_context",
]
with ExitStack() as stack:

View File

@@ -14,13 +14,13 @@ from unittest.mock import MagicMock, patch
import pytest
import numpy as np
from inference.pipeline.pipeline import (
from backend.pipeline.pipeline import (
InferencePipeline,
InferenceResult,
CrossValidationResult,
)
from inference.pipeline.yolo_detector import Detection
from inference.pipeline.field_extractor import ExtractedField
from backend.pipeline.yolo_detector import Detection
from backend.pipeline.field_extractor import ExtractedField
@pytest.fixture

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
import pytest
from inference.data.repositories.annotation_repository import AnnotationRepository
from backend.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepositoryCreate:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadCreate:

View File

@@ -8,7 +8,7 @@ from uuid import uuid4
import pytest
from inference.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepositoryCreate:
@@ -300,7 +300,7 @@ class TestActiveTrainingTasks:
repo = DatasetRepository()
# Update task to running
from inference.data.repositories.training_task_repository import TrainingTaskRepository
from backend.data.repositories.training_task_repository import TrainingTaskRepository
task_repo = TrainingTaskRepository()
task_repo.update_status(str(sample_training_task.task_id), "running")

View File

@@ -10,8 +10,8 @@ from uuid import uuid4
import pytest
from sqlmodel import select
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.document_repository import DocumentRepository
from backend.data.admin_models import AdminAnnotation, AdminDocument
from backend.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
@@ -243,7 +243,7 @@ class TestDocumentRepositoryDelete:
assert result is True
# Verify annotation is also deleted
from inference.data.repositories.annotation_repository import AnnotationRepository
from backend.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone, timedelta
import pytest
from inference.data.repositories.token_repository import TokenRepository
from backend.data.repositories.token_repository import TokenRepository
class TestTokenCreate:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
from backend.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.data.admin_models import (
from backend.data.admin_models import (
AdminAnnotation,
AdminDocument,
AnnotationHistory,
@@ -17,7 +17,7 @@ from inference.data.admin_models import (
TrainingDataset,
TrainingTask,
)
from inference.web.services.dashboard_service import (
from backend.web.services.dashboard_service import (
DashboardStatsService,
DashboardActivityService,
is_annotation_complete,

View File

@@ -12,11 +12,11 @@ from uuid import uuid4
import pytest
import yaml
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.annotation_repository import AnnotationRepository
from inference.data.repositories.dataset_repository import DatasetRepository
from inference.data.repositories.document_repository import DocumentRepository
from inference.web.services.dataset_builder import DatasetBuilder
from backend.data.admin_models import AdminAnnotation, AdminDocument
from backend.data.repositories.annotation_repository import AnnotationRepository
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.document_repository import DocumentRepository
from backend.web.services.dataset_builder import DatasetBuilder
@pytest.fixture

View File

@@ -9,7 +9,7 @@ from unittest.mock import MagicMock
import pytest
from inference.web.services.document_service import DocumentService, DocumentResult
from backend.web.services.document_service import DocumentService, DocumentResult
class MockStorageBackend:

View File

@@ -7,7 +7,7 @@ Tests for database connection, session management, and basic operations.
import pytest
from sqlmodel import Session, select
from inference.data.admin_models import AdminDocument, AdminToken
from backend.data.admin_models import AdminDocument, AdminToken
class TestDatabaseConnection:

View File

@@ -10,7 +10,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from inference.pipeline.customer_number_parser import (
from backend.pipeline.customer_number_parser import (
CustomerNumberParser,
DashFormatPattern,
NoDashFormatPattern,

View File

@@ -26,7 +26,7 @@ def _collect_modules(package_name: str) -> list[str]:
SHARED_MODULES = _collect_modules("shared")
INFERENCE_MODULES = _collect_modules("inference")
BACKEND_MODULES = _collect_modules("backend")
TRAINING_MODULES = _collect_modules("training")
@@ -36,9 +36,9 @@ def test_shared_module_imports(module_name: str) -> None:
importlib.import_module(module_name)
@pytest.mark.parametrize("module_name", INFERENCE_MODULES)
def test_inference_module_imports(module_name: str) -> None:
"""Every module in the inference package should import without error."""
@pytest.mark.parametrize("module_name", BACKEND_MODULES)
def test_backend_module_imports(module_name: str) -> None:
"""Every module in the backend package should import without error."""
importlib.import_module(module_name)

View File

@@ -10,7 +10,7 @@ from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from inference.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData
from backend.pipeline.payment_line_parser import PaymentLineParser, PaymentLineData
class TestPaymentLineParser:

View File

@@ -10,12 +10,12 @@ from uuid import UUID
import pytest
from inference.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from inference.data.models import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.core.rate_limiter import RateLimiter
from backend.data.async_request_db import ApiKeyConfig, AsyncRequestDB
from backend.data.models import AsyncRequest
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from backend.web.services.async_processing import AsyncProcessingService
from backend.web.config import AsyncConfig, StorageConfig
from backend.web.core.rate_limiter import RateLimiter
@pytest.fixture

View File

@@ -14,7 +14,7 @@ class TestTaskStatus:
def test_task_status_basic_fields(self) -> None:
"""TaskStatus has all required fields."""
from inference.web.core.task_interface import TaskStatus
from backend.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test_runner",
@@ -29,7 +29,7 @@ class TestTaskStatus:
def test_task_status_with_error(self) -> None:
"""TaskStatus can include optional error message."""
from inference.web.core.task_interface import TaskStatus
from backend.web.core.task_interface import TaskStatus
status = TaskStatus(
name="failed_runner",
@@ -42,7 +42,7 @@ class TestTaskStatus:
def test_task_status_default_error_is_none(self) -> None:
"""TaskStatus error defaults to None."""
from inference.web.core.task_interface import TaskStatus
from backend.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
@@ -54,7 +54,7 @@ class TestTaskStatus:
def test_task_status_is_frozen(self) -> None:
"""TaskStatus is immutable (frozen dataclass)."""
from inference.web.core.task_interface import TaskStatus
from backend.web.core.task_interface import TaskStatus
status = TaskStatus(
name="test",
@@ -71,20 +71,20 @@ class TestTaskRunnerInterface:
def test_cannot_instantiate_directly(self) -> None:
"""TaskRunner is abstract and cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
from backend.web.core.task_interface import TaskRunner
with pytest.raises(TypeError):
TaskRunner() # type: ignore[abstract]
def test_is_abstract_base_class(self) -> None:
"""TaskRunner inherits from ABC."""
from inference.web.core.task_interface import TaskRunner
from backend.web.core.task_interface import TaskRunner
assert issubclass(TaskRunner, ABC)
def test_subclass_missing_name_cannot_instantiate(self) -> None:
"""Subclass without name property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskRunner, TaskStatus
class MissingName(TaskRunner):
def start(self) -> None:
@@ -105,7 +105,7 @@ class TestTaskRunnerInterface:
def test_subclass_missing_start_cannot_instantiate(self) -> None:
"""Subclass without start method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskRunner, TaskStatus
class MissingStart(TaskRunner):
@property
@@ -127,7 +127,7 @@ class TestTaskRunnerInterface:
def test_subclass_missing_stop_cannot_instantiate(self) -> None:
"""Subclass without stop method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskRunner, TaskStatus
class MissingStop(TaskRunner):
@property
@@ -149,7 +149,7 @@ class TestTaskRunnerInterface:
def test_subclass_missing_is_running_cannot_instantiate(self) -> None:
"""Subclass without is_running property cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskRunner, TaskStatus
class MissingIsRunning(TaskRunner):
@property
@@ -170,7 +170,7 @@ class TestTaskRunnerInterface:
def test_subclass_missing_get_status_cannot_instantiate(self) -> None:
"""Subclass without get_status method cannot be instantiated."""
from inference.web.core.task_interface import TaskRunner
from backend.web.core.task_interface import TaskRunner
class MissingGetStatus(TaskRunner):
@property
@@ -192,7 +192,7 @@ class TestTaskRunnerInterface:
def test_complete_subclass_can_instantiate(self) -> None:
"""Complete subclass implementing all methods can be instantiated."""
from inference.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskRunner, TaskStatus
class CompleteRunner(TaskRunner):
def __init__(self) -> None:
@@ -240,7 +240,7 @@ class TestTaskManager:
def test_register_runner(self) -> None:
"""Can register a task runner."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
@property
@@ -268,14 +268,14 @@ class TestTaskManager:
def test_get_runner_returns_none_for_unknown(self) -> None:
"""get_runner returns None for unknown runner name."""
from inference.web.core.task_interface import TaskManager
from backend.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_runner("unknown") is None
def test_start_all_runners(self) -> None:
"""start_all starts all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
@@ -315,7 +315,7 @@ class TestTaskManager:
def test_stop_all_runners(self) -> None:
"""stop_all stops all registered runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
@@ -355,7 +355,7 @@ class TestTaskManager:
def test_get_all_status(self) -> None:
"""get_all_status returns status of all runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str, pending: int) -> None:
@@ -391,14 +391,14 @@ class TestTaskManager:
def test_get_all_status_empty_when_no_runners(self) -> None:
"""get_all_status returns empty dict when no runners registered."""
from inference.web.core.task_interface import TaskManager
from backend.web.core.task_interface import TaskManager
manager = TaskManager()
assert manager.get_all_status() == {}
def test_runner_names_property(self) -> None:
"""runner_names returns list of all registered runner names."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
class MockRunner(TaskRunner):
def __init__(self, runner_name: str) -> None:
@@ -430,7 +430,7 @@ class TestTaskManager:
def test_stop_all_with_timeout_distribution(self) -> None:
"""stop_all distributes timeout across runners."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
received_timeouts: list[float | None] = []
@@ -467,7 +467,7 @@ class TestTaskManager:
def test_start_all_skips_runners_requiring_arguments(self) -> None:
"""start_all skips runners that require arguments."""
from inference.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
from backend.web.core.task_interface import TaskManager, TaskRunner, TaskStatus
no_args_started = []
with_args_started = []
@@ -521,7 +521,7 @@ class TestTaskManager:
def test_stop_all_with_no_runners(self) -> None:
"""stop_all does nothing when no runners registered."""
from inference.web.core.task_interface import TaskManager
from backend.web.core.task_interface import TaskManager
manager = TaskManager()
# Should not raise any exception
@@ -535,23 +535,23 @@ class TestTrainingSchedulerInterface:
def test_training_scheduler_is_task_runner(self) -> None:
"""TrainingScheduler inherits from TaskRunner."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskRunner
from backend.web.core.scheduler import TrainingScheduler
from backend.web.core.task_interface import TaskRunner
scheduler = TrainingScheduler()
assert isinstance(scheduler, TaskRunner)
def test_training_scheduler_name(self) -> None:
"""TrainingScheduler has correct name."""
from inference.web.core.scheduler import TrainingScheduler
from backend.web.core.scheduler import TrainingScheduler
scheduler = TrainingScheduler()
assert scheduler.name == "training_scheduler"
def test_training_scheduler_get_status(self) -> None:
"""TrainingScheduler provides status via get_status."""
from inference.web.core.scheduler import TrainingScheduler
from inference.web.core.task_interface import TaskStatus
from backend.web.core.scheduler import TrainingScheduler
from backend.web.core.task_interface import TaskStatus
scheduler = TrainingScheduler()
# Mock the training tasks repository
@@ -572,29 +572,29 @@ class TestAutoLabelSchedulerInterface:
def test_autolabel_scheduler_is_task_runner(self) -> None:
"""AutoLabelScheduler inherits from TaskRunner."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskRunner
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
from backend.web.core.task_interface import TaskRunner
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert isinstance(scheduler, TaskRunner)
def test_autolabel_scheduler_name(self) -> None:
"""AutoLabelScheduler has correct name."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
scheduler = AutoLabelScheduler()
assert scheduler.name == "autolabel_scheduler"
def test_autolabel_scheduler_get_status(self) -> None:
"""AutoLabelScheduler provides status via get_status."""
from inference.web.core.autolabel_scheduler import AutoLabelScheduler
from inference.web.core.task_interface import TaskStatus
from backend.web.core.autolabel_scheduler import AutoLabelScheduler
from backend.web.core.task_interface import TaskStatus
with patch("inference.web.core.autolabel_scheduler.get_storage_helper"):
with patch("backend.web.core.autolabel_scheduler.get_storage_helper"):
with patch(
"inference.web.core.autolabel_scheduler.get_pending_autolabel_documents"
"backend.web.core.autolabel_scheduler.get_pending_autolabel_documents"
) as mock_get:
mock_get.return_value = [MagicMock(), MagicMock(), MagicMock()]
@@ -612,23 +612,23 @@ class TestAsyncTaskQueueInterface:
def test_async_queue_is_task_runner(self) -> None:
"""AsyncTaskQueue inherits from TaskRunner."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskRunner
from backend.web.workers.async_queue import AsyncTaskQueue
from backend.web.core.task_interface import TaskRunner
queue = AsyncTaskQueue()
assert isinstance(queue, TaskRunner)
def test_async_queue_name(self) -> None:
"""AsyncTaskQueue has correct name."""
from inference.web.workers.async_queue import AsyncTaskQueue
from backend.web.workers.async_queue import AsyncTaskQueue
queue = AsyncTaskQueue()
assert queue.name == "async_task_queue"
def test_async_queue_get_status(self) -> None:
"""AsyncTaskQueue provides status via get_status."""
from inference.web.workers.async_queue import AsyncTaskQueue
from inference.web.core.task_interface import TaskStatus
from backend.web.workers.async_queue import AsyncTaskQueue
from backend.web.core.task_interface import TaskStatus
queue = AsyncTaskQueue()
status = queue.get_status()
@@ -645,23 +645,23 @@ class TestBatchTaskQueueInterface:
def test_batch_queue_is_task_runner(self) -> None:
"""BatchTaskQueue inherits from TaskRunner."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskRunner
from backend.web.workers.batch_queue import BatchTaskQueue
from backend.web.core.task_interface import TaskRunner
queue = BatchTaskQueue()
assert isinstance(queue, TaskRunner)
def test_batch_queue_name(self) -> None:
"""BatchTaskQueue has correct name."""
from inference.web.workers.batch_queue import BatchTaskQueue
from backend.web.workers.batch_queue import BatchTaskQueue
queue = BatchTaskQueue()
assert queue.name == "batch_task_queue"
def test_batch_queue_get_status(self) -> None:
"""BatchTaskQueue provides status via get_status."""
from inference.web.workers.batch_queue import BatchTaskQueue
from inference.web.core.task_interface import TaskStatus
from backend.web.workers.batch_queue import BatchTaskQueue
from backend.web.core.task_interface import TaskStatus
queue = BatchTaskQueue()
status = queue.get_status()

View File

@@ -9,10 +9,10 @@ from uuid import UUID
from fastapi import HTTPException
from inference.data.admin_models import AdminAnnotation, AdminDocument
from backend.data.admin_models import AdminAnnotation, AdminDocument
from shared.fields import FIELD_CLASSES
from inference.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from inference.web.schemas.admin import (
from backend.web.api.v1.admin.annotations import _validate_uuid, create_annotation_router
from backend.web.schemas.admin import (
AnnotationCreate,
AnnotationUpdate,
AutoLabelRequest,
@@ -234,10 +234,10 @@ class TestAutoLabelFilePathResolution:
mock_storage.get_raw_pdf_local_path.return_value = mock_path
with patch(
"inference.web.services.storage_helpers.get_storage_helper",
"backend.web.services.storage_helpers.get_storage_helper",
return_value=mock_storage,
):
from inference.web.services.storage_helpers import get_storage_helper
from backend.web.services.storage_helpers import get_storage_helper
storage = get_storage_helper()
result = storage.get_raw_pdf_local_path("test.pdf")

View File

@@ -8,9 +8,9 @@ from unittest.mock import MagicMock, patch
from fastapi import HTTPException
from inference.data.repositories import TokenRepository
from inference.data.admin_models import AdminToken
from inference.web.core.auth import (
from backend.data.repositories import TokenRepository
from backend.data.admin_models import AdminToken
from backend.web.core.auth import (
get_token_repository,
reset_token_repository,
validate_admin_token,
@@ -81,7 +81,7 @@ class TestTokenRepository:
def test_is_valid_active_token(self):
"""Test valid active token."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -98,7 +98,7 @@ class TestTokenRepository:
def test_is_valid_inactive_token(self):
"""Test inactive token."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -115,7 +115,7 @@ class TestTokenRepository:
def test_is_valid_expired_token(self):
"""Test expired token."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -138,7 +138,7 @@ class TestTokenRepository:
This verifies the fix for comparing timezone-aware and naive datetimes.
The auth API now creates tokens with timezone-aware expiration dates.
"""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -157,7 +157,7 @@ class TestTokenRepository:
def test_is_valid_not_expired_token_timezone_aware(self):
"""Test non-expired token with timezone-aware datetime."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
@@ -175,7 +175,7 @@ class TestTokenRepository:
def test_is_valid_token_not_found(self):
"""Test token not found."""
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
with patch("backend.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
mock_session = MagicMock()
mock_ctx.return_value.__enter__.return_value = mock_session
mock_session.get.return_value = None

View File

@@ -12,9 +12,9 @@ from uuid import UUID
from fastapi import HTTPException
from fastapi.testclient import TestClient
from inference.data.admin_models import AdminDocument, AdminToken
from inference.web.api.v1.admin.documents import _validate_uuid, create_documents_router
from inference.web.config import StorageConfig
from backend.data.admin_models import AdminDocument, AdminToken
from backend.web.api.v1.admin.documents import _validate_uuid, create_documents_router
from backend.web.config import StorageConfig
# Test UUID
@@ -66,7 +66,7 @@ class TestCreateTokenEndpoint:
def test_create_token_success(self, mock_db):
"""Test successful token creation."""
from inference.web.schemas.admin import AdminTokenCreate
from backend.web.schemas.admin import AdminTokenCreate
request = AdminTokenCreate(name="Test Token", expires_in_days=30)

View File

@@ -9,9 +9,9 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.documents import create_documents_router
from inference.web.config import StorageConfig
from inference.web.core.auth import (
from backend.web.api.v1.admin.documents import create_documents_router
from backend.web.config import StorageConfig
from backend.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,

View File

@@ -1,33 +1,33 @@
"""
Tests to verify admin schemas split maintains backward compatibility.
All existing imports from inference.web.schemas.admin must continue to work.
All existing imports from backend.web.schemas.admin must continue to work.
"""
import pytest
class TestEnumImports:
"""All enums importable from inference.web.schemas.admin."""
"""All enums importable from backend.web.schemas.admin."""
def test_document_status(self):
from inference.web.schemas.admin import DocumentStatus
from backend.web.schemas.admin import DocumentStatus
assert DocumentStatus.PENDING == "pending"
def test_auto_label_status(self):
from inference.web.schemas.admin import AutoLabelStatus
from backend.web.schemas.admin import AutoLabelStatus
assert AutoLabelStatus.RUNNING == "running"
def test_training_status(self):
from inference.web.schemas.admin import TrainingStatus
from backend.web.schemas.admin import TrainingStatus
assert TrainingStatus.PENDING == "pending"
def test_training_type(self):
from inference.web.schemas.admin import TrainingType
from backend.web.schemas.admin import TrainingType
assert TrainingType.TRAIN == "train"
def test_annotation_source(self):
from inference.web.schemas.admin import AnnotationSource
from backend.web.schemas.admin import AnnotationSource
assert AnnotationSource.MANUAL == "manual"
@@ -35,12 +35,12 @@ class TestAuthImports:
"""Auth schemas importable."""
def test_admin_token_create(self):
from inference.web.schemas.admin import AdminTokenCreate
from backend.web.schemas.admin import AdminTokenCreate
token = AdminTokenCreate(name="test")
assert token.name == "test"
def test_admin_token_response(self):
from inference.web.schemas.admin import AdminTokenResponse
from backend.web.schemas.admin import AdminTokenResponse
assert AdminTokenResponse is not None
@@ -48,23 +48,23 @@ class TestDocumentImports:
"""Document schemas importable."""
def test_document_upload_response(self):
from inference.web.schemas.admin import DocumentUploadResponse
from backend.web.schemas.admin import DocumentUploadResponse
assert DocumentUploadResponse is not None
def test_document_item(self):
from inference.web.schemas.admin import DocumentItem
from backend.web.schemas.admin import DocumentItem
assert DocumentItem is not None
def test_document_list_response(self):
from inference.web.schemas.admin import DocumentListResponse
from backend.web.schemas.admin import DocumentListResponse
assert DocumentListResponse is not None
def test_document_detail_response(self):
from inference.web.schemas.admin import DocumentDetailResponse
from backend.web.schemas.admin import DocumentDetailResponse
assert DocumentDetailResponse is not None
def test_document_stats_response(self):
from inference.web.schemas.admin import DocumentStatsResponse
from backend.web.schemas.admin import DocumentStatsResponse
assert DocumentStatsResponse is not None
@@ -72,60 +72,60 @@ class TestAnnotationImports:
"""Annotation schemas importable."""
def test_bounding_box(self):
from inference.web.schemas.admin import BoundingBox
from backend.web.schemas.admin import BoundingBox
bb = BoundingBox(x=0, y=0, width=100, height=50)
assert bb.width == 100
def test_annotation_create(self):
from inference.web.schemas.admin import AnnotationCreate
from backend.web.schemas.admin import AnnotationCreate
assert AnnotationCreate is not None
def test_annotation_update(self):
from inference.web.schemas.admin import AnnotationUpdate
from backend.web.schemas.admin import AnnotationUpdate
assert AnnotationUpdate is not None
def test_annotation_item(self):
from inference.web.schemas.admin import AnnotationItem
from backend.web.schemas.admin import AnnotationItem
assert AnnotationItem is not None
def test_annotation_response(self):
from inference.web.schemas.admin import AnnotationResponse
from backend.web.schemas.admin import AnnotationResponse
assert AnnotationResponse is not None
def test_annotation_list_response(self):
from inference.web.schemas.admin import AnnotationListResponse
from backend.web.schemas.admin import AnnotationListResponse
assert AnnotationListResponse is not None
def test_annotation_lock_request(self):
from inference.web.schemas.admin import AnnotationLockRequest
from backend.web.schemas.admin import AnnotationLockRequest
assert AnnotationLockRequest is not None
def test_annotation_lock_response(self):
from inference.web.schemas.admin import AnnotationLockResponse
from backend.web.schemas.admin import AnnotationLockResponse
assert AnnotationLockResponse is not None
def test_auto_label_request(self):
from inference.web.schemas.admin import AutoLabelRequest
from backend.web.schemas.admin import AutoLabelRequest
assert AutoLabelRequest is not None
def test_auto_label_response(self):
from inference.web.schemas.admin import AutoLabelResponse
from backend.web.schemas.admin import AutoLabelResponse
assert AutoLabelResponse is not None
def test_annotation_verify_request(self):
from inference.web.schemas.admin import AnnotationVerifyRequest
from backend.web.schemas.admin import AnnotationVerifyRequest
assert AnnotationVerifyRequest is not None
def test_annotation_verify_response(self):
from inference.web.schemas.admin import AnnotationVerifyResponse
from backend.web.schemas.admin import AnnotationVerifyResponse
assert AnnotationVerifyResponse is not None
def test_annotation_override_request(self):
from inference.web.schemas.admin import AnnotationOverrideRequest
from backend.web.schemas.admin import AnnotationOverrideRequest
assert AnnotationOverrideRequest is not None
def test_annotation_override_response(self):
from inference.web.schemas.admin import AnnotationOverrideResponse
from backend.web.schemas.admin import AnnotationOverrideResponse
assert AnnotationOverrideResponse is not None
@@ -133,68 +133,68 @@ class TestTrainingImports:
"""Training schemas importable."""
def test_training_config(self):
from inference.web.schemas.admin import TrainingConfig
from backend.web.schemas.admin import TrainingConfig
config = TrainingConfig()
assert config.epochs == 100
def test_training_task_create(self):
from inference.web.schemas.admin import TrainingTaskCreate
from backend.web.schemas.admin import TrainingTaskCreate
assert TrainingTaskCreate is not None
def test_training_task_item(self):
from inference.web.schemas.admin import TrainingTaskItem
from backend.web.schemas.admin import TrainingTaskItem
assert TrainingTaskItem is not None
def test_training_task_list_response(self):
from inference.web.schemas.admin import TrainingTaskListResponse
from backend.web.schemas.admin import TrainingTaskListResponse
assert TrainingTaskListResponse is not None
def test_training_task_detail_response(self):
from inference.web.schemas.admin import TrainingTaskDetailResponse
from backend.web.schemas.admin import TrainingTaskDetailResponse
assert TrainingTaskDetailResponse is not None
def test_training_task_response(self):
from inference.web.schemas.admin import TrainingTaskResponse
from backend.web.schemas.admin import TrainingTaskResponse
assert TrainingTaskResponse is not None
def test_training_log_item(self):
from inference.web.schemas.admin import TrainingLogItem
from backend.web.schemas.admin import TrainingLogItem
assert TrainingLogItem is not None
def test_training_logs_response(self):
from inference.web.schemas.admin import TrainingLogsResponse
from backend.web.schemas.admin import TrainingLogsResponse
assert TrainingLogsResponse is not None
def test_export_request(self):
from inference.web.schemas.admin import ExportRequest
from backend.web.schemas.admin import ExportRequest
assert ExportRequest is not None
def test_export_response(self):
from inference.web.schemas.admin import ExportResponse
from backend.web.schemas.admin import ExportResponse
assert ExportResponse is not None
def test_training_document_item(self):
from inference.web.schemas.admin import TrainingDocumentItem
from backend.web.schemas.admin import TrainingDocumentItem
assert TrainingDocumentItem is not None
def test_training_documents_response(self):
from inference.web.schemas.admin import TrainingDocumentsResponse
from backend.web.schemas.admin import TrainingDocumentsResponse
assert TrainingDocumentsResponse is not None
def test_model_metrics(self):
from inference.web.schemas.admin import ModelMetrics
from backend.web.schemas.admin import ModelMetrics
assert ModelMetrics is not None
def test_training_model_item(self):
from inference.web.schemas.admin import TrainingModelItem
from backend.web.schemas.admin import TrainingModelItem
assert TrainingModelItem is not None
def test_training_models_response(self):
from inference.web.schemas.admin import TrainingModelsResponse
from backend.web.schemas.admin import TrainingModelsResponse
assert TrainingModelsResponse is not None
def test_training_history_item(self):
from inference.web.schemas.admin import TrainingHistoryItem
from backend.web.schemas.admin import TrainingHistoryItem
assert TrainingHistoryItem is not None
@@ -202,31 +202,31 @@ class TestDatasetImports:
"""Dataset schemas importable."""
def test_dataset_create_request(self):
from inference.web.schemas.admin import DatasetCreateRequest
from backend.web.schemas.admin import DatasetCreateRequest
assert DatasetCreateRequest is not None
def test_dataset_document_item(self):
from inference.web.schemas.admin import DatasetDocumentItem
from backend.web.schemas.admin import DatasetDocumentItem
assert DatasetDocumentItem is not None
def test_dataset_response(self):
from inference.web.schemas.admin import DatasetResponse
from backend.web.schemas.admin import DatasetResponse
assert DatasetResponse is not None
def test_dataset_detail_response(self):
from inference.web.schemas.admin import DatasetDetailResponse
from backend.web.schemas.admin import DatasetDetailResponse
assert DatasetDetailResponse is not None
def test_dataset_list_item(self):
from inference.web.schemas.admin import DatasetListItem
from backend.web.schemas.admin import DatasetListItem
assert DatasetListItem is not None
def test_dataset_list_response(self):
from inference.web.schemas.admin import DatasetListResponse
from backend.web.schemas.admin import DatasetListResponse
assert DatasetListResponse is not None
def test_dataset_train_request(self):
from inference.web.schemas.admin import DatasetTrainRequest
from backend.web.schemas.admin import DatasetTrainRequest
assert DatasetTrainRequest is not None
@@ -234,12 +234,12 @@ class TestForwardReferences:
"""Forward references resolve correctly."""
def test_document_detail_has_annotation_items(self):
from inference.web.schemas.admin import DocumentDetailResponse
from backend.web.schemas.admin import DocumentDetailResponse
fields = DocumentDetailResponse.model_fields
assert "annotations" in fields
assert "training_history" in fields
def test_dataset_train_request_has_config(self):
from inference.web.schemas.admin import DatasetTrainRequest, TrainingConfig
from backend.web.schemas.admin import DatasetTrainRequest, TrainingConfig
req = DatasetTrainRequest(name="test", config=TrainingConfig())
assert req.config.epochs == 100

View File

@@ -7,15 +7,15 @@ from datetime import datetime, timedelta
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingTask, TrainingLog
from inference.web.api.v1.admin.training import _validate_uuid, create_training_router
from inference.web.core.scheduler import (
from backend.data.admin_models import TrainingTask, TrainingLog
from backend.web.api.v1.admin.training import _validate_uuid, create_training_router
from backend.web.core.scheduler import (
TrainingScheduler,
get_training_scheduler,
start_scheduler,
stop_scheduler,
)
from inference.web.schemas.admin import (
from backend.web.schemas.admin import (
TrainingConfig,
TrainingStatus,
TrainingTaskCreate,

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.locks import create_locks_router
from inference.web.core.auth import (
from backend.web.api.v1.admin.locks import create_locks_router
from backend.web.core.auth import (
validate_admin_token,
get_document_repository,
)

View File

@@ -9,12 +9,12 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.annotations import (
from backend.web.api.v1.admin.annotations import (
create_annotation_router,
get_doc_repository,
get_ann_repository,
)
from inference.web.core.auth import validate_admin_token
from backend.web.core.auth import validate_admin_token
class MockAdminDocument:

View File

@@ -11,7 +11,7 @@ from unittest.mock import MagicMock
import pytest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
class TestAsyncTask:

View File

@@ -11,12 +11,12 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from inference.web.api.v1.public.async_api import create_async_router, set_async_service
from inference.web.services.async_processing import AsyncSubmitResult
from inference.web.dependencies import init_dependencies
from inference.web.rate_limiter import RateLimiter, RateLimitStatus
from inference.web.schemas.inference import AsyncStatus
from backend.data.async_request_db import ApiKeyConfig, AsyncRequest, AsyncRequestDB
from backend.web.api.v1.public.async_api import create_async_router, set_async_service
from backend.web.services.async_processing import AsyncSubmitResult
from backend.web.dependencies import init_dependencies
from backend.web.rate_limiter import RateLimiter, RateLimitStatus
from backend.web.schemas.inference import AsyncStatus
# Valid UUID for testing
TEST_REQUEST_UUID = "550e8400-e29b-41d4-a716-446655440000"

View File

@@ -10,11 +10,11 @@ from unittest.mock import MagicMock, patch
import pytest
from inference.data.async_request_db import AsyncRequest
from inference.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from inference.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from inference.web.config import AsyncConfig, StorageConfig
from inference.web.rate_limiter import RateLimiter
from backend.data.async_request_db import AsyncRequest
from backend.web.workers.async_queue import AsyncTask, AsyncTaskQueue
from backend.web.services.async_processing import AsyncProcessingService, AsyncSubmitResult
from backend.web.config import AsyncConfig, StorageConfig
from backend.web.rate_limiter import RateLimiter
@pytest.fixture
@@ -231,7 +231,7 @@ class TestAsyncProcessingService:
mock_db.get_request.return_value = None
# Mock the storage helper to return the same directory as the fixture
with patch("inference.web.services.async_processing.get_storage_helper") as mock_storage:
with patch("backend.web.services.async_processing.get_storage_helper") as mock_storage:
mock_helper = MagicMock()
mock_helper.get_uploads_base_path.return_value = temp_dir
mock_storage.return_value = mock_helper

View File

@@ -10,8 +10,8 @@ from fastapi import FastAPI
from fastapi.testclient import TestClient
import numpy as np
from inference.web.api.v1.admin.augmentation import create_augmentation_router
from inference.web.core.auth import (
from backend.web.api.v1.admin.augmentation import create_augmentation_router
from backend.web.core.auth import (
validate_admin_token,
get_document_repository,
get_dataset_repository,
@@ -175,7 +175,7 @@ class TestAugmentationPreviewEndpoint:
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
"backend.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image
@@ -251,7 +251,7 @@ class TestAugmentationPreviewConfigEndpoint:
fake_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8)
with patch(
"inference.web.services.augmentation_service.AugmentationService._load_document_page"
"backend.web.services.augmentation_service.AugmentationService._load_document_page"
) as mock_load:
mock_load.return_value = fake_image

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from unittest.mock import Mock, MagicMock
from uuid import uuid4
from inference.web.services.autolabel import AutoLabelService
from backend.web.services.autolabel import AutoLabelService
class MockDocument:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
from backend.web.workers.batch_queue import BatchTask, BatchTaskQueue
class MockBatchService:

View File

@@ -11,10 +11,10 @@ import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.batch.routes import router, get_batch_repository
from inference.web.core.auth import validate_admin_token
from inference.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from inference.web.services.batch_upload import BatchUploadService
from backend.web.api.v1.batch.routes import router, get_batch_repository
from backend.web.core.auth import validate_admin_token
from backend.web.workers.batch_queue import init_batch_queue, shutdown_batch_queue
from backend.web.services.batch_upload import BatchUploadService
class MockBatchUploadRepository:

View File

@@ -9,7 +9,7 @@ from uuid import uuid4
import pytest
from inference.web.services.batch_upload import BatchUploadService
from backend.web.services.batch_upload import BatchUploadService
@pytest.fixture

View File

@@ -28,7 +28,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_invoice_number_and_bankgiro_is_complete(self):
"""Document with invoice_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
@@ -39,7 +39,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_ocr_number_and_plusgiro_is_complete(self):
"""Document with ocr_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
@@ -50,7 +50,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_invoice_number_and_plusgiro_is_complete(self):
"""Document with invoice_number + plusgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
@@ -61,7 +61,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_ocr_number_and_bankgiro_is_complete(self):
"""Document with ocr_number + bankgiro should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 3, "class_name": "ocr_number"},
@@ -72,7 +72,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_only_identifier_is_incomplete(self):
"""Document with only identifier field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
@@ -82,7 +82,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_only_payment_is_incomplete(self):
"""Document with only payment field should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 4, "class_name": "bankgiro"},
@@ -92,13 +92,13 @@ class TestAnnotationCompletenessLogic:
def test_document_with_no_annotations_is_incomplete(self):
"""Document with no annotations should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
assert is_annotation_complete([]) is False
def test_document_with_other_fields_only_is_incomplete(self):
"""Document with only non-essential fields should be incomplete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 1, "class_name": "invoice_date"},
@@ -109,7 +109,7 @@ class TestAnnotationCompletenessLogic:
def test_document_with_all_fields_is_complete(self):
"""Document with all fields should be complete."""
from inference.web.services.dashboard_service import is_annotation_complete
from backend.web.services.dashboard_service import is_annotation_complete
annotations = [
{"class_id": 0, "class_name": "invoice_number"},
@@ -178,7 +178,7 @@ class TestDashboardSchemas:
def test_dashboard_stats_response_schema(self):
"""Test DashboardStatsResponse schema validation."""
from inference.web.schemas.admin import DashboardStatsResponse
from backend.web.schemas.admin import DashboardStatsResponse
response = DashboardStatsResponse(
total_documents=38,
@@ -195,10 +195,10 @@ class TestDashboardSchemas:
assert response.completeness_rate == 75.76
def test_active_model_response_schema(self):
"""Test ActiveModelResponse schema with null model."""
from inference.web.schemas.admin import ActiveModelResponse
"""Test DashboardActiveModelResponse schema with null model."""
from backend.web.schemas.admin import DashboardActiveModelResponse
response = ActiveModelResponse(
response = DashboardActiveModelResponse(
model=None,
running_training=None,
)
@@ -208,7 +208,7 @@ class TestDashboardSchemas:
def test_active_model_info_schema(self):
"""Test ActiveModelInfo schema validation."""
from inference.web.schemas.admin import ActiveModelInfo
from backend.web.schemas.admin import ActiveModelInfo
model = ActiveModelInfo(
version_id=TEST_MODEL_UUID,
@@ -227,7 +227,7 @@ class TestDashboardSchemas:
def test_running_training_info_schema(self):
"""Test RunningTrainingInfo schema validation."""
from inference.web.schemas.admin import RunningTrainingInfo
from backend.web.schemas.admin import RunningTrainingInfo
task = RunningTrainingInfo(
task_id=TEST_TASK_UUID,
@@ -243,7 +243,7 @@ class TestDashboardSchemas:
def test_activity_item_schema(self):
"""Test ActivityItem schema validation."""
from inference.web.schemas.admin import ActivityItem
from backend.web.schemas.admin import ActivityItem
activity = ActivityItem(
type="model_activated",
@@ -258,7 +258,7 @@ class TestDashboardSchemas:
def test_recent_activity_response_schema(self):
"""Test RecentActivityResponse schema with empty activities."""
from inference.web.schemas.admin import RecentActivityResponse
from backend.web.schemas.admin import RecentActivityResponse
response = RecentActivityResponse(activities=[])
@@ -270,7 +270,7 @@ class TestDashboardRouterCreation:
def test_creates_router_with_expected_endpoints(self):
"""Test router is created with expected endpoint paths."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from backend.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
@@ -282,7 +282,7 @@ class TestDashboardRouterCreation:
def test_router_has_correct_prefix(self):
"""Test router has /admin/dashboard prefix."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from backend.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
@@ -290,7 +290,7 @@ class TestDashboardRouterCreation:
def test_router_has_dashboard_tag(self):
"""Test router uses Dashboard tag."""
from inference.web.api.v1.admin.dashboard import create_dashboard_router
from backend.web.api.v1.admin.dashboard import create_dashboard_router
router = create_dashboard_router()
@@ -302,7 +302,7 @@ class TestFieldClassIds:
def test_identifier_class_ids(self):
"""Test identifier field class IDs."""
from inference.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
from backend.web.services.dashboard_service import IDENTIFIER_CLASS_IDS
# invoice_number = 0, ocr_number = 3
assert 0 in IDENTIFIER_CLASS_IDS
@@ -310,7 +310,7 @@ class TestFieldClassIds:
def test_payment_class_ids(self):
"""Test payment field class IDs."""
from inference.web.services.dashboard_service import PAYMENT_CLASS_IDS
from backend.web.services.dashboard_service import PAYMENT_CLASS_IDS
# bankgiro = 4, plusgiro = 5
assert 4 in PAYMENT_CLASS_IDS

View File

@@ -12,7 +12,7 @@ from uuid import uuid4
import pytest
from inference.data.admin_models import (
from backend.data.admin_models import (
AdminAnnotation,
AdminDocument,
TrainingDataset,
@@ -105,7 +105,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""Dataset builder should create images/ and labels/ with train/val/test subdirs."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
dataset_dir = tmp_path / "datasets" / "test"
builder = DatasetBuilder(
@@ -141,7 +141,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""Images should be copied from admin_images to dataset folder."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -177,7 +177,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""YOLO label files should be generated with correct format."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -221,7 +221,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""data.yaml should be generated with correct field classes."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -257,7 +257,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""Documents should be split into train/val/test according to ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -294,7 +294,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""After successful build, dataset status should be updated to 'ready'."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -327,7 +327,7 @@ class TestDatasetBuilder:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""If build fails, dataset status should be set to 'failed'."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -357,7 +357,7 @@ class TestDatasetBuilder:
sample_documents, sample_annotations
):
"""Same seed should produce same splits."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
results = []
for _ in range(2):
@@ -405,7 +405,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with unique group_key are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -433,7 +433,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with null/empty group_key are each treated as independent single-doc groups."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -461,7 +461,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Documents with same group_key should be assigned to the same split."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -494,7 +494,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Multi-doc groups should be split according to train/val/test ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -536,7 +536,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Mix of single-doc and multi-doc groups should be handled correctly."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -574,7 +574,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Same seed should produce same split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -601,7 +601,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Different seeds should potentially produce different split assignments."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -627,7 +627,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Every document should be assigned a split."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -654,7 +654,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""Empty document list should return empty result."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -671,7 +671,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""When all groups have multiple docs, splits should follow ratios."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -707,7 +707,7 @@ class TestAssignSplitsByGroup:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""When all groups have single doc, they are distributed across splits."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
builder = DatasetBuilder(
datasets_repo=mock_datasets_repo,
@@ -798,7 +798,7 @@ class TestBuildDatasetWithGroupKey:
mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""build_dataset should use group_key for split assignment."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
tmp_path, docs = grouped_documents
@@ -847,7 +847,7 @@ class TestBuildDatasetWithGroupKey:
self, tmp_path, mock_datasets_repo, mock_documents_repo, mock_annotations_repo
):
"""All docs with same group_key should go to same split."""
from inference.web.services.dataset_builder import DatasetBuilder
from backend.web.services.dataset_builder import DatasetBuilder
# Create 5 docs all with same group_key
docs = []

View File

@@ -9,9 +9,9 @@ from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
from uuid import UUID
from inference.data.admin_models import TrainingDataset, DatasetDocument
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
from backend.data.admin_models import TrainingDataset, DatasetDocument
from backend.web.api.v1.admin.training import create_training_router
from backend.web.schemas.admin import (
DatasetCreateRequest,
DatasetTrainRequest,
TrainingConfig,
@@ -130,10 +130,10 @@ class TestCreateDatasetRoute:
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
"backend.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
"backend.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"
@@ -222,10 +222,10 @@ class TestCreateDatasetRoute:
)
with patch(
"inference.web.services.dataset_builder.DatasetBuilder",
"backend.web.services.dataset_builder.DatasetBuilder",
return_value=mock_builder,
), patch(
"inference.web.api.v1.admin.training.datasets.get_storage_helper"
"backend.web.api.v1.admin.training.datasets.get_storage_helper"
) as mock_storage:
mock_storage.return_value.get_datasets_base_path.return_value = "/data/datasets"
mock_storage.return_value.get_admin_images_base_path.return_value = "/data/admin_images"

View File

@@ -27,7 +27,7 @@ class TestTrainingDatasetModel:
def test_training_dataset_has_training_status_field(self):
"""TrainingDataset model should have training_status field."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(
name="test-dataset",
@@ -37,7 +37,7 @@ class TestTrainingDatasetModel:
def test_training_dataset_has_active_training_task_id_field(self):
"""TrainingDataset model should have active_training_task_id field."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
task_id = uuid4()
dataset = TrainingDataset(
@@ -48,7 +48,7 @@ class TestTrainingDatasetModel:
def test_training_dataset_defaults(self):
"""TrainingDataset should have correct defaults for new fields."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test-dataset")
assert dataset.training_status is None
@@ -71,7 +71,7 @@ class TestDatasetRepositoryTrainingStatus:
def test_update_training_status_sets_status(self, mock_session):
"""update_training_status should set training_status."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
@@ -81,10 +81,10 @@ class TestDatasetRepositoryTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
from backend.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
@@ -98,7 +98,7 @@ class TestDatasetRepositoryTrainingStatus:
def test_update_training_status_sets_task_id(self, mock_session):
"""update_training_status should set active_training_task_id."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
@@ -109,10 +109,10 @@ class TestDatasetRepositoryTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
from backend.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
@@ -127,7 +127,7 @@ class TestDatasetRepositoryTrainingStatus:
self, mock_session
):
"""update_training_status should update main status to 'trained' when completed."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset_id = uuid4()
dataset = TrainingDataset(
@@ -137,10 +137,10 @@ class TestDatasetRepositoryTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
from backend.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
@@ -156,7 +156,7 @@ class TestDatasetRepositoryTrainingStatus:
self, mock_session
):
"""update_training_status should clear task_id when training completes."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset_id = uuid4()
task_id = uuid4()
@@ -169,10 +169,10 @@ class TestDatasetRepositoryTrainingStatus:
)
mock_session.get.return_value = dataset
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
from backend.data.repositories import DatasetRepository
repo = DatasetRepository()
repo.update_training_status(
@@ -187,10 +187,10 @@ class TestDatasetRepositoryTrainingStatus:
"""update_training_status should handle missing dataset gracefully."""
mock_session.get.return_value = None
with patch("inference.data.repositories.dataset_repository.get_session_context") as mock_ctx:
with patch("backend.data.repositories.dataset_repository.get_session_context") as mock_ctx:
mock_ctx.return_value.__enter__.return_value = mock_session
from inference.data.repositories import DatasetRepository
from backend.data.repositories import DatasetRepository
repo = DatasetRepository()
# Should not raise
@@ -213,7 +213,7 @@ class TestDatasetDetailResponseTrainingStatus:
def test_dataset_detail_response_includes_training_status(self):
"""DatasetDetailResponse schema should include training_status field."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
from backend.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
@@ -240,7 +240,7 @@ class TestDatasetDetailResponseTrainingStatus:
def test_dataset_detail_response_allows_null_training_status(self):
"""DatasetDetailResponse should allow null training_status."""
from inference.web.schemas.admin.datasets import DatasetDetailResponse
from backend.web.schemas.admin.datasets import DatasetDetailResponse
response = DatasetDetailResponse(
dataset_id=str(uuid4()),
@@ -294,7 +294,7 @@ class TestSchedulerDatasetStatusUpdates:
def test_scheduler_sets_running_status_on_task_start(self, mock_datasets_repo, mock_training_tasks_repo):
"""Scheduler should set dataset training_status to 'running' when task starts."""
from inference.web.core.scheduler import TrainingScheduler
from backend.web.core.scheduler import TrainingScheduler
with patch.object(TrainingScheduler, "_run_yolo_training") as mock_train:
mock_train.return_value = {"model_path": "/path/to/model.pt", "metrics": {}}
@@ -333,35 +333,35 @@ class TestDatasetStatusValues:
def test_dataset_status_building(self):
"""Dataset can have status 'building'."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="building")
assert dataset.status == "building"
def test_dataset_status_ready(self):
"""Dataset can have status 'ready'."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="ready")
assert dataset.status == "ready"
def test_dataset_status_trained(self):
"""Dataset can have status 'trained'."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="trained")
assert dataset.status == "trained"
def test_dataset_status_failed(self):
"""Dataset can have status 'failed'."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
dataset = TrainingDataset(name="test", status="failed")
assert dataset.status == "failed"
def test_training_status_values(self):
"""Training status can have various values."""
from inference.data.admin_models import TrainingDataset
from backend.data.admin_models import TrainingDataset
valid_statuses = ["pending", "scheduled", "running", "completed", "failed", "cancelled"]
for status in valid_statuses:

View File

@@ -10,7 +10,7 @@ from datetime import datetime
from unittest.mock import MagicMock
from uuid import UUID, uuid4
from inference.data.admin_models import AdminDocument
from backend.data.admin_models import AdminDocument
# Test constants
@@ -76,7 +76,7 @@ class TestDocumentCategoryInReadModel:
def test_admin_document_read_has_category(self):
"""Test AdminDocumentRead includes category field."""
from inference.data.admin_models import AdminDocumentRead
from backend.data.admin_models import AdminDocumentRead
# Check the model has category field in its schema
assert "category" in AdminDocumentRead.model_fields
@@ -94,7 +94,7 @@ class TestDocumentCategoryAPI:
def test_upload_document_with_category(self, mock_admin_db):
"""Test uploading document with category parameter."""
from inference.web.schemas.admin import DocumentUploadResponse
from backend.web.schemas.admin import DocumentUploadResponse
# Verify response schema supports category
response = DocumentUploadResponse(
@@ -110,7 +110,7 @@ class TestDocumentCategoryAPI:
def test_list_documents_returns_category(self, mock_admin_db):
"""Test list documents endpoint returns category."""
from inference.web.schemas.admin import DocumentItem
from backend.web.schemas.admin import DocumentItem
item = DocumentItem(
document_id=TEST_DOC_UUID,
@@ -127,7 +127,7 @@ class TestDocumentCategoryAPI:
def test_document_detail_includes_category(self, mock_admin_db):
"""Test document detail response includes category."""
from inference.web.schemas.admin import DocumentDetailResponse
from backend.web.schemas.admin import DocumentDetailResponse
# Check schema has category
assert "category" in DocumentDetailResponse.model_fields
@@ -167,14 +167,14 @@ class TestDocumentCategoryUpdate:
def test_update_document_category_schema(self):
"""Test update document request supports category."""
from inference.web.schemas.admin import DocumentUpdateRequest
from backend.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="letter")
assert request.category == "letter"
def test_update_document_category_optional(self):
"""Test category is optional in update request."""
from inference.web.schemas.admin import DocumentUpdateRequest
from backend.web.schemas.admin import DocumentUpdateRequest
# Should not raise - category is optional
request = DocumentUpdateRequest()
@@ -186,7 +186,7 @@ class TestDatasetWithCategory:
def test_dataset_create_with_category_filter(self):
"""Test creating dataset can filter by document category."""
from inference.web.schemas.admin import DatasetCreateRequest
from backend.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Invoice Training Set",
@@ -197,7 +197,7 @@ class TestDatasetWithCategory:
def test_dataset_create_category_is_optional(self):
"""Test category filter is optional when creating dataset."""
from inference.web.schemas.admin import DatasetCreateRequest
from backend.web.schemas.admin import DatasetCreateRequest
request = DatasetCreateRequest(
name="Mixed Training Set",

View File

@@ -23,7 +23,7 @@ class TestGetCategoriesEndpoint:
def test_categories_endpoint_returns_list(self):
"""Test categories endpoint returns list of available categories."""
from inference.web.schemas.admin import DocumentCategoriesResponse
from backend.web.schemas.admin import DocumentCategoriesResponse
# Test schema exists and works
response = DocumentCategoriesResponse(
@@ -35,7 +35,7 @@ class TestGetCategoriesEndpoint:
def test_categories_response_schema(self):
"""Test DocumentCategoriesResponse schema structure."""
from inference.web.schemas.admin import DocumentCategoriesResponse
from backend.web.schemas.admin import DocumentCategoriesResponse
assert "categories" in DocumentCategoriesResponse.model_fields
assert "total" in DocumentCategoriesResponse.model_fields
@@ -69,7 +69,7 @@ class TestDocumentListFilterByCategory:
"""Test list documents endpoint accepts category query parameter."""
# The endpoint should accept ?category=invoice parameter
# This test verifies the schema/query parameter exists
from inference.web.schemas.admin import DocumentListResponse
from backend.web.schemas.admin import DocumentListResponse
# Schema should work with category filter applied
assert DocumentListResponse is not None
@@ -93,7 +93,7 @@ class TestDocumentUploadWithCategory:
def test_upload_response_includes_category(self):
"""Test upload response includes the category that was set."""
from inference.web.schemas.admin import DocumentUploadResponse
from backend.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
@@ -108,7 +108,7 @@ class TestDocumentUploadWithCategory:
def test_upload_defaults_to_invoice_category(self):
"""Test upload defaults to 'invoice' if no category specified."""
from inference.web.schemas.admin import DocumentUploadResponse
from backend.web.schemas.admin import DocumentUploadResponse
response = DocumentUploadResponse(
document_id=TEST_DOC_UUID,
@@ -127,14 +127,14 @@ class TestDocumentRepositoryCategoryMethods:
def test_get_categories_method_exists(self):
"""Test DocumentRepository has get_categories method."""
from inference.data.repositories import DocumentRepository
from backend.data.repositories import DocumentRepository
repo = DocumentRepository()
assert hasattr(repo, "get_categories")
def test_get_paginated_accepts_category_filter(self):
"""Test get_paginated method accepts category parameter."""
from inference.data.repositories import DocumentRepository
from backend.data.repositories import DocumentRepository
import inspect
repo = DocumentRepository()
@@ -152,14 +152,14 @@ class TestUpdateDocumentCategory:
def test_update_category_method_exists(self):
"""Test DocumentRepository has method to update document category."""
from inference.data.repositories import DocumentRepository
from backend.data.repositories import DocumentRepository
repo = DocumentRepository()
assert hasattr(repo, "update_category")
def test_update_request_schema(self):
"""Test DocumentUpdateRequest can update category."""
from inference.web.schemas.admin import DocumentUpdateRequest
from backend.web.schemas.admin import DocumentUpdateRequest
request = DocumentUpdateRequest(category="receipt")
assert request.category == "receipt"

View File

@@ -11,8 +11,8 @@ from fastapi.testclient import TestClient
from PIL import Image
import io
from inference.web.app import create_app
from inference.web.config import ModelConfig, StorageConfig, AppConfig
from backend.web.app import create_app
from backend.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
@@ -87,8 +87,8 @@ class TestHealthEndpoint:
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
@@ -150,8 +150,8 @@ class TestInferEndpoint:
assert response.status_code == 422 # Unprocessable Entity
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
@@ -210,8 +210,8 @@ class TestInferEndpoint:
# This test documents that it should be added
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
@@ -263,7 +263,7 @@ class TestResultsEndpoint:
# Mock the storage helper to return our test file path
with patch(
"inference.web.api.v1.public.inference.get_storage_helper"
"backend.web.api.v1.public.inference.get_storage_helper"
) as mock_storage:
mock_helper = Mock()
mock_helper.get_result_local_path.return_value = result_file
@@ -285,15 +285,15 @@ class TestInferenceServiceImports:
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from inference.web.inference (non-existent module)
- from backend.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from inference.web.services.inference import InferenceService
from backend.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from inference.pipeline.pipeline import InferencePipeline
from inference.pipeline.yolo_detector import YOLODetector
from backend.pipeline.pipeline import InferencePipeline
from backend.pipeline.yolo_detector import YOLODetector
from shared.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly

View File

@@ -10,8 +10,8 @@ from unittest.mock import Mock, patch
from PIL import Image
import io
from inference.web.services.inference import InferenceService
from inference.web.config import ModelConfig, StorageConfig
from backend.web.services.inference import InferenceService
from backend.web.config import ModelConfig, StorageConfig
@pytest.fixture
@@ -72,8 +72,8 @@ class TestInferenceServiceInitialization:
gpu_available = inference_service.gpu_available
assert isinstance(gpu_available, bool)
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_initialize_imports_correctly(
self,
mock_yolo_detector,
@@ -102,8 +102,8 @@ class TestInferenceServiceInitialization:
mock_yolo_detector.assert_called_once()
mock_pipeline.assert_called_once()
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_initialize_sets_up_pipeline(
self,
mock_yolo_detector,
@@ -135,8 +135,8 @@ class TestInferenceServiceInitialization:
enable_fallback=True,
)
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_initialize_idempotent(
self,
mock_yolo_detector,
@@ -161,8 +161,8 @@ class TestInferenceServiceInitialization:
class TestInferenceServiceProcessing:
"""Test inference processing methods."""
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
@patch('ultralytics.YOLO')
def test_process_image_basic_flow(
self,
@@ -197,8 +197,8 @@ class TestInferenceServiceProcessing:
assert result.confidence == {"InvoiceNumber": 0.95}
assert result.processing_time_ms > 0
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_process_image_handles_errors(
self,
mock_yolo_detector,
@@ -228,8 +228,8 @@ class TestInferenceServiceProcessing:
class TestInferenceServicePDFRendering:
"""Test PDF rendering imports."""
@patch('inference.pipeline.pipeline.InferencePipeline')
@patch('inference.pipeline.yolo_detector.YOLODetector')
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
@patch('shared.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly(

View File

@@ -9,9 +9,9 @@ from uuid import UUID
import pytest
from inference.data.admin_models import ModelVersion
from inference.web.api.v1.admin.training import create_training_router
from inference.web.schemas.admin import (
from backend.data.admin_models import ModelVersion
from backend.web.api.v1.admin.training import create_training_router
from backend.web.schemas.admin import (
ModelVersionCreateRequest,
ModelVersionUpdateRequest,
)
@@ -173,13 +173,16 @@ class TestGetActiveModelRoute:
def test_get_active_model_when_exists(self, mock_models_repo):
fn = _find_endpoint("get_active_model")
mock_models_repo.get_active.return_value = _make_model_version(status="active", is_active=True)
mock_model = _make_model_version(status="active", is_active=True)
mock_models_repo.get_active.return_value = mock_model
result = asyncio.run(fn(admin_token=TEST_TOKEN, models=mock_models_repo))
assert result.has_active_model is True
assert result.model is not None
# model is now a ModelVersionItem, not ModelVersion
assert result.model.is_active is True
assert result.model.version == "1.0.0"
def test_get_active_model_when_none(self, mock_models_repo):
fn = _find_endpoint("get_active_model")

View File

@@ -8,8 +8,8 @@ from unittest.mock import MagicMock
import pytest
from inference.data.async_request_db import ApiKeyConfig
from inference.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
from backend.data.async_request_db import ApiKeyConfig
from backend.web.rate_limiter import RateLimiter, RateLimitConfig, RateLimitStatus
class TestRateLimiter:

View File

@@ -3,7 +3,7 @@
import pytest
from unittest.mock import MagicMock, patch
from inference.web.services.storage_helpers import StorageHelper, get_storage_helper
from backend.web.services.storage_helpers import StorageHelper, get_storage_helper
from shared.storage import PREFIXES
@@ -402,10 +402,10 @@ class TestGetStorageHelper:
def test_returns_helper_instance(self) -> None:
"""Should return a StorageHelper instance."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
with patch("backend.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
# Reset the global helper
import inference.web.services.storage_helpers as module
import backend.web.services.storage_helpers as module
module._default_helper = None
helper = get_storage_helper()
@@ -414,9 +414,9 @@ class TestGetStorageHelper:
def test_returns_same_instance(self) -> None:
"""Should return the same instance on subsequent calls."""
with patch("inference.web.services.storage_helpers.get_default_storage") as mock_get:
with patch("backend.web.services.storage_helpers.get_default_storage") as mock_get:
mock_get.return_value = MagicMock()
import inference.web.services.storage_helpers as module
import backend.web.services.storage_helpers as module
module._default_helper = None
helper1 = get_storage_helper()

View File

@@ -18,7 +18,7 @@ class TestStorageBackendInitialization:
"""Test that get_storage_backend returns a StorageBackend instance."""
from shared.storage.base import StorageBackend
from inference.web.config import get_storage_backend
from backend.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
@@ -36,7 +36,7 @@ class TestStorageBackendInitialization:
"""Test that storage config file is used when present."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
from backend.web.config import get_storage_backend
config_file = tmp_path / "storage.yaml"
storage_path = tmp_path / "storage"
@@ -55,7 +55,7 @@ local:
"""Test fallback to environment variables when no config file."""
from shared.storage.local import LocalStorageBackend
from inference.web.config import get_storage_backend
from backend.web.config import get_storage_backend
env = {
"STORAGE_BACKEND": "local",
@@ -71,7 +71,7 @@ local:
"""Test that AppConfig can be created with storage backend."""
from shared.storage.base import StorageBackend
from inference.web.config import AppConfig, create_app_config
from backend.web.config import AppConfig, create_app_config
env = {
"STORAGE_BACKEND": "local",
@@ -103,7 +103,7 @@ class TestStorageBackendInDocumentUpload:
# Create a mock upload file
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -130,7 +130,7 @@ class TestStorageBackendInDocumentUpload:
pdf_content = b"%PDF-1.4 test content"
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -163,7 +163,7 @@ class TestStorageBackendInDocumentDownload:
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -188,7 +188,7 @@ class TestStorageBackendInDocumentDownload:
original_content = b"%PDF-1.4 test content"
backend.upload_bytes(original_content, doc_path)
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -214,7 +214,7 @@ class TestStorageBackendInImageServing:
image_path = "images/doc-123/page_1.png"
backend.upload_bytes(b"fake png content", image_path)
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -233,7 +233,7 @@ class TestStorageBackendInImageServing:
storage_path.mkdir(parents=True, exist_ok=True)
backend = LocalStorageBackend(str(storage_path))
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -261,7 +261,7 @@ class TestStorageBackendInDocumentDeletion:
doc_path = "documents/test-doc.pdf"
backend.upload_bytes(b"%PDF-1.4 test", doc_path)
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)
@@ -284,7 +284,7 @@ class TestStorageBackendInDocumentDeletion:
backend.upload_bytes(b"img1", f"images/{doc_id}/page_1.png")
backend.upload_bytes(b"img2", f"images/{doc_id}/page_2.png")
from inference.web.services.document_service import DocumentService
from backend.web.services.document_service import DocumentService
service = DocumentService(admin_db=mock_admin_db, storage_backend=backend)

View File

@@ -9,8 +9,8 @@ from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from inference.web.api.v1.admin.training import create_training_router
from inference.web.core.auth import (
from backend.web.api.v1.admin.training import create_training_router
from backend.web.core.auth import (
validate_admin_token,
get_document_repository,
get_annotation_repository,