277 lines
9.3 KiB
Python
277 lines
9.3 KiB
Python
"""
|
|
Tests for Annotation Lock Mechanism (Phase 3.3).
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta, timezone
|
|
from uuid import uuid4
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from inference.web.api.v1.admin.locks import create_locks_router
|
|
from inference.web.core.auth import validate_admin_token, get_admin_db
|
|
|
|
|
|
class MockAdminDocument:
|
|
"""Mock AdminDocument for testing."""
|
|
|
|
def __init__(self, **kwargs):
|
|
self.document_id = kwargs.get('document_id', uuid4())
|
|
self.admin_token = kwargs.get('admin_token', 'test-token')
|
|
self.filename = kwargs.get('filename', 'test.pdf')
|
|
self.file_size = kwargs.get('file_size', 100000)
|
|
self.content_type = kwargs.get('content_type', 'application/pdf')
|
|
self.page_count = kwargs.get('page_count', 1)
|
|
self.status = kwargs.get('status', 'pending')
|
|
self.auto_label_status = kwargs.get('auto_label_status', None)
|
|
self.auto_label_error = kwargs.get('auto_label_error', None)
|
|
self.upload_source = kwargs.get('upload_source', 'ui')
|
|
self.batch_id = kwargs.get('batch_id', None)
|
|
self.csv_field_values = kwargs.get('csv_field_values', None)
|
|
self.annotation_lock_until = kwargs.get('annotation_lock_until', None)
|
|
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
|
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
|
|
|
|
|
class MockAdminDB:
|
|
"""Mock AdminDB for testing annotation locks."""
|
|
|
|
def __init__(self):
|
|
self.documents = {}
|
|
|
|
def get_document_by_token(self, document_id, admin_token):
|
|
"""Get single document by ID and token."""
|
|
doc = self.documents.get(document_id)
|
|
if doc and doc.admin_token == admin_token:
|
|
return doc
|
|
return None
|
|
|
|
def acquire_annotation_lock(self, document_id, admin_token, duration_seconds=300):
|
|
"""Acquire annotation lock for a document."""
|
|
doc = self.documents.get(document_id)
|
|
if not doc or doc.admin_token != admin_token:
|
|
return None
|
|
|
|
# Check if already locked
|
|
now = datetime.now(timezone.utc)
|
|
if doc.annotation_lock_until and doc.annotation_lock_until > now:
|
|
return None
|
|
|
|
# Acquire lock
|
|
doc.annotation_lock_until = now + timedelta(seconds=duration_seconds)
|
|
return doc
|
|
|
|
def release_annotation_lock(self, document_id, admin_token, force=False):
|
|
"""Release annotation lock for a document."""
|
|
doc = self.documents.get(document_id)
|
|
if not doc or doc.admin_token != admin_token:
|
|
return None
|
|
|
|
# Release lock
|
|
doc.annotation_lock_until = None
|
|
return doc
|
|
|
|
def extend_annotation_lock(self, document_id, admin_token, additional_seconds=300):
|
|
"""Extend an existing annotation lock."""
|
|
doc = self.documents.get(document_id)
|
|
if not doc or doc.admin_token != admin_token:
|
|
return None
|
|
|
|
# Check if lock exists and is still valid
|
|
now = datetime.now(timezone.utc)
|
|
if not doc.annotation_lock_until or doc.annotation_lock_until <= now:
|
|
return None
|
|
|
|
# Extend lock
|
|
doc.annotation_lock_until = doc.annotation_lock_until + timedelta(seconds=additional_seconds)
|
|
return doc
|
|
|
|
|
|
@pytest.fixture
|
|
def app():
|
|
"""Create test FastAPI app."""
|
|
app = FastAPI()
|
|
|
|
# Create mock DB
|
|
mock_db = MockAdminDB()
|
|
|
|
# Add test document
|
|
doc1 = MockAdminDocument(
|
|
filename="INV001.pdf",
|
|
status="pending",
|
|
upload_source="ui",
|
|
)
|
|
|
|
mock_db.documents[str(doc1.document_id)] = doc1
|
|
|
|
# Override dependencies
|
|
app.dependency_overrides[validate_admin_token] = lambda: "test-token"
|
|
app.dependency_overrides[get_admin_db] = lambda: mock_db
|
|
|
|
# Include router
|
|
router = create_locks_router()
|
|
app.include_router(router)
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
"""Create test client."""
|
|
return TestClient(app)
|
|
|
|
|
|
@pytest.fixture
|
|
def document_id(app):
|
|
"""Get document ID from the mock DB."""
|
|
mock_db = app.dependency_overrides[get_admin_db]()
|
|
return str(list(mock_db.documents.keys())[0])
|
|
|
|
|
|
class TestAnnotationLocks:
|
|
"""Tests for annotation lock endpoints."""
|
|
|
|
def test_acquire_lock_success(self, client, document_id):
|
|
"""Test successfully acquiring an annotation lock."""
|
|
response = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["document_id"] == document_id
|
|
assert data["locked"] is True
|
|
assert data["lock_expires_at"] is not None
|
|
assert "Lock acquired for 300 seconds" in data["message"]
|
|
|
|
def test_acquire_lock_already_locked(self, client, document_id):
|
|
"""Test acquiring lock on already locked document."""
|
|
# First lock
|
|
response1 = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
assert response1.status_code == 200
|
|
|
|
# Try to lock again
|
|
response2 = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
assert response2.status_code == 409
|
|
assert "already locked" in response2.json()["detail"]
|
|
|
|
def test_release_lock_success(self, client, document_id):
|
|
"""Test successfully releasing an annotation lock."""
|
|
# First acquire lock
|
|
client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
|
|
# Then release it
|
|
response = client.delete(f"/admin/documents/{document_id}/lock")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["document_id"] == document_id
|
|
assert data["locked"] is False
|
|
assert data["lock_expires_at"] is None
|
|
assert "released successfully" in data["message"]
|
|
|
|
def test_release_lock_not_locked(self, client, document_id):
|
|
"""Test releasing lock on unlocked document."""
|
|
response = client.delete(f"/admin/documents/{document_id}/lock")
|
|
|
|
# Should succeed even if not locked
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["locked"] is False
|
|
|
|
def test_extend_lock_success(self, client, document_id):
|
|
"""Test successfully extending an annotation lock."""
|
|
# First acquire lock
|
|
response1 = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
original_expiry = response1.json()["lock_expires_at"]
|
|
|
|
# Extend lock
|
|
response2 = client.patch(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
|
|
assert response2.status_code == 200
|
|
data = response2.json()
|
|
assert data["document_id"] == document_id
|
|
assert data["locked"] is True
|
|
assert data["lock_expires_at"] != original_expiry
|
|
assert "extended by 300 seconds" in data["message"]
|
|
|
|
def test_extend_lock_not_locked(self, client, document_id):
|
|
"""Test extending lock on unlocked document."""
|
|
response = client.patch(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
|
|
assert response.status_code == 409
|
|
assert "doesn't exist or has expired" in response.json()["detail"]
|
|
|
|
def test_acquire_lock_custom_duration(self, client, document_id):
|
|
"""Test acquiring lock with custom duration."""
|
|
response = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 600}
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "Lock acquired for 600 seconds" in data["message"]
|
|
|
|
def test_acquire_lock_invalid_document(self, client):
|
|
"""Test acquiring lock on non-existent document."""
|
|
fake_id = str(uuid4())
|
|
response = client.post(
|
|
f"/admin/documents/{fake_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "not found" in response.json()["detail"]
|
|
|
|
def test_lock_lifecycle(self, client, document_id):
|
|
"""Test complete lock lifecycle: acquire -> extend -> release."""
|
|
# Acquire
|
|
response1 = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
assert response1.status_code == 200
|
|
assert response1.json()["locked"] is True
|
|
|
|
# Extend
|
|
response2 = client.patch(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
assert response2.status_code == 200
|
|
assert response2.json()["locked"] is True
|
|
|
|
# Release
|
|
response3 = client.delete(f"/admin/documents/{document_id}/lock")
|
|
assert response3.status_code == 200
|
|
assert response3.json()["locked"] is False
|
|
|
|
# Verify can acquire again after release
|
|
response4 = client.post(
|
|
f"/admin/documents/{document_id}/lock",
|
|
json={"duration_seconds": 300}
|
|
)
|
|
assert response4.status_code == 200
|
|
assert response4.json()["locked"] is True
|