218 lines
5.9 KiB
Python
218 lines
5.9 KiB
Python
"""
|
|
Tests for the AsyncTaskQueue class.
|
|
"""
|
|
|
|
import tempfile
|
|
import time
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from threading import Event
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from src.web.workers.async_queue import AsyncTask, AsyncTaskQueue
|
|
|
|
|
|
class TestAsyncTask:
|
|
"""Tests for AsyncTask dataclass."""
|
|
|
|
def test_create_task(self):
|
|
"""Test creating an AsyncTask."""
|
|
task = AsyncTask(
|
|
request_id="test-id",
|
|
api_key="test-key",
|
|
file_path=Path("/tmp/test.pdf"),
|
|
filename="test.pdf",
|
|
)
|
|
|
|
assert task.request_id == "test-id"
|
|
assert task.api_key == "test-key"
|
|
assert task.filename == "test.pdf"
|
|
assert task.priority == 0
|
|
assert task.created_at is not None
|
|
|
|
|
|
class TestAsyncTaskQueue:
|
|
"""Tests for AsyncTaskQueue."""
|
|
|
|
def test_init(self):
|
|
"""Test queue initialization."""
|
|
queue = AsyncTaskQueue(max_size=50, worker_count=2)
|
|
|
|
assert queue._worker_count == 2
|
|
assert queue._queue.maxsize == 50
|
|
assert not queue._started
|
|
|
|
def test_submit_task(self, task_queue, sample_task):
|
|
"""Test submitting a task to the queue."""
|
|
success = task_queue.submit(sample_task)
|
|
|
|
assert success is True
|
|
assert task_queue.get_queue_depth() == 1
|
|
|
|
def test_submit_when_full(self, sample_task):
|
|
"""Test submitting to a full queue."""
|
|
queue = AsyncTaskQueue(max_size=1, worker_count=1)
|
|
|
|
# Submit first task
|
|
queue.submit(sample_task)
|
|
|
|
# Create second task
|
|
task2 = AsyncTask(
|
|
request_id="test-2",
|
|
api_key="test-key",
|
|
file_path=sample_task.file_path,
|
|
filename="test2.pdf",
|
|
)
|
|
|
|
# Queue should be full
|
|
success = queue.submit(task2)
|
|
assert success is False
|
|
|
|
def test_get_queue_depth(self, task_queue, sample_task):
|
|
"""Test getting queue depth."""
|
|
assert task_queue.get_queue_depth() == 0
|
|
|
|
task_queue.submit(sample_task)
|
|
assert task_queue.get_queue_depth() == 1
|
|
|
|
def test_start_and_stop(self, task_queue):
|
|
"""Test starting and stopping the queue."""
|
|
handler = MagicMock()
|
|
|
|
task_queue.start(handler)
|
|
assert task_queue._started is True
|
|
assert task_queue.is_running is True
|
|
assert len(task_queue._workers) == 1
|
|
|
|
task_queue.stop(timeout=5.0)
|
|
assert task_queue._started is False
|
|
assert task_queue.is_running is False
|
|
assert len(task_queue._workers) == 0
|
|
|
|
def test_worker_processes_task(self, sample_task):
|
|
"""Test that worker thread processes tasks."""
|
|
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
|
processed = Event()
|
|
|
|
def handler(task):
|
|
processed.set()
|
|
|
|
queue.start(handler)
|
|
queue.submit(sample_task)
|
|
|
|
# Wait for processing
|
|
assert processed.wait(timeout=5.0)
|
|
|
|
queue.stop()
|
|
|
|
def test_worker_handles_errors(self, sample_task):
|
|
"""Test that worker handles errors gracefully."""
|
|
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
|
error_handled = Event()
|
|
|
|
def failing_handler(task):
|
|
error_handled.set()
|
|
raise ValueError("Test error")
|
|
|
|
queue.start(failing_handler)
|
|
queue.submit(sample_task)
|
|
|
|
# Should not crash
|
|
assert error_handled.wait(timeout=5.0)
|
|
time.sleep(0.5) # Give time for error handling
|
|
|
|
assert queue.is_running
|
|
|
|
queue.stop()
|
|
|
|
def test_processing_tracking(self, task_queue, sample_task):
|
|
"""Test tracking of processing tasks."""
|
|
processed = Event()
|
|
|
|
def slow_handler(task):
|
|
processed.set()
|
|
time.sleep(0.5)
|
|
|
|
task_queue.start(slow_handler)
|
|
task_queue.submit(sample_task)
|
|
|
|
# Wait for processing to start
|
|
assert processed.wait(timeout=5.0)
|
|
|
|
# Task should be in processing set
|
|
assert task_queue.get_processing_count() == 1
|
|
assert task_queue.is_processing(sample_task.request_id)
|
|
|
|
# Wait for completion
|
|
time.sleep(1.0)
|
|
|
|
assert task_queue.get_processing_count() == 0
|
|
assert not task_queue.is_processing(sample_task.request_id)
|
|
|
|
task_queue.stop()
|
|
|
|
def test_multiple_workers(self, sample_task):
|
|
"""Test queue with multiple workers."""
|
|
queue = AsyncTaskQueue(max_size=10, worker_count=3)
|
|
processed_count = []
|
|
|
|
def handler(task):
|
|
processed_count.append(task.request_id)
|
|
time.sleep(0.1)
|
|
|
|
queue.start(handler)
|
|
|
|
# Submit multiple tasks
|
|
for i in range(5):
|
|
task = AsyncTask(
|
|
request_id=f"task-{i}",
|
|
api_key="test-key",
|
|
file_path=sample_task.file_path,
|
|
filename=f"test-{i}.pdf",
|
|
)
|
|
queue.submit(task)
|
|
|
|
# Wait for all tasks
|
|
time.sleep(2.0)
|
|
|
|
assert len(processed_count) == 5
|
|
|
|
queue.stop()
|
|
|
|
def test_graceful_shutdown(self, sample_task):
|
|
"""Test graceful shutdown waits for current task."""
|
|
queue = AsyncTaskQueue(max_size=10, worker_count=1)
|
|
started = Event()
|
|
finished = Event()
|
|
|
|
def slow_handler(task):
|
|
started.set()
|
|
time.sleep(0.5)
|
|
finished.set()
|
|
|
|
queue.start(slow_handler)
|
|
queue.submit(sample_task)
|
|
|
|
# Wait for processing to start
|
|
assert started.wait(timeout=5.0)
|
|
|
|
# Stop should wait for task to finish
|
|
queue.stop(timeout=5.0)
|
|
|
|
assert finished.is_set()
|
|
|
|
def test_double_start(self, task_queue):
|
|
"""Test that starting twice doesn't create duplicate workers."""
|
|
handler = MagicMock()
|
|
|
|
task_queue.start(handler)
|
|
assert len(task_queue._workers) == 1
|
|
|
|
# Starting again should not add more workers
|
|
task_queue.start(handler)
|
|
assert len(task_queue._workers) == 1
|
|
|
|
task_queue.stop()
|