283 lines
7.8 KiB
Python
283 lines
7.8 KiB
Python
"""
|
|
Tests for Batch Upload Queue
|
|
"""
|
|
|
|
import time
|
|
from datetime import datetime
|
|
from threading import Event
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
from inference.web.workers.batch_queue import BatchTask, BatchTaskQueue
|
|
|
|
|
|
class MockBatchService:
|
|
"""Mock batch upload service for testing."""
|
|
|
|
def __init__(self):
|
|
self.processed_tasks = []
|
|
self.process_delay = 0.1 # Simulate processing time
|
|
self.should_fail = False
|
|
|
|
def process_zip_upload(self, admin_token, zip_filename, zip_content, upload_source):
|
|
"""Mock process_zip_upload method."""
|
|
if self.should_fail:
|
|
raise Exception("Simulated processing error")
|
|
|
|
time.sleep(self.process_delay) # Simulate work
|
|
|
|
self.processed_tasks.append({
|
|
"admin_token": admin_token,
|
|
"zip_filename": zip_filename,
|
|
"upload_source": upload_source,
|
|
})
|
|
|
|
return {
|
|
"status": "completed",
|
|
"successful_files": 1,
|
|
"failed_files": 0,
|
|
}
|
|
|
|
|
|
class TestBatchTask:
|
|
"""Tests for BatchTask dataclass."""
|
|
|
|
def test_batch_task_creation(self):
|
|
"""BatchTask can be created with required fields."""
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename="test.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
assert task.batch_id is not None
|
|
assert task.admin_token == "test-token"
|
|
assert task.zip_filename == "test.zip"
|
|
assert task.upload_source == "ui"
|
|
assert task.auto_label is True
|
|
|
|
|
|
class TestBatchTaskQueue:
|
|
"""Tests for batch task queue functionality."""
|
|
|
|
def test_queue_initialization(self):
|
|
"""Queue initializes with correct defaults."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
|
|
assert queue.get_queue_depth() == 0
|
|
assert queue.is_running is False
|
|
assert queue._worker_count == 1
|
|
|
|
def test_start_queue(self):
|
|
"""Queue starts with batch service."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
queue.start(service)
|
|
|
|
assert queue.is_running is True
|
|
assert len(queue._workers) == 1
|
|
|
|
queue.stop()
|
|
|
|
def test_stop_queue(self):
|
|
"""Queue stops gracefully."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
queue.start(service)
|
|
assert queue.is_running is True
|
|
|
|
queue.stop(timeout=5.0)
|
|
|
|
assert queue.is_running is False
|
|
assert len(queue._workers) == 0
|
|
|
|
def test_submit_task_success(self):
|
|
"""Task is submitted to queue successfully."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename="test.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
result = queue.submit(task)
|
|
|
|
assert result is True
|
|
assert queue.get_queue_depth() == 1
|
|
|
|
def test_submit_task_queue_full(self):
|
|
"""Returns False when queue is full."""
|
|
queue = BatchTaskQueue(max_size=2, worker_count=1)
|
|
|
|
# Fill the queue
|
|
for i in range(2):
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename=f"test{i}.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
assert queue.submit(task) is True
|
|
|
|
# Try to add one more (should fail)
|
|
extra_task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename="extra.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
result = queue.submit(extra_task)
|
|
|
|
assert result is False
|
|
assert queue.get_queue_depth() == 2
|
|
|
|
def test_worker_processes_task(self):
|
|
"""Worker thread processes queued tasks."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
queue.start(service)
|
|
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename="test.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
queue.submit(task)
|
|
|
|
# Wait for processing
|
|
time.sleep(0.5)
|
|
|
|
assert len(service.processed_tasks) == 1
|
|
assert service.processed_tasks[0]["zip_filename"] == "test.zip"
|
|
|
|
queue.stop()
|
|
|
|
def test_multiple_tasks_processed(self):
|
|
"""Multiple tasks are processed in order."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
queue.start(service)
|
|
|
|
# Submit multiple tasks
|
|
for i in range(3):
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename=f"test{i}.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
queue.submit(task)
|
|
|
|
# Wait for all to process
|
|
time.sleep(1.0)
|
|
|
|
assert len(service.processed_tasks) == 3
|
|
|
|
queue.stop()
|
|
|
|
def test_get_queue_depth(self):
|
|
"""Returns correct queue depth."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
|
|
assert queue.get_queue_depth() == 0
|
|
|
|
# Add tasks
|
|
for i in range(3):
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename=f"test{i}.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
queue.submit(task)
|
|
|
|
assert queue.get_queue_depth() == 3
|
|
|
|
def test_is_running_property(self):
|
|
"""is_running reflects queue state."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
assert queue.is_running is False
|
|
|
|
queue.start(service)
|
|
assert queue.is_running is True
|
|
|
|
queue.stop()
|
|
assert queue.is_running is False
|
|
|
|
def test_double_start_ignored(self):
|
|
"""Starting queue twice is safely ignored."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
|
|
queue.start(service)
|
|
worker_count_after_first_start = len(queue._workers)
|
|
|
|
queue.start(service) # Second start
|
|
worker_count_after_second_start = len(queue._workers)
|
|
|
|
assert worker_count_after_first_start == worker_count_after_second_start
|
|
|
|
queue.stop()
|
|
|
|
def test_error_handling_in_worker(self):
|
|
"""Worker handles processing errors gracefully."""
|
|
queue = BatchTaskQueue(max_size=10, worker_count=1)
|
|
service = MockBatchService()
|
|
service.should_fail = True # Cause errors
|
|
|
|
queue.start(service)
|
|
|
|
task = BatchTask(
|
|
batch_id=uuid4(),
|
|
admin_token="test-token",
|
|
zip_content=b"test",
|
|
zip_filename="test.zip",
|
|
upload_source="ui",
|
|
auto_label=True,
|
|
created_at=datetime.utcnow(),
|
|
)
|
|
|
|
queue.submit(task)
|
|
|
|
# Wait for processing attempt
|
|
time.sleep(0.5)
|
|
|
|
# Worker should still be running
|
|
assert queue.is_running is True
|
|
|
|
queue.stop()
|