549 lines
19 KiB
Python
549 lines
19 KiB
Python
"""
|
|
Batch Upload Service
|
|
|
|
Handles ZIP file uploads with multiple PDFs and optional CSV for auto-labeling.
|
|
"""
|
|
|
|
import csv
|
|
import io
|
|
import logging
|
|
import zipfile
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from pydantic import BaseModel, Field, field_validator
|
|
|
|
from backend.data.repositories import BatchUploadRepository
|
|
from shared.fields import CSV_TO_CLASS_MAPPING
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Security limits
|
|
MAX_COMPRESSED_SIZE = 100 * 1024 * 1024 # 100 MB
|
|
MAX_UNCOMPRESSED_SIZE = 200 * 1024 * 1024 # 200 MB
|
|
MAX_INDIVIDUAL_FILE_SIZE = 50 * 1024 * 1024 # 50 MB
|
|
MAX_FILES_IN_ZIP = 1000
|
|
|
|
|
|
class CSVRowData(BaseModel):
|
|
"""Validated CSV row data with security checks."""
|
|
|
|
document_id: str = Field(..., min_length=1, max_length=255, pattern=r'^[a-zA-Z0-9\-_\.]+$')
|
|
invoice_number: str | None = Field(None, max_length=255)
|
|
invoice_date: str | None = Field(None, max_length=50)
|
|
invoice_due_date: str | None = Field(None, max_length=50)
|
|
amount: str | None = Field(None, max_length=100)
|
|
ocr: str | None = Field(None, max_length=100)
|
|
bankgiro: str | None = Field(None, max_length=50)
|
|
plusgiro: str | None = Field(None, max_length=50)
|
|
customer_number: str | None = Field(None, max_length=255)
|
|
supplier_organisation_number: str | None = Field(None, max_length=50)
|
|
|
|
@field_validator('*', mode='before')
|
|
@classmethod
|
|
def strip_whitespace(cls, v):
|
|
"""Strip whitespace from all string fields."""
|
|
if isinstance(v, str):
|
|
return v.strip()
|
|
return v
|
|
|
|
@field_validator('*', mode='before')
|
|
@classmethod
|
|
def reject_suspicious_patterns(cls, v):
|
|
"""Reject values with suspicious characters."""
|
|
if isinstance(v, str):
|
|
# Reject SQL/shell metacharacters and newlines
|
|
dangerous_chars = [';', '|', '&', '`', '$', '\n', '\r', '\x00']
|
|
if any(char in v for char in dangerous_chars):
|
|
raise ValueError(f"Suspicious characters detected in value")
|
|
return v
|
|
|
|
|
|
class BatchUploadService:
|
|
"""Service for handling batch uploads of documents via ZIP files."""
|
|
|
|
def __init__(self, batch_repo: BatchUploadRepository | None = None):
|
|
"""Initialize the batch upload service.
|
|
|
|
Args:
|
|
batch_repo: Batch upload repository (created if None)
|
|
"""
|
|
self.batch_repo = batch_repo or BatchUploadRepository()
|
|
|
|
def _safe_extract_filename(self, zip_path: str) -> str:
|
|
"""Safely extract filename from ZIP path, preventing path traversal.
|
|
|
|
Args:
|
|
zip_path: Path from ZIP file entry
|
|
|
|
Returns:
|
|
Safe filename
|
|
|
|
Raises:
|
|
ValueError: If path contains traversal attempts or is invalid
|
|
"""
|
|
# Reject absolute paths
|
|
if zip_path.startswith('/') or zip_path.startswith('\\'):
|
|
raise ValueError(f"Absolute path rejected: {zip_path}")
|
|
|
|
# Reject path traversal attempts
|
|
if '..' in zip_path:
|
|
raise ValueError(f"Path traversal rejected: {zip_path}")
|
|
|
|
# Reject Windows drive letters
|
|
if len(zip_path) >= 2 and zip_path[1] == ':':
|
|
raise ValueError(f"Windows path rejected: {zip_path}")
|
|
|
|
# Get only the basename
|
|
safe_name = Path(zip_path).name
|
|
if not safe_name or safe_name in ['.', '..']:
|
|
raise ValueError(f"Invalid filename: {zip_path}")
|
|
|
|
# Validate filename doesn't contain suspicious characters
|
|
if any(char in safe_name for char in ['\\', '/', '\x00', '\n', '\r']):
|
|
raise ValueError(f"Invalid characters in filename: {safe_name}")
|
|
|
|
return safe_name
|
|
|
|
def _validate_zip_safety(self, zip_file: zipfile.ZipFile) -> None:
|
|
"""Validate ZIP file against Zip bomb and other attacks.
|
|
|
|
Args:
|
|
zip_file: Opened ZIP file
|
|
|
|
Raises:
|
|
ValueError: If ZIP file is unsafe
|
|
"""
|
|
total_uncompressed = 0
|
|
file_count = 0
|
|
|
|
for zip_info in zip_file.infolist():
|
|
file_count += 1
|
|
|
|
# Check file count limit
|
|
if file_count > MAX_FILES_IN_ZIP:
|
|
raise ValueError(
|
|
f"ZIP contains too many files (max {MAX_FILES_IN_ZIP})"
|
|
)
|
|
|
|
# Check individual file size
|
|
if zip_info.file_size > MAX_INDIVIDUAL_FILE_SIZE:
|
|
max_mb = MAX_INDIVIDUAL_FILE_SIZE / (1024 * 1024)
|
|
raise ValueError(
|
|
f"File '{zip_info.filename}' exceeds {max_mb:.0f}MB limit"
|
|
)
|
|
|
|
# Accumulate uncompressed size
|
|
total_uncompressed += zip_info.file_size
|
|
|
|
# Check total uncompressed size (Zip bomb protection)
|
|
if total_uncompressed > MAX_UNCOMPRESSED_SIZE:
|
|
max_mb = MAX_UNCOMPRESSED_SIZE / (1024 * 1024)
|
|
raise ValueError(
|
|
f"Total uncompressed size exceeds {max_mb:.0f}MB limit"
|
|
)
|
|
|
|
# Validate filename safety
|
|
try:
|
|
self._safe_extract_filename(zip_info.filename)
|
|
except ValueError as e:
|
|
logger.warning(f"Rejecting malicious ZIP entry: {e}")
|
|
raise ValueError(f"Invalid file in ZIP: {zip_info.filename}")
|
|
|
|
def process_zip_upload(
|
|
self,
|
|
admin_token: str,
|
|
zip_filename: str,
|
|
zip_content: bytes,
|
|
upload_source: str = "ui",
|
|
) -> dict[str, Any]:
|
|
"""Process a ZIP file containing PDFs and optional CSV.
|
|
|
|
Args:
|
|
admin_token: Admin authentication token
|
|
zip_filename: Name of the ZIP file
|
|
zip_content: ZIP file content as bytes
|
|
upload_source: Upload source (ui or api)
|
|
|
|
Returns:
|
|
Dictionary with batch upload results
|
|
"""
|
|
batch = self.batch_repo.create(
|
|
admin_token=admin_token,
|
|
filename=zip_filename,
|
|
file_size=len(zip_content),
|
|
upload_source=upload_source,
|
|
)
|
|
|
|
try:
|
|
with zipfile.ZipFile(io.BytesIO(zip_content)) as zip_file:
|
|
# Validate ZIP safety first
|
|
self._validate_zip_safety(zip_file)
|
|
|
|
result = self._process_zip_contents(
|
|
batch_id=batch.batch_id,
|
|
admin_token=admin_token,
|
|
zip_file=zip_file,
|
|
)
|
|
|
|
# Update batch upload status
|
|
self.batch_repo.update(
|
|
batch_id=batch.batch_id,
|
|
status=result["status"],
|
|
total_files=result["total_files"],
|
|
processed_files=result["processed_files"],
|
|
successful_files=result["successful_files"],
|
|
failed_files=result["failed_files"],
|
|
csv_filename=result.get("csv_filename"),
|
|
csv_row_count=result.get("csv_row_count"),
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
|
|
return {
|
|
"batch_id": str(batch.batch_id),
|
|
**result,
|
|
}
|
|
|
|
except zipfile.BadZipFile as e:
|
|
logger.error(f"Invalid ZIP file {zip_filename}: {e}")
|
|
self.batch_repo.update(
|
|
batch_id=batch.batch_id,
|
|
status="failed",
|
|
error_message="Invalid ZIP file format",
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
return {
|
|
"batch_id": str(batch.batch_id),
|
|
"status": "failed",
|
|
"error": "Invalid ZIP file format",
|
|
}
|
|
except ValueError as e:
|
|
# Security validation errors
|
|
logger.warning(f"ZIP validation failed for {zip_filename}: {e}")
|
|
self.batch_repo.update(
|
|
batch_id=batch.batch_id,
|
|
status="failed",
|
|
error_message="ZIP file validation failed",
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
return {
|
|
"batch_id": str(batch.batch_id),
|
|
"status": "failed",
|
|
"error": "ZIP file validation failed",
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error processing ZIP file {zip_filename}: {e}", exc_info=True)
|
|
self.batch_repo.update(
|
|
batch_id=batch.batch_id,
|
|
status="failed",
|
|
error_message="Processing error",
|
|
completed_at=datetime.utcnow(),
|
|
)
|
|
return {
|
|
"batch_id": str(batch.batch_id),
|
|
"status": "failed",
|
|
"error": "Failed to process batch upload",
|
|
}
|
|
|
|
def _process_zip_contents(
|
|
self,
|
|
batch_id: UUID,
|
|
admin_token: str,
|
|
zip_file: zipfile.ZipFile,
|
|
) -> dict[str, Any]:
|
|
"""Process contents of ZIP file.
|
|
|
|
Args:
|
|
batch_id: Batch upload ID
|
|
admin_token: Admin authentication token
|
|
zip_file: Opened ZIP file
|
|
|
|
Returns:
|
|
Processing results dictionary
|
|
"""
|
|
# Extract file lists
|
|
pdf_files = []
|
|
csv_file = None
|
|
csv_data = {}
|
|
|
|
for file_info in zip_file.filelist:
|
|
if file_info.is_dir():
|
|
continue
|
|
|
|
try:
|
|
# Use safe filename extraction
|
|
filename = self._safe_extract_filename(file_info.filename)
|
|
except ValueError as e:
|
|
logger.warning(f"Skipping invalid file: {e}")
|
|
continue
|
|
|
|
if filename.lower().endswith('.pdf'):
|
|
pdf_files.append(file_info)
|
|
elif filename.lower().endswith('.csv'):
|
|
if csv_file is None:
|
|
csv_file = file_info
|
|
# Parse CSV
|
|
csv_data = self._parse_csv_file(zip_file, file_info)
|
|
else:
|
|
logger.warning(f"Multiple CSV files found, using first: {csv_file.filename}")
|
|
|
|
if not pdf_files:
|
|
return {
|
|
"status": "failed",
|
|
"total_files": 0,
|
|
"processed_files": 0,
|
|
"successful_files": 0,
|
|
"failed_files": 0,
|
|
"error": "No PDF files found in ZIP",
|
|
}
|
|
|
|
# Process each PDF file
|
|
total_files = len(pdf_files)
|
|
successful_files = 0
|
|
failed_files = 0
|
|
|
|
for pdf_info in pdf_files:
|
|
file_record = None
|
|
|
|
try:
|
|
# Use safe filename extraction
|
|
filename = self._safe_extract_filename(pdf_info.filename)
|
|
|
|
# Create batch upload file record
|
|
file_record = self.batch_repo.create_file(
|
|
batch_id=batch_id,
|
|
filename=filename,
|
|
status="processing",
|
|
)
|
|
|
|
# Get CSV data for this file if available
|
|
document_id_base = Path(filename).stem
|
|
csv_row_data = csv_data.get(document_id_base)
|
|
|
|
# Extract PDF content
|
|
pdf_content = zip_file.read(pdf_info.filename)
|
|
|
|
# TODO: Save PDF file and create document
|
|
# For now, just mark as completed
|
|
|
|
self.batch_repo.update_file(
|
|
file_id=file_record.file_id,
|
|
status="completed",
|
|
csv_row_data=csv_row_data,
|
|
processed_at=datetime.utcnow(),
|
|
)
|
|
|
|
successful_files += 1
|
|
|
|
except ValueError as e:
|
|
# Path validation error
|
|
logger.warning(f"Skipping invalid file: {e}")
|
|
if file_record:
|
|
self.batch_repo.update_file(
|
|
file_id=file_record.file_id,
|
|
status="failed",
|
|
error_message="Invalid filename",
|
|
processed_at=datetime.utcnow(),
|
|
)
|
|
failed_files += 1
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing PDF: {e}", exc_info=True)
|
|
if file_record:
|
|
self.batch_repo.update_file(
|
|
file_id=file_record.file_id,
|
|
status="failed",
|
|
error_message="Processing error",
|
|
processed_at=datetime.utcnow(),
|
|
)
|
|
failed_files += 1
|
|
|
|
# Determine overall status
|
|
if failed_files == 0:
|
|
status = "completed"
|
|
elif successful_files == 0:
|
|
status = "failed"
|
|
else:
|
|
status = "partial"
|
|
|
|
result = {
|
|
"status": status,
|
|
"total_files": total_files,
|
|
"processed_files": total_files,
|
|
"successful_files": successful_files,
|
|
"failed_files": failed_files,
|
|
}
|
|
|
|
if csv_file:
|
|
result["csv_filename"] = Path(csv_file.filename).name
|
|
result["csv_row_count"] = len(csv_data)
|
|
|
|
return result
|
|
|
|
def _parse_csv_file(
|
|
self,
|
|
zip_file: zipfile.ZipFile,
|
|
csv_file_info: zipfile.ZipInfo,
|
|
) -> dict[str, dict[str, Any]]:
|
|
"""Parse CSV file and extract field values with validation.
|
|
|
|
Args:
|
|
zip_file: Opened ZIP file
|
|
csv_file_info: CSV file info
|
|
|
|
Returns:
|
|
Dictionary mapping DocumentId to validated field values
|
|
"""
|
|
# Try multiple encodings
|
|
csv_bytes = zip_file.read(csv_file_info.filename)
|
|
encodings = ['utf-8-sig', 'utf-8', 'latin-1', 'cp1252']
|
|
csv_content = None
|
|
|
|
for encoding in encodings:
|
|
try:
|
|
csv_content = csv_bytes.decode(encoding)
|
|
logger.info(f"CSV decoded with {encoding}")
|
|
break
|
|
except UnicodeDecodeError:
|
|
continue
|
|
|
|
if csv_content is None:
|
|
logger.error("Failed to decode CSV with any encoding")
|
|
raise ValueError("Unable to decode CSV file")
|
|
|
|
csv_reader = csv.DictReader(io.StringIO(csv_content))
|
|
result = {}
|
|
|
|
# Case-insensitive column mapping
|
|
field_name_map = {
|
|
'DocumentId': ['DocumentId', 'documentid', 'document_id'],
|
|
'InvoiceNumber': ['InvoiceNumber', 'invoicenumber', 'invoice_number'],
|
|
'InvoiceDate': ['InvoiceDate', 'invoicedate', 'invoice_date'],
|
|
'InvoiceDueDate': ['InvoiceDueDate', 'invoiceduedate', 'invoice_due_date'],
|
|
'Amount': ['Amount', 'amount'],
|
|
'OCR': ['OCR', 'ocr'],
|
|
'Bankgiro': ['Bankgiro', 'bankgiro'],
|
|
'Plusgiro': ['Plusgiro', 'plusgiro'],
|
|
'customer_number': ['customer_number', 'customernumber', 'CustomerNumber'],
|
|
'supplier_organisation_number': ['supplier_organisation_number', 'supplierorganisationnumber'],
|
|
}
|
|
|
|
for row_num, row in enumerate(csv_reader, start=2):
|
|
try:
|
|
# Create case-insensitive lookup
|
|
row_lower = {k.lower(): v for k, v in row.items()}
|
|
|
|
# Find DocumentId with case-insensitive matching
|
|
document_id = None
|
|
for variant in field_name_map['DocumentId']:
|
|
if variant.lower() in row_lower:
|
|
document_id = row_lower[variant.lower()]
|
|
break
|
|
|
|
if not document_id:
|
|
logger.warning(f"Row {row_num}: No DocumentId found")
|
|
continue
|
|
|
|
# Validate using Pydantic model
|
|
csv_row_dict = {'document_id': document_id}
|
|
|
|
# Map CSV field names to model attribute names
|
|
csv_to_model_attr = {
|
|
'InvoiceNumber': 'invoice_number',
|
|
'InvoiceDate': 'invoice_date',
|
|
'InvoiceDueDate': 'invoice_due_date',
|
|
'Amount': 'amount',
|
|
'OCR': 'ocr',
|
|
'Bankgiro': 'bankgiro',
|
|
'Plusgiro': 'plusgiro',
|
|
'customer_number': 'customer_number',
|
|
'supplier_organisation_number': 'supplier_organisation_number',
|
|
}
|
|
|
|
for csv_field in field_name_map.keys():
|
|
if csv_field == 'DocumentId':
|
|
continue
|
|
|
|
model_attr = csv_to_model_attr.get(csv_field)
|
|
if not model_attr:
|
|
continue
|
|
|
|
for variant in field_name_map[csv_field]:
|
|
if variant.lower() in row_lower and row_lower[variant.lower()]:
|
|
csv_row_dict[model_attr] = row_lower[variant.lower()]
|
|
break
|
|
|
|
# Validate
|
|
validated_row = CSVRowData(**csv_row_dict)
|
|
|
|
# Extract only the fields we care about (map back to CSV field names)
|
|
field_values = {}
|
|
model_attr_to_csv = {
|
|
'invoice_number': 'InvoiceNumber',
|
|
'invoice_date': 'InvoiceDate',
|
|
'invoice_due_date': 'InvoiceDueDate',
|
|
'amount': 'Amount',
|
|
'ocr': 'OCR',
|
|
'bankgiro': 'Bankgiro',
|
|
'plusgiro': 'Plusgiro',
|
|
'customer_number': 'customer_number',
|
|
'supplier_organisation_number': 'supplier_organisation_number',
|
|
}
|
|
|
|
for model_attr, csv_field in model_attr_to_csv.items():
|
|
value = getattr(validated_row, model_attr, None)
|
|
if value and csv_field in CSV_TO_CLASS_MAPPING:
|
|
field_values[csv_field] = value
|
|
|
|
if field_values:
|
|
result[document_id] = field_values
|
|
|
|
except Exception as e:
|
|
logger.warning(f"Row {row_num}: Validation error - {e}")
|
|
continue
|
|
|
|
return result
|
|
|
|
def get_batch_status(self, batch_id: str) -> dict[str, Any]:
|
|
"""Get batch upload status.
|
|
|
|
Args:
|
|
batch_id: Batch upload ID
|
|
|
|
Returns:
|
|
Batch status dictionary
|
|
"""
|
|
batch = self.batch_repo.get(UUID(batch_id))
|
|
if not batch:
|
|
return {
|
|
"error": "Batch upload not found",
|
|
}
|
|
|
|
files = self.batch_repo.get_files(batch.batch_id)
|
|
|
|
return {
|
|
"batch_id": str(batch.batch_id),
|
|
"filename": batch.filename,
|
|
"status": batch.status,
|
|
"total_files": batch.total_files,
|
|
"processed_files": batch.processed_files,
|
|
"successful_files": batch.successful_files,
|
|
"failed_files": batch.failed_files,
|
|
"csv_filename": batch.csv_filename,
|
|
"csv_row_count": batch.csv_row_count,
|
|
"error_message": batch.error_message,
|
|
"created_at": batch.created_at.isoformat() if batch.created_at else None,
|
|
"completed_at": batch.completed_at.isoformat() if batch.completed_at else None,
|
|
"files": [
|
|
{
|
|
"filename": f.filename,
|
|
"status": f.status,
|
|
"error_message": f.error_message,
|
|
"annotation_count": f.annotation_count,
|
|
}
|
|
for f in files
|
|
],
|
|
}
|