WIP
This commit is contained in:
297
tests/web/test_inference_service.py
Normal file
297
tests/web/test_inference_service.py
Normal file
@@ -0,0 +1,297 @@
|
||||
"""
|
||||
Integration tests for inference service.
|
||||
|
||||
Tests the full initialization and processing flow to catch import errors.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.services.inference import InferenceService
|
||||
from src.web.config import ModelConfig, StorageConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config(tmp_path):
|
||||
"""Create model configuration for testing."""
|
||||
return ModelConfig(
|
||||
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
confidence_threshold=0.5,
|
||||
use_gpu=False, # Use CPU for tests
|
||||
dpi=150,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def storage_config(tmp_path):
|
||||
"""Create storage configuration for testing."""
|
||||
upload_dir = tmp_path / "uploads"
|
||||
result_dir = tmp_path / "results"
|
||||
upload_dir.mkdir()
|
||||
result_dir.mkdir()
|
||||
|
||||
return StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(tmp_path):
|
||||
"""Create a sample test image."""
|
||||
image_path = tmp_path / "test_invoice.png"
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img.save(image_path)
|
||||
return image_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inference_service(model_config, storage_config):
|
||||
"""Create inference service instance."""
|
||||
return InferenceService(
|
||||
model_config=model_config,
|
||||
storage_config=storage_config,
|
||||
)
|
||||
|
||||
|
||||
class TestInferenceServiceInitialization:
|
||||
"""Test inference service initialization to catch import errors."""
|
||||
|
||||
def test_service_creation(self, inference_service):
|
||||
"""Test that service can be created without errors."""
|
||||
assert inference_service is not None
|
||||
assert not inference_service.is_initialized
|
||||
|
||||
def test_gpu_available_check(self, inference_service):
|
||||
"""Test GPU availability check (should not crash)."""
|
||||
gpu_available = inference_service.gpu_available
|
||||
assert isinstance(gpu_available, bool)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_imports_correctly(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
):
|
||||
"""
|
||||
Test that initialize() imports modules correctly.
|
||||
|
||||
This test ensures that the import statements in initialize()
|
||||
use correct paths and don't fail with ImportError.
|
||||
"""
|
||||
# Mock the constructors to avoid actually loading models
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Initialize should not raise ImportError
|
||||
inference_service.initialize()
|
||||
|
||||
# Verify initialization succeeded
|
||||
assert inference_service.is_initialized
|
||||
|
||||
# Verify imports were called with correct parameters
|
||||
mock_yolo_detector.assert_called_once()
|
||||
mock_pipeline.assert_called_once()
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_sets_up_pipeline(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
model_config,
|
||||
):
|
||||
"""Test that initialize sets up pipeline with correct config."""
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
inference_service.initialize()
|
||||
|
||||
# Check YOLO detector was initialized correctly
|
||||
mock_yolo_detector.assert_called_once_with(
|
||||
str(model_config.model_path),
|
||||
confidence_threshold=model_config.confidence_threshold,
|
||||
device="cpu", # use_gpu=False in fixture
|
||||
)
|
||||
|
||||
# Check pipeline was initialized correctly
|
||||
mock_pipeline.assert_called_once_with(
|
||||
model_path=str(model_config.model_path),
|
||||
confidence_threshold=model_config.confidence_threshold,
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_initialize_idempotent(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
):
|
||||
"""Test that calling initialize() multiple times is safe."""
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Call initialize twice
|
||||
inference_service.initialize()
|
||||
inference_service.initialize()
|
||||
|
||||
# Should only be called once due to is_initialized check
|
||||
assert mock_yolo_detector.call_count == 1
|
||||
assert mock_pipeline.call_count == 1
|
||||
|
||||
|
||||
class TestInferenceServiceProcessing:
|
||||
"""Test inference processing methods."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_process_image_basic_flow(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
sample_image,
|
||||
):
|
||||
"""Test basic image processing flow."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {"InvoiceNumber": "12345"}
|
||||
mock_result.confidence = {"InvoiceNumber": 0.95}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Process image
|
||||
result = inference_service.process_image(sample_image)
|
||||
|
||||
# Verify result
|
||||
assert result.success
|
||||
assert result.fields == {"InvoiceNumber": "12345"}
|
||||
assert result.confidence == {"InvoiceNumber": 0.95}
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_process_image_handles_errors(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
sample_image,
|
||||
):
|
||||
"""Test that processing errors are handled gracefully."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Make pipeline raise an error
|
||||
mock_pipeline_instance.process_image.side_effect = Exception("Test error")
|
||||
|
||||
# Process should not crash
|
||||
result = inference_service.process_image(sample_image)
|
||||
|
||||
# Verify error handling
|
||||
assert not result.success
|
||||
assert len(result.errors) > 0
|
||||
assert "Test error" in result.errors[0]
|
||||
|
||||
|
||||
class TestInferenceServicePDFRendering:
|
||||
"""Test PDF rendering imports."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
@patch('src.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_render_pdf,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
inference_service,
|
||||
tmp_path,
|
||||
):
|
||||
"""
|
||||
Test that _save_pdf_visualization imports render_pdf_to_images correctly.
|
||||
|
||||
This catches the import error we had with:
|
||||
from ..pdf.renderer (wrong) vs from src.pdf.renderer (correct)
|
||||
"""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Create a fake PDF path
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.touch()
|
||||
|
||||
# Mock render_pdf_to_images to return an image
|
||||
image_bytes = io.BytesIO()
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img.save(image_bytes, format='PNG')
|
||||
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
||||
|
||||
# Mock YOLO
|
||||
mock_model_instance = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model_instance.predict.return_value = [mock_result]
|
||||
mock_yolo_class.return_value = mock_model_instance
|
||||
|
||||
# This should not raise ImportError
|
||||
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
||||
|
||||
# Verify import was successful
|
||||
mock_render_pdf.assert_called_once()
|
||||
assert result_path is not None
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not Path("runs/train/invoice_fields/weights/best.pt").exists(),
|
||||
reason="Model file not available"
|
||||
)
|
||||
class TestInferenceServiceRealModel:
|
||||
"""Integration tests with real model (skip if model not available)."""
|
||||
|
||||
def test_real_initialization(self, model_config, storage_config):
|
||||
"""Test real initialization with actual model."""
|
||||
service = InferenceService(model_config, storage_config)
|
||||
|
||||
# This should work with the real imports
|
||||
service.initialize()
|
||||
|
||||
assert service.is_initialized
|
||||
assert service._pipeline is not None
|
||||
assert service._detector is not None
|
||||
Reference in New Issue
Block a user