Files
invoice-master-poc-v2/packages/inference/inference/web/schemas/inference.py
2026-01-27 23:58:17 +01:00

197 lines
7.1 KiB
Python

"""
API Request/Response Schemas
Pydantic models for API validation and serialization.
"""
from datetime import datetime
from enum import Enum
from pydantic import BaseModel, Field
# =============================================================================
# Enums
# =============================================================================
class AsyncStatus(str, Enum):
"""Async request status enum."""
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
# =============================================================================
# Sync API Schemas (existing)
# =============================================================================
class DetectionResult(BaseModel):
"""Single detection result."""
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
class ExtractedField(BaseModel):
"""Extracted and normalized field value."""
field_name: str = Field(..., description="Field name")
value: str | None = Field(None, description="Extracted value")
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
is_valid: bool = Field(True, description="Whether the value passed validation")
class InferenceResult(BaseModel):
"""Complete inference result for a document."""
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
document_type: str = Field(
default="invoice", description="Document type: 'invoice' or 'letter'"
)
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)
confidence: dict[str, float] = Field(
default_factory=dict, description="Confidence scores per field"
)
detections: list[DetectionResult] = Field(
default_factory=list, description="Raw YOLO detections"
)
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
errors: list[str] = Field(default_factory=list, description="Error messages")
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
status: str = Field(..., description="Response status: success or error")
message: str = Field(..., description="Response message")
result: InferenceResult | None = Field(None, description="Inference result")
class BatchInferenceResponse(BaseModel):
"""API response for batch inference endpoint."""
status: str = Field(..., description="Response status")
message: str = Field(..., description="Response message")
total: int = Field(..., description="Total documents processed")
successful: int = Field(..., description="Number of successful extractions")
results: list[InferenceResult] = Field(
default_factory=list, description="Individual results"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
gpu_available: bool = Field(..., description="Whether GPU is available")
version: str = Field(..., description="API version")
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")
# =============================================================================
# Async API Schemas
# =============================================================================
class AsyncSubmitResponse(BaseModel):
"""Response for async submit endpoint."""
status: str = Field(default="accepted", description="Response status")
message: str = Field(..., description="Response message")
request_id: str = Field(..., description="Unique request identifier (UUID)")
estimated_wait_seconds: int = Field(
..., ge=0, description="Estimated wait time in seconds"
)
poll_url: str = Field(..., description="URL to poll for status updates")
class AsyncStatusResponse(BaseModel):
"""Response for async status endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
created_at: datetime = Field(..., description="Request creation timestamp")
started_at: datetime | None = Field(
None, description="Processing start timestamp"
)
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
position_in_queue: int | None = Field(
None, description="Position in queue (for pending status)"
)
error_message: str | None = Field(
None, description="Error message (for failed status)"
)
result_url: str | None = Field(
None, description="URL to fetch results (for completed status)"
)
class AsyncResultResponse(BaseModel):
"""Response for async result endpoint."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Processing status")
processing_time_ms: float = Field(
..., ge=0, description="Total processing time in milliseconds"
)
result: InferenceResult | None = Field(
None, description="Extraction result (when completed)"
)
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
class AsyncRequestItem(BaseModel):
"""Single item in async requests list."""
request_id: str = Field(..., description="Unique request identifier")
status: AsyncStatus = Field(..., description="Current processing status")
filename: str = Field(..., description="Original filename")
file_size: int = Field(..., ge=0, description="File size in bytes")
created_at: datetime = Field(..., description="Request creation timestamp")
completed_at: datetime | None = Field(
None, description="Processing completion timestamp"
)
class AsyncRequestsListResponse(BaseModel):
"""Response for async requests list endpoint."""
total: int = Field(..., ge=0, description="Total number of requests")
limit: int = Field(..., ge=1, description="Maximum items per page")
offset: int = Field(..., ge=0, description="Current offset")
requests: list[AsyncRequestItem] = Field(
default_factory=list, description="List of requests"
)
class RateLimitInfo(BaseModel):
"""Rate limit information (included in headers)."""
limit: int = Field(..., description="Maximum requests per minute")
remaining: int = Field(..., description="Remaining requests in current window")
reset_at: datetime = Field(..., description="Time when limit resets")