184 lines
5.7 KiB
Python
184 lines
5.7 KiB
Python
"""
|
|
API Routes
|
|
|
|
FastAPI route definitions for the inference API.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import shutil
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING
|
|
|
|
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
|
from fastapi.responses import FileResponse
|
|
|
|
from .schemas import (
|
|
BatchInferenceResponse,
|
|
DetectionResult,
|
|
ErrorResponse,
|
|
HealthResponse,
|
|
InferenceResponse,
|
|
InferenceResult,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from .services import InferenceService
|
|
from .config import StorageConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def create_api_router(
|
|
inference_service: "InferenceService",
|
|
storage_config: "StorageConfig",
|
|
) -> APIRouter:
|
|
"""
|
|
Create API router with inference endpoints.
|
|
|
|
Args:
|
|
inference_service: Inference service instance
|
|
storage_config: Storage configuration
|
|
|
|
Returns:
|
|
Configured APIRouter
|
|
"""
|
|
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
|
|
|
@router.get("/health", response_model=HealthResponse)
|
|
async def health_check() -> HealthResponse:
|
|
"""Check service health status."""
|
|
return HealthResponse(
|
|
status="healthy",
|
|
model_loaded=inference_service.is_initialized,
|
|
gpu_available=inference_service.gpu_available,
|
|
version="1.0.0",
|
|
)
|
|
|
|
@router.post(
|
|
"/infer",
|
|
response_model=InferenceResponse,
|
|
responses={
|
|
400: {"model": ErrorResponse, "description": "Invalid file"},
|
|
500: {"model": ErrorResponse, "description": "Processing error"},
|
|
},
|
|
)
|
|
async def infer_document(
|
|
file: UploadFile = File(..., description="PDF or image file to process"),
|
|
) -> InferenceResponse:
|
|
"""
|
|
Process a document and extract invoice fields.
|
|
|
|
Accepts PDF or image files (PNG, JPG, JPEG).
|
|
Returns extracted field values with confidence scores.
|
|
"""
|
|
# Validate file extension
|
|
if not file.filename:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Filename is required",
|
|
)
|
|
|
|
file_ext = Path(file.filename).suffix.lower()
|
|
if file_ext not in storage_config.allowed_extensions:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
|
)
|
|
|
|
# Generate document ID
|
|
doc_id = str(uuid.uuid4())[:8]
|
|
|
|
# Save uploaded file
|
|
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
|
try:
|
|
with open(upload_path, "wb") as f:
|
|
shutil.copyfileobj(file.file, f)
|
|
except Exception as e:
|
|
logger.error(f"Failed to save uploaded file: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Failed to save uploaded file",
|
|
)
|
|
|
|
try:
|
|
# Process based on file type
|
|
if file_ext == ".pdf":
|
|
service_result = inference_service.process_pdf(
|
|
upload_path, document_id=doc_id
|
|
)
|
|
else:
|
|
service_result = inference_service.process_image(
|
|
upload_path, document_id=doc_id
|
|
)
|
|
|
|
# Build response
|
|
viz_url = None
|
|
if service_result.visualization_path:
|
|
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
|
|
|
|
inference_result = InferenceResult(
|
|
document_id=service_result.document_id,
|
|
success=service_result.success,
|
|
fields=service_result.fields,
|
|
confidence=service_result.confidence,
|
|
detections=[
|
|
DetectionResult(**d) for d in service_result.detections
|
|
],
|
|
processing_time_ms=service_result.processing_time_ms,
|
|
visualization_url=viz_url,
|
|
errors=service_result.errors,
|
|
)
|
|
|
|
return InferenceResponse(
|
|
status="success" if service_result.success else "partial",
|
|
message=f"Processed document {doc_id}",
|
|
result=inference_result,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing document: {e}")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=str(e),
|
|
)
|
|
|
|
finally:
|
|
# Cleanup uploaded file
|
|
upload_path.unlink(missing_ok=True)
|
|
|
|
@router.get("/results/{filename}")
|
|
async def get_result_image(filename: str) -> FileResponse:
|
|
"""Get visualization result image."""
|
|
file_path = storage_config.result_dir / filename
|
|
|
|
if not file_path.exists():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Result file not found: {filename}",
|
|
)
|
|
|
|
return FileResponse(
|
|
path=file_path,
|
|
media_type="image/png",
|
|
filename=filename,
|
|
)
|
|
|
|
@router.delete("/results/{filename}")
|
|
async def delete_result(filename: str) -> dict:
|
|
"""Delete a result file."""
|
|
file_path = storage_config.result_dir / filename
|
|
|
|
if not file_path.exists():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail=f"Result file not found: {filename}",
|
|
)
|
|
|
|
file_path.unlink()
|
|
return {"status": "deleted", "filename": filename}
|
|
|
|
return router
|