This commit is contained in:
Yaojia Wang
2026-02-07 13:56:00 +01:00
parent 0990239e9c
commit f1a7bfe6b7
16 changed files with 1121 additions and 307 deletions

View File

@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
import argparse
import json
import logging
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main():
@@ -66,10 +70,13 @@ def main():
args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
logger.error("Model not found: %s", model_path)
sys.exit(1)
# Get input files
@@ -79,16 +86,16 @@ def main():
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
logger.error("Input not found: %s", input_path)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
logger.error("No PDF files found")
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
logger.info("Processing %d PDF file(s)", len(pdf_files))
logger.info("Model: %s", model_path)
from backend.pipeline import InferencePipeline
@@ -107,18 +114,18 @@ def main():
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
logger.info("Processing: %s", pdf_path.name)
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
logger.info(" Success: %s", result.success)
logger.info(" Fields: %d", len(result.fields))
if result.fallback_used:
print(f" Fallback used: Yes")
logger.info(" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
logger.info(" Errors: %s", result.errors)
# Output results
if len(results) == 1:
@@ -132,9 +139,11 @@ def main():
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
logger.info("Results written to: %s", args.output)
else:
print(json_output)
# Output JSON to stdout (not logged)
sys.stdout.write(json_output)
sys.stdout.write('\n')
if __name__ == '__main__':

View File

@@ -417,7 +417,12 @@ class InferencePipeline:
result.errors.append(f"Business feature extraction error: {error_detail}")
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
"""Merge extracted fields, keeping best candidate for each field.
Selection priority:
1. Prefer candidates without validation errors
2. Among equal validity, prefer higher confidence
"""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
@@ -430,7 +435,12 @@ class InferencePipeline:
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
# Sort by: (no validation error, confidence) - descending
# This prefers candidates without errors, then by confidence
best = max(
candidates,
key=lambda x: (x.validation_error is None, x.confidence)
)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields)

View File

@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
import json
import base64
import logging
import os
from pathlib import Path
from typing import Optional, Dict, Any, List
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
from datetime import datetime
import psycopg2
logger = logging.getLogger(__name__)
from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
docs = self.get_documents_with_failed_matches(limit=limit)
if verbose:
print(f"Found {len(docs)} documents with failed matches to validate")
logger.info("Found %d documents with failed matches to validate", len(docs))
results = []
for i, doc in enumerate(docs):
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
result = self.validate_document(doc_id, provider, model)
results.append(result)
if verbose:
if result.error:
print(f" ERROR: {result.error}")
logger.error(" ERROR: %s", result.error)
else:
print(f" OK ({result.processing_time_ms:.0f}ms)")
logger.info(" OK (%.0fms)", result.processing_time_ms)
return results

View File

@@ -11,6 +11,7 @@ from backend.web.schemas.admin import (
ExportResponse,
)
from backend.web.schemas.common import ErrorResponse
from shared.bbox import expand_bbox
logger = logging.getLogger(__name__)
@@ -102,12 +103,52 @@ def register_export_routes(router: APIRouter) -> None:
dst_image.write_bytes(image_content)
total_images += 1
# Get image dimensions for bbox expansion
img_dims = storage.get_admin_image_dimensions(doc_id, page_num)
if img_dims is None:
# Fall back to standard A4 at 300 DPI if dimensions unavailable
img_width, img_height = 2480, 3508
else:
img_width, img_height = img_dims
label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f:
for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
# Convert normalized coords to pixel coords
half_w = (ann.width * img_width) / 2
half_h = (ann.height * img_height) / 2
x0 = ann.x_center * img_width - half_w
y0 = ann.y_center * img_height - half_h
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Use manual_mode for manual/imported annotations
manual_mode = ann.source in ("manual", "imported")
# Apply field-specific bbox expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Convert back to normalized YOLO format
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Clamp to valid range
new_x_center = max(0, min(1, new_x_center))
new_y_center = max(0, min(1, new_y_center))
new_width = max(0, min(1, new_width))
new_height = max(0, min(1, new_height))
line = f"{ann.class_id} {new_x_center:.6f} {new_y_center:.6f} {new_width:.6f} {new_height:.6f}\n"
f.write(line)
total_annotations += 1