This commit is contained in:
Yaojia Wang
2026-02-01 18:51:54 +01:00
parent 4126196dea
commit a564ac9d70
82 changed files with 13123 additions and 3282 deletions

View File

@@ -9,7 +9,6 @@ from unittest.mock import Mock, MagicMock
from uuid import uuid4
from inference.web.services.autolabel import AutoLabelService
from inference.data.admin_db import AdminDB
class MockDocument:
@@ -23,19 +22,18 @@ class MockDocument:
self.auto_label_error = None
class MockAdminDB:
"""Mock AdminDB for testing."""
class MockDocumentRepository:
"""Mock DocumentRepository for testing."""
def __init__(self):
self.documents = {}
self.annotations = []
self.status_updates = []
def get_document(self, document_id):
def get(self, document_id):
"""Get document by ID."""
return self.documents.get(str(document_id))
def update_document_status(
def update_status(
self,
document_id,
status=None,
@@ -58,19 +56,32 @@ class MockAdminDB:
if auto_label_error:
doc.auto_label_error = auto_label_error
def delete_annotations_for_document(self, document_id, source=None):
class MockAnnotationRepository:
"""Mock AnnotationRepository for testing."""
def __init__(self):
self.annotations = []
def delete_for_document(self, document_id, source=None):
"""Mock delete annotations."""
return 0
def create_annotations_batch(self, annotations):
def create_batch(self, annotations):
"""Mock create annotations."""
self.annotations.extend(annotations)
@pytest.fixture
def mock_db():
"""Create mock admin DB."""
return MockAdminDB()
def mock_doc_repo():
"""Create mock document repository."""
return MockDocumentRepository()
@pytest.fixture
def mock_ann_repo():
"""Create mock annotation repository."""
return MockAnnotationRepository()
@pytest.fixture
@@ -82,10 +93,14 @@ def auto_label_service(monkeypatch):
service._ocr_engine.extract_from_image = Mock(return_value=[])
# Mock the image processing methods to avoid file I/O errors
def mock_process_image(self, document_id, image_path, field_values, db, page_number=1):
def mock_process_image(self, document_id, image_path, field_values, ann_repo, page_number=1):
return 0 # No annotations created (mocked)
def mock_process_pdf(self, document_id, pdf_path, field_values, ann_repo):
return 0 # No annotations created (mocked)
monkeypatch.setattr(AutoLabelService, "_process_image", mock_process_image)
monkeypatch.setattr(AutoLabelService, "_process_pdf", mock_process_pdf)
return service
@@ -93,11 +108,11 @@ def auto_label_service(monkeypatch):
class TestAutoLabelWithLocks:
"""Tests for auto-label service with lock integration."""
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_unlocked_document_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling succeeds on unlocked document."""
# Create test document (unlocked)
document_id = str(uuid4())
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=None,
)
@@ -111,21 +126,22 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should succeed
assert result["status"] == "completed"
# Verify status was updated to running and then completed
assert len(mock_db.status_updates) >= 2
assert mock_db.status_updates[0]["auto_label_status"] == "running"
assert len(mock_doc_repo.status_updates) >= 2
assert mock_doc_repo.status_updates[0]["auto_label_status"] == "running"
def test_auto_label_locked_document_fails(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_locked_document_fails(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling fails on locked document."""
# Create test document (locked for 1 hour)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -139,7 +155,8 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should fail
@@ -150,15 +167,15 @@ class TestAutoLabelWithLocks:
# Verify status was updated to failed
assert any(
update["auto_label_status"] == "failed"
for update in mock_db.status_updates
for update in mock_doc_repo.status_updates
)
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_expired_lock_succeeds(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling succeeds when lock has expired."""
# Create test document (lock expired 1 hour ago)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) - timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -172,18 +189,19 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should succeed (lock expired)
assert result["status"] == "completed"
def test_auto_label_skip_lock_check(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_skip_lock_check(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling with skip_lock_check=True bypasses lock."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(hours=1)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -197,14 +215,15 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
skip_lock_check=True, # Bypass lock check
)
# Should succeed even though document is locked
assert result["status"] == "completed"
def test_auto_label_document_not_found(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_document_not_found(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test auto-labeling fails when document doesn't exist."""
# Create dummy file
test_file = tmp_path / "test.png"
@@ -215,19 +234,20 @@ class TestAutoLabelWithLocks:
document_id=str(uuid4()),
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
)
# Should fail
assert result["status"] == "failed"
assert "not found" in result["error"]
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_db, tmp_path):
def test_auto_label_respects_lock_by_default(self, auto_label_service, mock_doc_repo, mock_ann_repo, tmp_path):
"""Test that lock check is enabled by default."""
# Create test document (locked)
document_id = str(uuid4())
lock_until = datetime.now(timezone.utc) + timedelta(minutes=30)
mock_db.documents[document_id] = MockDocument(
mock_doc_repo.documents[document_id] = MockDocument(
document_id=document_id,
annotation_lock_until=lock_until,
)
@@ -241,7 +261,8 @@ class TestAutoLabelWithLocks:
document_id=document_id,
file_path=str(test_file),
field_values={"invoice_number": "INV-001"},
db=mock_db,
doc_repo=mock_doc_repo,
ann_repo=mock_ann_repo,
# skip_lock_check not specified, should default to False
)