Files
invoice-master-poc-v2/tests/web/test_inference_api.py
2026-02-03 21:28:06 +01:00

528 lines
18 KiB
Python

"""
Integration tests for inference API endpoints.
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
"""
import pytest
from pathlib import Path
from unittest.mock import Mock, patch
from fastapi.testclient import TestClient
from PIL import Image
import io
from backend.web.app import create_app
from backend.web.config import ModelConfig, StorageConfig, AppConfig
@pytest.fixture
def test_app(tmp_path):
"""Create test FastAPI application."""
# Setup test directories
upload_dir = tmp_path / "uploads"
result_dir = tmp_path / "results"
upload_dir.mkdir()
result_dir.mkdir()
# Create test config
app_config = AppConfig(
model=ModelConfig(
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
confidence_threshold=0.5,
use_gpu=False,
dpi=150,
),
file=StorageConfig(
upload_dir=upload_dir,
result_dir=result_dir,
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
max_file_size_mb=50,
),
)
# Create app
app = create_app(app_config)
return app
@pytest.fixture
def client(test_app):
"""Create test client."""
return TestClient(test_app)
@pytest.fixture
def sample_png_bytes():
"""Create sample PNG image bytes."""
img = Image.new('RGB', (800, 1200), color='white')
img_bytes = io.BytesIO()
img.save(img_bytes, format='PNG')
img_bytes.seek(0)
return img_bytes
class TestHealthEndpoint:
"""Test /api/v1/health endpoint."""
def test_health_check_returns_200(self, client):
"""Test health check returns 200 OK."""
response = client.get("/api/v1/health")
assert response.status_code == 200
def test_health_check_response_structure(self, client):
"""Test health check response has correct structure."""
response = client.get("/api/v1/health")
data = response.json()
assert "status" in data
assert "model_loaded" in data
assert "gpu_available" in data
assert "version" in data
assert data["status"] == "healthy"
assert isinstance(data["model_loaded"], bool)
assert isinstance(data["gpu_available"], bool)
class TestInferEndpoint:
"""Test /api/v1/infer endpoint."""
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_accepts_png_file(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that /infer endpoint accepts PNG files."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result
mock_result = Mock()
mock_result.fields = {"InvoiceNumber": "12345"}
mock_result.confidence = {"InvoiceNumber": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert "result" in data
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
def test_infer_rejects_invalid_file_type(self, client):
"""Test that /infer rejects unsupported file types."""
invalid_file = io.BytesIO(b"fake txt content")
response = client.post(
"/api/v1/infer",
files={"file": ("test.txt", invalid_file, "text/plain")},
)
assert response.status_code == 400
assert "Unsupported file type" in response.json()["detail"]
def test_infer_requires_file(self, client):
"""Test that /infer requires a file parameter."""
response = client.post("/api/v1/infer")
assert response.status_code == 422 # Unprocessable Entity
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_returns_cross_validation_if_available(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that cross-validation results are included if available."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result with cross-validation
mock_result = Mock()
mock_result.fields = {
"InvoiceNumber": "12345",
"OCR": "1234567",
"Amount": "100.00",
}
mock_result.confidence = {
"InvoiceNumber": 0.95,
"OCR": 0.90,
"Amount": 0.88,
}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
# Add cross-validation result
mock_cv = Mock()
mock_cv.is_valid = True
mock_cv.payment_line_ocr = "1234567"
mock_cv.ocr_match = True
mock_result.cross_validation = mock_cv
mock_pipeline_instance.process_image.return_value = mock_result
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify response includes cross-validation
assert response.status_code == 200
data = response.json()
# Note: cross_validation is not currently in the response schema
# This test documents that it should be added
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_handles_processing_errors_gracefully(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that processing errors are handled gracefully."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Make pipeline raise an error
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
# Make request
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
# Verify error handling - service catches exceptions and returns partial results
assert response.status_code == 200
data = response.json()
assert data["status"] == "partial"
assert data["result"]["success"] is False
assert len(data["result"]["errors"]) > 0
assert "Model inference failed" in data["result"]["errors"][0]
class TestResultsEndpoint:
"""Test /api/v1/results/{filename} endpoint."""
def test_get_result_image_returns_404_if_not_found(self, client):
"""Test that getting non-existent result returns 404."""
response = client.get("/api/v1/results/nonexistent.png")
assert response.status_code == 404
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
"""Test that existing result file is returned."""
# Create a test result file in temp directory
result_dir = tmp_path / "results"
result_dir.mkdir(exist_ok=True)
result_file = result_dir / "test_result.png"
img = Image.new('RGB', (100, 100), color='red')
img.save(result_file)
# Mock the storage helper to return our test file path
with patch(
"backend.web.api.v1.public.inference.get_storage_helper"
) as mock_storage:
mock_helper = Mock()
mock_helper.get_result_local_path.return_value = result_file
mock_storage.return_value = mock_helper
# Request the file
response = client.get("/api/v1/results/test_result.png")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"
class TestInferenceServiceImports:
"""Critical test to catch import errors."""
def test_inference_service_can_import_modules(self):
"""
Test that InferenceService can import its dependencies.
This test will fail if there are ImportError issues like:
- from ..inference.pipeline (wrong relative import)
- from backend.web.inference (non-existent module)
It ensures the imports are correct before runtime.
"""
from backend.web.services.inference import InferenceService
# Import the modules that InferenceService tries to import
from backend.pipeline.pipeline import InferencePipeline
from backend.pipeline.yolo_detector import YOLODetector
from shared.pdf.renderer import render_pdf_to_images
# If we got here, all imports work correctly
assert InferencePipeline is not None
assert YOLODetector is not None
assert render_pdf_to_images is not None
assert InferenceService is not None
class TestBusinessFeaturesAPI:
"""Tests for business features (line items, VAT) in API."""
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_with_extract_line_items_false_by_default(
self,
mock_yolo_detector,
mock_pipeline,
client,
sample_png_bytes,
):
"""Test that extract_line_items defaults to False."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Mock pipeline result
mock_result = Mock()
mock_result.fields = {"InvoiceNumber": "12345"}
mock_result.confidence = {"InvoiceNumber": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 100.0
mock_result.visualization_path = None
mock_result.detections = []
mock_pipeline_instance.process_image.return_value = mock_result
# Make request without extract_line_items parameter
response = client.post(
"/api/v1/infer",
files={"file": ("test.png", sample_png_bytes, "image/png")},
)
assert response.status_code == 200
data = response.json()
# Business features should be None when not requested
assert data["result"]["line_items"] is None
assert data["result"]["vat_summary"] is None
assert data["result"]["vat_validation"] is None
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
def test_infer_with_extract_line_items_returns_business_features(
self,
mock_yolo_detector,
mock_pipeline,
client,
tmp_path,
):
"""Test that extract_line_items=True returns business features."""
# Setup mocks
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_yolo_detector.return_value = mock_detector_instance
mock_pipeline.return_value = mock_pipeline_instance
# Create a test PDF file
pdf_path = tmp_path / "test.pdf"
pdf_path.write_bytes(b'%PDF-1.4 fake pdf content')
# Mock pipeline result with business features
mock_result = Mock()
mock_result.fields = {"Amount": "12500,00"}
mock_result.confidence = {"Amount": 0.95}
mock_result.success = True
mock_result.errors = []
mock_result.raw_detections = []
mock_result.document_id = "test123"
mock_result.document_type = "invoice"
mock_result.processing_time_ms = 150.0
mock_result.visualization_path = None
mock_result.detections = []
# Mock line items
mock_result.line_items = Mock()
mock_result._line_items_to_json.return_value = {
"items": [
{
"row_index": 0,
"description": "Product A",
"quantity": "2",
"unit": "st",
"unit_price": "5000,00",
"amount": "10000,00",
"article_number": "ART001",
"vat_rate": "25",
"confidence": 0.9,
}
],
"header_row": ["Beskrivning", "Antal", "Pris", "Belopp"],
"total_amount": "10000,00",
}
# Mock VAT summary
mock_result.vat_summary = Mock()
mock_result._vat_summary_to_json.return_value = {
"breakdowns": [
{
"rate": 25.0,
"base_amount": "10000,00",
"vat_amount": "2500,00",
"source": "regex",
}
],
"total_excl_vat": "10000,00",
"total_vat": "2500,00",
"total_incl_vat": "12500,00",
"confidence": 0.9,
}
# Mock VAT validation
mock_result.vat_validation = Mock()
mock_result._vat_validation_to_json.return_value = {
"is_valid": True,
"confidence_score": 0.95,
"math_checks": [
{
"rate": 25.0,
"base_amount": 10000.0,
"expected_vat": 2500.0,
"actual_vat": 2500.0,
"is_valid": True,
"tolerance": 0.5,
}
],
"total_check": True,
"line_items_vs_summary": True,
"amount_consistency": True,
"needs_review": False,
"review_reasons": [],
}
mock_pipeline_instance.process_pdf.return_value = mock_result
# Make request with extract_line_items=true
response = client.post(
"/api/v1/infer",
files={"file": ("test.pdf", pdf_path.open("rb"), "application/pdf")},
data={"extract_line_items": "true"},
)
assert response.status_code == 200
data = response.json()
# Verify business features are included
assert data["result"]["line_items"] is not None
assert len(data["result"]["line_items"]["items"]) == 1
assert data["result"]["line_items"]["items"][0]["description"] == "Product A"
assert data["result"]["line_items"]["items"][0]["amount"] == "10000,00"
assert data["result"]["vat_summary"] is not None
assert len(data["result"]["vat_summary"]["breakdowns"]) == 1
assert data["result"]["vat_summary"]["breakdowns"][0]["rate"] == 25.0
assert data["result"]["vat_summary"]["total_incl_vat"] == "12500,00"
assert data["result"]["vat_validation"] is not None
assert data["result"]["vat_validation"]["is_valid"] is True
assert data["result"]["vat_validation"]["confidence_score"] == 0.95
def test_schema_imports_work_correctly(self):
"""Test that all business feature schemas can be imported."""
from backend.web.schemas.inference import (
LineItemSchema,
LineItemsResultSchema,
VATBreakdownSchema,
VATSummarySchema,
MathCheckResultSchema,
VATValidationResultSchema,
InferenceResult,
)
# Verify schemas can be instantiated
line_item = LineItemSchema(
row_index=0,
description="Test",
amount="100",
)
assert line_item.description == "Test"
vat_breakdown = VATBreakdownSchema(
rate=25.0,
base_amount="100",
vat_amount="25",
)
assert vat_breakdown.rate == 25.0
# Verify InferenceResult includes business feature fields
result = InferenceResult(
document_id="test",
success=True,
processing_time_ms=100.0,
)
assert result.line_items is None
assert result.vat_summary is None
assert result.vat_validation is None
def test_service_result_has_business_feature_fields(self):
"""Test that ServiceResult dataclass includes business feature fields."""
from backend.web.services.inference import ServiceResult
result = ServiceResult(document_id="test123")
# Verify business feature fields exist and default to None
assert result.line_items is None
assert result.vat_summary is None
assert result.vat_validation is None
# Verify they can be set
result.line_items = {"items": []}
result.vat_summary = {"breakdowns": []}
result.vat_validation = {"is_valid": True}
assert result.line_items == {"items": []}
assert result.vat_summary == {"breakdowns": []}
assert result.vat_validation == {"is_valid": True}