Files
invoice-master-poc-v2/packages/backend/backend/web/services/batch_upload.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

549 lines
19 KiB
Python

"""
Batch Upload Service
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
"""
import csv
import io
import logging
import zipfile
from datetime import datetime
from pathlib import Path
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field, field_validator
from backend.data.repositories import BatchUploadRepository
from shared.fields import CSV_TO_CLASS_MAPPING
logger = logging.getLogger(__name__)
# Security limits
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
MAX_FILES_IN_ZIP = 1000
class CSVRowData(BaseModel):
"""Validated CSV row data with security checks."""
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
invoice_number: str | None = Field(None, max_length=255)
invoice_date: str | None = Field(None, max_length=50)
invoice_due_date: str | None = Field(None, max_length=50)
amount: str | None = Field(None, max_length=100)
ocr: str | None = Field(None, max_length=100)
bankgiro: str | None = Field(None, max_length=50)
plusgiro: str | None = Field(None, max_length=50)
customer_number: str | None = Field(None, max_length=255)
supplier_organisation_number: str | None = Field(None, max_length=50)
@field_validator('*', mode='before')
@classmethod
def strip_whitespace(cls, v):
"""Strip whitespace from all string fields."""
if isinstance(v, str):
return v.strip()
return v
@field_validator('*', mode='before')
@classmethod
def reject_suspicious_patterns(cls, v):
"""Reject values with suspicious characters."""
if isinstance(v, str):
# Reject SQL/shell metacharacters and newlines
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
if any(char in v for char in dangerous_chars):
raise ValueError(f"Suspicious characters detected in value")
return v
class BatchUploadService:
"""Service for handling batch uploads of documents via ZIP files."""
def __init__(self, batch_repo: BatchUploadRepository | None = None):
"""Initialize the batch upload service.
Args:
batch_repo: Batch upload repository (created if None)
"""
self.batch_repo = batch_repo or BatchUploadRepository()
def _safe_extract_filename(self, zip_path: str) -> str:
"""Safely extract filename from ZIP path, preventing path traversal.
Args:
zip_path: Path from ZIP file entry
Returns:
Safe filename
Raises:
ValueError: If path contains traversal attempts or is invalid
"""
# Reject absolute paths
if zip_path.startswith('/') or zip_path.startswith('\\'):
raise ValueError(f"Absolute path rejected: {zip_path}")
# Reject path traversal attempts
if '..' in zip_path:
raise ValueError(f"Path traversal rejected: {zip_path}")
# Reject Windows drive letters
if len(zip_path) >= 2 and zip_path[1] == ':':
raise ValueError(f"Windows path rejected: {zip_path}")
# Get only the basename
safe_name = Path(zip_path).name
if not safe_name or safe_name in ['.', '..']:
raise ValueError(f"Invalid filename: {zip_path}")
# Validate filename doesn't contain suspicious characters
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
raise ValueError(f"Invalid characters in filename: {safe_name}")
return safe_name
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
"""Validate ZIP file against Zip bomb and other attacks.
Args:
zip_file: Opened ZIP file
Raises:
ValueError: If ZIP file is unsafe
"""
total_uncompressed = 0
file_count = 0
for zip_info in zip_file.infolist():
file_count += 1
# Check file count limit
if file_count > MAX_FILES_IN_ZIP:
raise ValueError(
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
)
# Check individual file size
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
raise ValueError(
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
)
# Accumulate uncompressed size
total_uncompressed += zip_info.file_size
# Check total uncompressed size (Zip bomb protection)
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
raise ValueError(
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
)
# Validate filename safety
try:
self._safe_extract_filename(zip_info.filename)
except ValueError as e:
logger.warning(f"Rejecting malicious ZIP entry: {e}")
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
def process_zip_upload(
self,
admin_token: str,
zip_filename: str,
zip_content: bytes,
upload_source: str = "ui",
) -> dict[str, Any]:
"""Process a ZIP file containing PDFs and optional CSV.
Args:
admin_token: Admin authentication token
zip_filename: Name of the ZIP file
zip_content: ZIP file content as bytes
upload_source: Upload source (ui or api)
Returns:
Dictionary with batch upload results
"""
batch = self.batch_repo.create(
admin_token=admin_token,
filename=zip_filename,
file_size=len(zip_content),
upload_source=upload_source,
)
try:
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
# Validate ZIP safety first
self._validate_zip_safety(zip_file)
result = self._process_zip_contents(
batch_id=batch.batch_id,
admin_token=admin_token,
zip_file=zip_file,
)
# Update batch upload status
self.batch_repo.update(
batch_id=batch.batch_id,
status=result["status"],
total_files=result["total_files"],
processed_files=result["processed_files"],
successful_files=result["successful_files"],
failed_files=result["failed_files"],
csv_filename=result.get("csv_filename"),
csv_row_count=result.get("csv_row_count"),
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
**result,
}
except zipfile.BadZipFile as e:
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Invalid ZIP file format",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Invalid ZIP file format",
}
except ValueError as e:
# Security validation errors
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="ZIP file validation failed",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "ZIP file validation failed",
}
except Exception as e:
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
self.batch_repo.update(
batch_id=batch.batch_id,
status="failed",
error_message="Processing error",
completed_at=datetime.utcnow(),
)
return {
"batch_id": str(batch.batch_id),
"status": "failed",
"error": "Failed to process batch upload",
}
def _process_zip_contents(
self,
batch_id: UUID,
admin_token: str,
zip_file: zipfile.ZipFile,
) -> dict[str, Any]:
"""Process contents of ZIP file.
Args:
batch_id: Batch upload ID
admin_token: Admin authentication token
zip_file: Opened ZIP file
Returns:
Processing results dictionary
"""
# Extract file lists
pdf_files = []
csv_file = None
csv_data = {}
for file_info in zip_file.filelist:
if file_info.is_dir():
continue
try:
# Use safe filename extraction
filename = self._safe_extract_filename(file_info.filename)
except ValueError as e:
logger.warning(f"Skipping invalid file: {e}")
continue
if filename.lower().endswith('.pdf'):
pdf_files.append(file_info)
elif filename.lower().endswith('.csv'):
if csv_file is None:
csv_file = file_info
# Parse CSV
csv_data = self._parse_csv_file(zip_file, file_info)
else:
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
if not pdf_files:
return {
"status": "failed",
"total_files": 0,
"processed_files": 0,
"successful_files": 0,
"failed_files": 0,
"error": "No PDF files found in ZIP",
}
# Process each PDF file
total_files = len(pdf_files)
successful_files = 0
failed_files = 0
for pdf_info in pdf_files:
file_record = None
try:
# Use safe filename extraction
filename = self._safe_extract_filename(pdf_info.filename)
# Create batch upload file record
file_record = self.batch_repo.create_file(
batch_id=batch_id,
filename=filename,
status="processing",
)
# Get CSV data for this file if available
document_id_base = Path(filename).stem
csv_row_data = csv_data.get(document_id_base)
# Extract PDF content
pdf_content = zip_file.read(pdf_info.filename)
# TODO: Save PDF file and create document
# For now, just mark as completed
self.batch_repo.update_file(
file_id=file_record.file_id,
status="completed",
csv_row_data=csv_row_data,
processed_at=datetime.utcnow(),
)
successful_files += 1
except ValueError as e:
# Path validation error
logger.warning(f"Skipping invalid file: {e}")
if file_record:
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Invalid filename",
processed_at=datetime.utcnow(),
)
failed_files += 1
except Exception as e:
logger.error(f"Error processing PDF: {e}", exc_info=True)
if file_record:
self.batch_repo.update_file(
file_id=file_record.file_id,
status="failed",
error_message="Processing error",
processed_at=datetime.utcnow(),
)
failed_files += 1
# Determine overall status
if failed_files == 0:
status = "completed"
elif successful_files == 0:
status = "failed"
else:
status = "partial"
result = {
"status": status,
"total_files": total_files,
"processed_files": total_files,
"successful_files": successful_files,
"failed_files": failed_files,
}
if csv_file:
result["csv_filename"] = Path(csv_file.filename).name
result["csv_row_count"] = len(csv_data)
return result
def _parse_csv_file(
self,
zip_file: zipfile.ZipFile,
csv_file_info: zipfile.ZipInfo,
) -> dict[str, dict[str, Any]]:
"""Parse CSV file and extract field values with validation.
Args:
zip_file: Opened ZIP file
csv_file_info: CSV file info
Returns:
Dictionary mapping DocumentId to validated field values
"""
# Try multiple encodings
csv_bytes = zip_file.read(csv_file_info.filename)
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
csv_content = None
for encoding in encodings:
try:
csv_content = csv_bytes.decode(encoding)
logger.info(f"CSV decoded with {encoding}")
break
except UnicodeDecodeError:
continue
if csv_content is None:
logger.error("Failed to decode CSV with any encoding")
raise ValueError("Unable to decode CSV file")
csv_reader = csv.DictReader(io.StringIO(csv_content))
result = {}
# Case-insensitive column mapping
field_name_map = {
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
'Amount': ['Amount', 'amount'],
'OCR': ['OCR', 'ocr'],
'Bankgiro': ['Bankgiro', 'bankgiro'],
'Plusgiro': ['Plusgiro', 'plusgiro'],
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
}
for row_num, row in enumerate(csv_reader, start=2):
try:
# Create case-insensitive lookup
row_lower = {k.lower(): v for k, v in row.items()}
# Find DocumentId with case-insensitive matching
document_id = None
for variant in field_name_map['DocumentId']:
if variant.lower() in row_lower:
document_id = row_lower[variant.lower()]
break
if not document_id:
logger.warning(f"Row {row_num}: No DocumentId found")
continue
# Validate using Pydantic model
csv_row_dict = {'document_id': document_id}
# Map CSV field names to model attribute names
csv_to_model_attr = {
'InvoiceNumber': 'invoice_number',
'InvoiceDate': 'invoice_date',
'InvoiceDueDate': 'invoice_due_date',
'Amount': 'amount',
'OCR': 'ocr',
'Bankgiro': 'bankgiro',
'Plusgiro': 'plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for csv_field in field_name_map.keys():
if csv_field == 'DocumentId':
continue
model_attr = csv_to_model_attr.get(csv_field)
if not model_attr:
continue
for variant in field_name_map[csv_field]:
if variant.lower() in row_lower and row_lower[variant.lower()]:
csv_row_dict[model_attr] = row_lower[variant.lower()]
break
# Validate
validated_row = CSVRowData(**csv_row_dict)
# Extract only the fields we care about (map back to CSV field names)
field_values = {}
model_attr_to_csv = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'amount': 'Amount',
'ocr': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'customer_number': 'customer_number',
'supplier_organisation_number': 'supplier_organisation_number',
}
for model_attr, csv_field in model_attr_to_csv.items():
value = getattr(validated_row, model_attr, None)
if value and csv_field in CSV_TO_CLASS_MAPPING:
field_values[csv_field] = value
if field_values:
result[document_id] = field_values
except Exception as e:
logger.warning(f"Row {row_num}: Validation error - {e}")
continue
return result
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
"""Get batch upload status.
Args:
batch_id: Batch upload ID
Returns:
Batch status dictionary
"""
batch = self.batch_repo.get(UUID(batch_id))
if not batch:
return {
"error": "Batch upload not found",
}
files = self.batch_repo.get_files(batch.batch_id)
return {
"batch_id": str(batch.batch_id),
"filename": batch.filename,
"status": batch.status,
"total_files": batch.total_files,
"processed_files": batch.processed_files,
"successful_files": batch.successful_files,
"failed_files": batch.failed_files,
"csv_filename": batch.csv_filename,
"csv_row_count": batch.csv_row_count,
"error_message": batch.error_message,
"created_at": batch.created_at.isoformat() if batch.created_at else None,
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
"files": [
{
"filename": f.filename,
"status": f.status,
"error_message": f.error_message,
"annotation_count": f.annotation_count,
}
for f in files
],
}