WIP
This commit is contained in:
298
tests/web/test_inference_api.py
Normal file
298
tests/web/test_inference_api.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
Integration tests for inference API endpoints.
|
||||
|
||||
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from src.web.app import create_app
|
||||
from src.web.config import ModelConfig, StorageConfig, AppConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_app(tmp_path):
|
||||
"""Create test FastAPI application."""
|
||||
# Setup test directories
|
||||
upload_dir = tmp_path / "uploads"
|
||||
result_dir = tmp_path / "results"
|
||||
upload_dir.mkdir()
|
||||
result_dir.mkdir()
|
||||
|
||||
# Create test config
|
||||
app_config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
confidence_threshold=0.5,
|
||||
use_gpu=False,
|
||||
dpi=150,
|
||||
),
|
||||
storage=StorageConfig(
|
||||
upload_dir=upload_dir,
|
||||
result_dir=result_dir,
|
||||
allowed_extensions={".pdf", ".png", ".jpg", ".jpeg"},
|
||||
max_file_size_mb=50,
|
||||
),
|
||||
)
|
||||
|
||||
# Create app
|
||||
app = create_app(app_config)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(test_app):
|
||||
"""Create test client."""
|
||||
return TestClient(test_app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_png_bytes():
|
||||
"""Create sample PNG image bytes."""
|
||||
img = Image.new('RGB', (800, 1200), color='white')
|
||||
img_bytes = io.BytesIO()
|
||||
img.save(img_bytes, format='PNG')
|
||||
img_bytes.seek(0)
|
||||
return img_bytes
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Test /api/v1/health endpoint."""
|
||||
|
||||
def test_health_check_returns_200(self, client):
|
||||
"""Test health check returns 200 OK."""
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_check_response_structure(self, client):
|
||||
"""Test health check response has correct structure."""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
|
||||
assert "status" in data
|
||||
assert "model_loaded" in data
|
||||
assert "gpu_available" in data
|
||||
assert "version" in data
|
||||
|
||||
assert data["status"] == "healthy"
|
||||
assert isinstance(data["model_loaded"], bool)
|
||||
assert isinstance(data["gpu_available"], bool)
|
||||
|
||||
|
||||
class TestInferEndpoint:
|
||||
"""Test /api/v1/infer endpoint."""
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_accepts_png_file(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that /infer endpoint accepts PNG files."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {"InvoiceNumber": "12345"}
|
||||
mock_result.confidence = {"InvoiceNumber": 0.95}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_result.document_id = "test123"
|
||||
mock_result.document_type = "invoice"
|
||||
mock_result.processing_time_ms = 100.0
|
||||
mock_result.visualization_path = None
|
||||
mock_result.detections = []
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
assert data["status"] == "success"
|
||||
assert "result" in data
|
||||
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
||||
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
|
||||
|
||||
def test_infer_rejects_invalid_file_type(self, client):
|
||||
"""Test that /infer rejects unsupported file types."""
|
||||
invalid_file = io.BytesIO(b"fake txt content")
|
||||
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", invalid_file, "text/plain")},
|
||||
)
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Unsupported file type" in response.json()["detail"]
|
||||
|
||||
def test_infer_requires_file(self, client):
|
||||
"""Test that /infer requires a file parameter."""
|
||||
response = client.post("/api/v1/infer")
|
||||
|
||||
assert response.status_code == 422 # Unprocessable Entity
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_returns_cross_validation_if_available(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that cross-validation results are included if available."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Mock pipeline result with cross-validation
|
||||
mock_result = Mock()
|
||||
mock_result.fields = {
|
||||
"InvoiceNumber": "12345",
|
||||
"OCR": "1234567",
|
||||
"Amount": "100.00",
|
||||
}
|
||||
mock_result.confidence = {
|
||||
"InvoiceNumber": 0.95,
|
||||
"OCR": 0.90,
|
||||
"Amount": 0.88,
|
||||
}
|
||||
mock_result.success = True
|
||||
mock_result.errors = []
|
||||
mock_result.raw_detections = []
|
||||
mock_result.document_id = "test123"
|
||||
mock_result.document_type = "invoice"
|
||||
mock_result.processing_time_ms = 100.0
|
||||
mock_result.visualization_path = None
|
||||
mock_result.detections = []
|
||||
|
||||
# Add cross-validation result
|
||||
mock_cv = Mock()
|
||||
mock_cv.is_valid = True
|
||||
mock_cv.payment_line_ocr = "1234567"
|
||||
mock_cv.ocr_match = True
|
||||
mock_result.cross_validation = mock_cv
|
||||
|
||||
mock_pipeline_instance.process_image.return_value = mock_result
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify response includes cross-validation
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
|
||||
# Note: cross_validation is not currently in the response schema
|
||||
# This test documents that it should be added
|
||||
|
||||
|
||||
@patch('src.inference.pipeline.InferencePipeline')
|
||||
@patch('src.inference.yolo_detector.YOLODetector')
|
||||
def test_infer_handles_processing_errors_gracefully(
|
||||
self,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
client,
|
||||
sample_png_bytes,
|
||||
):
|
||||
"""Test that processing errors are handled gracefully."""
|
||||
# Setup mocks
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Make pipeline raise an error
|
||||
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
|
||||
|
||||
# Make request
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
||||
)
|
||||
|
||||
# Verify error handling - service catches exceptions and returns partial results
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "partial"
|
||||
assert data["result"]["success"] is False
|
||||
assert len(data["result"]["errors"]) > 0
|
||||
assert "Model inference failed" in data["result"]["errors"][0]
|
||||
|
||||
|
||||
class TestResultsEndpoint:
|
||||
"""Test /api/v1/results/{filename} endpoint."""
|
||||
|
||||
def test_get_result_image_returns_404_if_not_found(self, client):
|
||||
"""Test that getting non-existent result returns 404."""
|
||||
response = client.get("/api/v1/results/nonexistent.png")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_result_image_returns_file_if_exists(self, client, test_app, tmp_path):
|
||||
"""Test that existing result file is returned."""
|
||||
# Get storage config from app
|
||||
storage_config = test_app.extra.get("storage_config")
|
||||
if not storage_config:
|
||||
pytest.skip("Storage config not available in test app")
|
||||
|
||||
# Create a test result file
|
||||
result_file = storage_config.result_dir / "test_result.png"
|
||||
img = Image.new('RGB', (100, 100), color='red')
|
||||
img.save(result_file)
|
||||
|
||||
# Request the file
|
||||
response = client.get("/api/v1/results/test_result.png")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "image/png"
|
||||
|
||||
|
||||
class TestInferenceServiceImports:
|
||||
"""Critical test to catch import errors."""
|
||||
|
||||
def test_inference_service_can_import_modules(self):
|
||||
"""
|
||||
Test that InferenceService can import its dependencies.
|
||||
|
||||
This test will fail if there are ImportError issues like:
|
||||
- from ..inference.pipeline (wrong relative import)
|
||||
- from src.web.inference (non-existent module)
|
||||
|
||||
It ensures the imports are correct before runtime.
|
||||
"""
|
||||
from src.web.services.inference import InferenceService
|
||||
|
||||
# Import the modules that InferenceService tries to import
|
||||
from src.inference.pipeline import InferencePipeline
|
||||
from src.inference.yolo_detector import YOLODetector
|
||||
from src.pdf.renderer import render_pdf_to_images
|
||||
|
||||
# If we got here, all imports work correctly
|
||||
assert InferencePipeline is not None
|
||||
assert YOLODetector is not None
|
||||
assert render_pdf_to_images is not None
|
||||
assert InferenceService is not None
|
||||
Reference in New Issue
Block a user