528 lines
18 KiB
Python
528 lines
18 KiB
Python
"""
|
|
Integration tests for inference API endpoints.
|
|
|
|
Tests the /api/v1/infer endpoint to ensure it works end-to-end.
|
|
"""
|
|
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch
|
|
from fastapi.testclient import TestClient
|
|
from PIL import Image
|
|
import io
|
|
|
|
from backend.web.app import create_app
|
|
from backend.web.config import ModelConfig, StorageConfig, AppConfig
|
|
|
|
|
|
@pytest.fixture
|
|
def test_app(tmp_path):
|
|
"""Create test FastAPI application."""
|
|
# Setup test directories
|
|
upload_dir = tmp_path / "uploads"
|
|
result_dir = tmp_path / "results"
|
|
upload_dir.mkdir()
|
|
result_dir.mkdir()
|
|
|
|
# Create test config
|
|
app_config = AppConfig(
|
|
model=ModelConfig(
|
|
model_path=Path("runs/train/invoice_fields/weights/best.pt"),
|
|
confidence_threshold=0.5,
|
|
use_gpu=False,
|
|
dpi=150,
|
|
),
|
|
file=StorageConfig(
|
|
upload_dir=upload_dir,
|
|
result_dir=result_dir,
|
|
allowed_extensions=(".pdf", ".png", ".jpg", ".jpeg"),
|
|
max_file_size_mb=50,
|
|
),
|
|
)
|
|
|
|
# Create app
|
|
app = create_app(app_config)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(test_app):
|
|
"""Create test client."""
|
|
return TestClient(test_app)
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_png_bytes():
|
|
"""Create sample PNG image bytes."""
|
|
img = Image.new('RGB', (800, 1200), color='white')
|
|
img_bytes = io.BytesIO()
|
|
img.save(img_bytes, format='PNG')
|
|
img_bytes.seek(0)
|
|
return img_bytes
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
"""Test /api/v1/health endpoint."""
|
|
|
|
def test_health_check_returns_200(self, client):
|
|
"""Test health check returns 200 OK."""
|
|
response = client.get("/api/v1/health")
|
|
assert response.status_code == 200
|
|
|
|
def test_health_check_response_structure(self, client):
|
|
"""Test health check response has correct structure."""
|
|
response = client.get("/api/v1/health")
|
|
data = response.json()
|
|
|
|
assert "status" in data
|
|
assert "model_loaded" in data
|
|
assert "gpu_available" in data
|
|
assert "version" in data
|
|
|
|
assert data["status"] == "healthy"
|
|
assert isinstance(data["model_loaded"], bool)
|
|
assert isinstance(data["gpu_available"], bool)
|
|
|
|
|
|
class TestInferEndpoint:
|
|
"""Test /api/v1/infer endpoint."""
|
|
|
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
|
def test_infer_accepts_png_file(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that /infer endpoint accepts PNG files."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Mock pipeline result
|
|
mock_result = Mock()
|
|
mock_result.fields = {"InvoiceNumber": "12345"}
|
|
mock_result.confidence = {"InvoiceNumber": 0.95}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 100.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
mock_pipeline_instance.process_image.return_value = mock_result
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify response
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert data["status"] == "success"
|
|
assert "result" in data
|
|
assert data["result"]["fields"]["InvoiceNumber"] == "12345"
|
|
assert data["result"]["confidence"]["InvoiceNumber"] == 0.95
|
|
|
|
def test_infer_rejects_invalid_file_type(self, client):
|
|
"""Test that /infer rejects unsupported file types."""
|
|
invalid_file = io.BytesIO(b"fake txt content")
|
|
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.txt", invalid_file, "text/plain")},
|
|
)
|
|
|
|
assert response.status_code == 400
|
|
assert "Unsupported file type" in response.json()["detail"]
|
|
|
|
def test_infer_requires_file(self, client):
|
|
"""Test that /infer requires a file parameter."""
|
|
response = client.post("/api/v1/infer")
|
|
|
|
assert response.status_code == 422 # Unprocessable Entity
|
|
|
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
|
def test_infer_returns_cross_validation_if_available(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that cross-validation results are included if available."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Mock pipeline result with cross-validation
|
|
mock_result = Mock()
|
|
mock_result.fields = {
|
|
"InvoiceNumber": "12345",
|
|
"OCR": "1234567",
|
|
"Amount": "100.00",
|
|
}
|
|
mock_result.confidence = {
|
|
"InvoiceNumber": 0.95,
|
|
"OCR": 0.90,
|
|
"Amount": 0.88,
|
|
}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 100.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
|
|
# Add cross-validation result
|
|
mock_cv = Mock()
|
|
mock_cv.is_valid = True
|
|
mock_cv.payment_line_ocr = "1234567"
|
|
mock_cv.ocr_match = True
|
|
mock_result.cross_validation = mock_cv
|
|
|
|
mock_pipeline_instance.process_image.return_value = mock_result
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify response includes cross-validation
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Note: cross_validation is not currently in the response schema
|
|
# This test documents that it should be added
|
|
|
|
|
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
|
def test_infer_handles_processing_errors_gracefully(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that processing errors are handled gracefully."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Make pipeline raise an error
|
|
mock_pipeline_instance.process_image.side_effect = Exception("Model inference failed")
|
|
|
|
# Make request
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
# Verify error handling - service catches exceptions and returns partial results
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "partial"
|
|
assert data["result"]["success"] is False
|
|
assert len(data["result"]["errors"]) > 0
|
|
assert "Model inference failed" in data["result"]["errors"][0]
|
|
|
|
|
|
class TestResultsEndpoint:
|
|
"""Test /api/v1/results/{filename} endpoint."""
|
|
|
|
def test_get_result_image_returns_404_if_not_found(self, client):
|
|
"""Test that getting non-existent result returns 404."""
|
|
response = client.get("/api/v1/results/nonexistent.png")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_result_image_returns_file_if_exists(self, client, tmp_path):
|
|
"""Test that existing result file is returned."""
|
|
# Create a test result file in temp directory
|
|
result_dir = tmp_path / "results"
|
|
result_dir.mkdir(exist_ok=True)
|
|
result_file = result_dir / "test_result.png"
|
|
img = Image.new('RGB', (100, 100), color='red')
|
|
img.save(result_file)
|
|
|
|
# Mock the storage helper to return our test file path
|
|
with patch(
|
|
"backend.web.api.v1.public.inference.get_storage_helper"
|
|
) as mock_storage:
|
|
mock_helper = Mock()
|
|
mock_helper.get_result_local_path.return_value = result_file
|
|
mock_storage.return_value = mock_helper
|
|
|
|
# Request the file
|
|
response = client.get("/api/v1/results/test_result.png")
|
|
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"] == "image/png"
|
|
|
|
|
|
class TestInferenceServiceImports:
|
|
"""Critical test to catch import errors."""
|
|
|
|
def test_inference_service_can_import_modules(self):
|
|
"""
|
|
Test that InferenceService can import its dependencies.
|
|
|
|
This test will fail if there are ImportError issues like:
|
|
- from ..inference.pipeline (wrong relative import)
|
|
- from backend.web.inference (non-existent module)
|
|
|
|
It ensures the imports are correct before runtime.
|
|
"""
|
|
from backend.web.services.inference import InferenceService
|
|
|
|
# Import the modules that InferenceService tries to import
|
|
from backend.pipeline.pipeline import InferencePipeline
|
|
from backend.pipeline.yolo_detector import YOLODetector
|
|
from shared.pdf.renderer import render_pdf_to_images
|
|
|
|
# If we got here, all imports work correctly
|
|
assert InferencePipeline is not None
|
|
assert YOLODetector is not None
|
|
assert render_pdf_to_images is not None
|
|
assert InferenceService is not None
|
|
|
|
|
|
class TestBusinessFeaturesAPI:
|
|
"""Tests for business features (line items, VAT) in API."""
|
|
|
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
|
def test_infer_with_extract_line_items_false_by_default(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
sample_png_bytes,
|
|
):
|
|
"""Test that extract_line_items defaults to False."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Mock pipeline result
|
|
mock_result = Mock()
|
|
mock_result.fields = {"InvoiceNumber": "12345"}
|
|
mock_result.confidence = {"InvoiceNumber": 0.95}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 100.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
mock_pipeline_instance.process_image.return_value = mock_result
|
|
|
|
# Make request without extract_line_items parameter
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.png", sample_png_bytes, "image/png")},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Business features should be None when not requested
|
|
assert data["result"]["line_items"] is None
|
|
assert data["result"]["vat_summary"] is None
|
|
assert data["result"]["vat_validation"] is None
|
|
|
|
@patch('backend.pipeline.pipeline.InferencePipeline')
|
|
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
|
def test_infer_with_extract_line_items_returns_business_features(
|
|
self,
|
|
mock_yolo_detector,
|
|
mock_pipeline,
|
|
client,
|
|
tmp_path,
|
|
):
|
|
"""Test that extract_line_items=True returns business features."""
|
|
# Setup mocks
|
|
mock_detector_instance = Mock()
|
|
mock_pipeline_instance = Mock()
|
|
mock_yolo_detector.return_value = mock_detector_instance
|
|
mock_pipeline.return_value = mock_pipeline_instance
|
|
|
|
# Create a test PDF file
|
|
pdf_path = tmp_path / "test.pdf"
|
|
pdf_path.write_bytes(b'%PDF-1.4 fake pdf content')
|
|
|
|
# Mock pipeline result with business features
|
|
mock_result = Mock()
|
|
mock_result.fields = {"Amount": "12500,00"}
|
|
mock_result.confidence = {"Amount": 0.95}
|
|
mock_result.success = True
|
|
mock_result.errors = []
|
|
mock_result.raw_detections = []
|
|
mock_result.document_id = "test123"
|
|
mock_result.document_type = "invoice"
|
|
mock_result.processing_time_ms = 150.0
|
|
mock_result.visualization_path = None
|
|
mock_result.detections = []
|
|
|
|
# Mock line items
|
|
mock_result.line_items = Mock()
|
|
mock_result._line_items_to_json.return_value = {
|
|
"items": [
|
|
{
|
|
"row_index": 0,
|
|
"description": "Product A",
|
|
"quantity": "2",
|
|
"unit": "st",
|
|
"unit_price": "5000,00",
|
|
"amount": "10000,00",
|
|
"article_number": "ART001",
|
|
"vat_rate": "25",
|
|
"confidence": 0.9,
|
|
}
|
|
],
|
|
"header_row": ["Beskrivning", "Antal", "Pris", "Belopp"],
|
|
"total_amount": "10000,00",
|
|
}
|
|
|
|
# Mock VAT summary
|
|
mock_result.vat_summary = Mock()
|
|
mock_result._vat_summary_to_json.return_value = {
|
|
"breakdowns": [
|
|
{
|
|
"rate": 25.0,
|
|
"base_amount": "10000,00",
|
|
"vat_amount": "2500,00",
|
|
"source": "regex",
|
|
}
|
|
],
|
|
"total_excl_vat": "10000,00",
|
|
"total_vat": "2500,00",
|
|
"total_incl_vat": "12500,00",
|
|
"confidence": 0.9,
|
|
}
|
|
|
|
# Mock VAT validation
|
|
mock_result.vat_validation = Mock()
|
|
mock_result._vat_validation_to_json.return_value = {
|
|
"is_valid": True,
|
|
"confidence_score": 0.95,
|
|
"math_checks": [
|
|
{
|
|
"rate": 25.0,
|
|
"base_amount": 10000.0,
|
|
"expected_vat": 2500.0,
|
|
"actual_vat": 2500.0,
|
|
"is_valid": True,
|
|
"tolerance": 0.5,
|
|
}
|
|
],
|
|
"total_check": True,
|
|
"line_items_vs_summary": True,
|
|
"amount_consistency": True,
|
|
"needs_review": False,
|
|
"review_reasons": [],
|
|
}
|
|
|
|
mock_pipeline_instance.process_pdf.return_value = mock_result
|
|
|
|
# Make request with extract_line_items=true
|
|
response = client.post(
|
|
"/api/v1/infer",
|
|
files={"file": ("test.pdf", pdf_path.open("rb"), "application/pdf")},
|
|
data={"extract_line_items": "true"},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
# Verify business features are included
|
|
assert data["result"]["line_items"] is not None
|
|
assert len(data["result"]["line_items"]["items"]) == 1
|
|
assert data["result"]["line_items"]["items"][0]["description"] == "Product A"
|
|
assert data["result"]["line_items"]["items"][0]["amount"] == "10000,00"
|
|
|
|
assert data["result"]["vat_summary"] is not None
|
|
assert len(data["result"]["vat_summary"]["breakdowns"]) == 1
|
|
assert data["result"]["vat_summary"]["breakdowns"][0]["rate"] == 25.0
|
|
assert data["result"]["vat_summary"]["total_incl_vat"] == "12500,00"
|
|
|
|
assert data["result"]["vat_validation"] is not None
|
|
assert data["result"]["vat_validation"]["is_valid"] is True
|
|
assert data["result"]["vat_validation"]["confidence_score"] == 0.95
|
|
|
|
def test_schema_imports_work_correctly(self):
|
|
"""Test that all business feature schemas can be imported."""
|
|
from backend.web.schemas.inference import (
|
|
LineItemSchema,
|
|
LineItemsResultSchema,
|
|
VATBreakdownSchema,
|
|
VATSummarySchema,
|
|
MathCheckResultSchema,
|
|
VATValidationResultSchema,
|
|
InferenceResult,
|
|
)
|
|
|
|
# Verify schemas can be instantiated
|
|
line_item = LineItemSchema(
|
|
row_index=0,
|
|
description="Test",
|
|
amount="100",
|
|
)
|
|
assert line_item.description == "Test"
|
|
|
|
vat_breakdown = VATBreakdownSchema(
|
|
rate=25.0,
|
|
base_amount="100",
|
|
vat_amount="25",
|
|
)
|
|
assert vat_breakdown.rate == 25.0
|
|
|
|
# Verify InferenceResult includes business feature fields
|
|
result = InferenceResult(
|
|
document_id="test",
|
|
success=True,
|
|
processing_time_ms=100.0,
|
|
)
|
|
assert result.line_items is None
|
|
assert result.vat_summary is None
|
|
assert result.vat_validation is None
|
|
|
|
def test_service_result_has_business_feature_fields(self):
|
|
"""Test that ServiceResult dataclass includes business feature fields."""
|
|
from backend.web.services.inference import ServiceResult
|
|
|
|
result = ServiceResult(document_id="test123")
|
|
|
|
# Verify business feature fields exist and default to None
|
|
assert result.line_items is None
|
|
assert result.vat_summary is None
|
|
assert result.vat_validation is None
|
|
|
|
# Verify they can be set
|
|
result.line_items = {"items": []}
|
|
result.vat_summary = {"breakdowns": []}
|
|
result.vat_validation = {"is_valid": True}
|
|
|
|
assert result.line_items == {"items": []}
|
|
assert result.vat_summary == {"breakdowns": []}
|
|
assert result.vat_validation == {"is_valid": True}
|