WIP
This commit is contained in:
@@ -9,7 +9,6 @@ from unittest.mock import Mock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from inference.web.services.autolabel import AutoLabelService
|
||||
from inference.data.admin_db import AdminDB
|
||||
|
||||
|
||||
class MockDocument:
|
||||
@@ -23,19 +22,18 @@ class MockDocument:
|
||||
self.auto_label_error = None
|
||||
|
||||
|
||||
class MockAdminDB:
|
||||
"""Mock AdminDB for testing."""
|
||||
class MockDocumentRepository:
|
||||
"""Mock DocumentRepository for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.documents = {}
|
||||
self.annotations = []
|
||||
self.status_updates = []
|
||||
|
||||
def get_document(self, document_id):
|
||||
def get(self, document_id):
|
||||
"""Get document by ID."""
|
||||
return self.documents.get(str(document_id))
|
||||
|
||||
def update_document_status(
|
||||
def update_status(
|
||||
self,
|
||||
document_id,
|
||||
status=None,
|
||||
@@ -58,19 +56,32 @@ class MockAdminDB:
|
||||
if auto_label_error:
|
||||
doc.auto_label_error = auto_label_error
|
||||
|
||||
def delete_annotations_for_document(self, document_id, source=None):
|
||||
|
||||
class MockAnnotationRepository:
|
||||
"""Mock AnnotationRepository for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.annotations = []
|
||||
|
||||
def delete_for_document(self, document_id, source=None):
|
||||
"""Mock delete annotations."""
|
||||
return 0
|
||||
|
||||
def create_annotations_batch(self, annotations):
|
||||
def create_batch(self, annotations):
|
||||
"""Mock create annotations."""
|
||||
self.annotations.extend(annotations)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db():
|
||||
"""Create mock admin DB."""
|
||||
return MockAdminDB()
|
||||
def mock_doc_repo():
|
||||
"""Create mock document repository."""
|
||||
return MockDocumentRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_ann_repo():
|
||||
"""Create mock annotation repository."""
|
||||
return MockAnnotationRepository()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -82,10 +93,14 @@ def auto_label_service(monkeypatch):
|
||||
service._ocr_engine.extract_from_image = Mock(return_value=[])
|
||||
|
||||
# Mock the image processing methods to avoid file I/O errors
|
||||
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
|
||||
def mock_process_image(self, document_id, image_path, field_values, ann_repo, page_number=1):
|
||||
return 0 # No annotations created (mocked)
|
||||
|
||||
def mock_process_pdf(self, document_id, pdf_path, field_values, ann_repo):
|
||||
return 0 # No annotations created (mocked)
|
||||
|
||||
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
|
||||
monkeypatch.setattr(AutoLabelService, "_process_pdf", mock_process_pdf)
|
||||
|
||||
return service
|
||||
|
||||
@@ -93,11 +108,11 @@ def auto_label_service(monkeypatch):
|
||||
class TestAutoLabelWithLocks:
|
||||
"""Tests for auto-label service with lock integration."""
|
||||
|
||||
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling succeeds on unlocked document."""
|
||||
# Create test document (unlocked)
|
||||
document_id = str(uuid4())
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=None,
|
||||
)
|
||||
@@ -111,21 +126,22 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should succeed
|
||||
assert result["status"] == "completed"
|
||||
# Verify status was updated to running and then completed
|
||||
assert len(mock_db.status_updates) >= 2
|
||||
assert mock_db.status_updates[0]["auto_label_status"] == "running"
|
||||
assert len(mock_doc_repo.status_updates) >= 2
|
||||
assert mock_doc_repo.status_updates[0]["auto_label_status"] == "running"
|
||||
|
||||
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_locked_document_fails(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling fails on locked document."""
|
||||
# Create test document (locked for 1 hour)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -139,7 +155,8 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
@@ -150,15 +167,15 @@ class TestAutoLabelWithLocks:
|
||||
# Verify status was updated to failed
|
||||
assert any(
|
||||
update["auto_label_status"] == "failed"
|
||||
for update in mock_db.status_updates
|
||||
for update in mock_doc_repo.status_updates
|
||||
)
|
||||
|
||||
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling succeeds when lock has expired."""
|
||||
# Create test document (lock expired 1 hour ago)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -172,18 +189,19 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should succeed (lock expired)
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_skip_lock_check(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -197,14 +215,15 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
skip_lock_check=True, # Bypass lock check
|
||||
)
|
||||
|
||||
# Should succeed even though document is locked
|
||||
assert result["status"] == "completed"
|
||||
|
||||
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_document_not_found(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test auto-labeling fails when document doesn't exist."""
|
||||
# Create dummy file
|
||||
test_file = tmp_path / "test.png"
|
||||
@@ -215,19 +234,20 @@ class TestAutoLabelWithLocks:
|
||||
document_id=str(uuid4()),
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
)
|
||||
|
||||
# Should fail
|
||||
assert result["status"] == "failed"
|
||||
assert "not found" in result["error"]
|
||||
|
||||
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
|
||||
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
|
||||
"""Test that lock check is enabled by default."""
|
||||
# Create test document (locked)
|
||||
document_id = str(uuid4())
|
||||
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
|
||||
mock_db.documents[document_id] = MockDocument(
|
||||
mock_doc_repo.documents[document_id] = MockDocument(
|
||||
document_id=document_id,
|
||||
annotation_lock_until=lock_until,
|
||||
)
|
||||
@@ -241,7 +261,8 @@ class TestAutoLabelWithLocks:
|
||||
document_id=document_id,
|
||||
file_path=str(test_file),
|
||||
field_values={"invoice_number": "INV-001"},
|
||||
db=mock_db,
|
||||
doc_repo=mock_doc_repo,
|
||||
ann_repo=mock_ann_repo,
|
||||
# skip_lock_check not specified, should default to False
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user