WIP
This commit is contained in:
@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -66,10 +70,13 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
logger.error("Model not found: %s", model_path)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
@@ -79,16 +86,16 @@ def main():
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
logger.error("Input not found: %s", input_path)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
logger.error("No PDF files found")
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
logger.info("Processing %d PDF file(s)", len(pdf_files))
|
||||
logger.info("Model: %s", model_path)
|
||||
|
||||
from backend.pipeline import InferencePipeline
|
||||
|
||||
@@ -107,18 +114,18 @@ def main():
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
logger.info("Processing: %s", pdf_path.name)
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
logger.info(" Success: %s", result.success)
|
||||
logger.info(" Fields: %d", len(result.fields))
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
logger.info(" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
logger.info(" Errors: %s", result.errors)
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
@@ -132,9 +139,11 @@ def main():
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
logger.info("Results written to: %s", args.output)
|
||||
else:
|
||||
print(json_output)
|
||||
# Output JSON to stdout (not logged)
|
||||
sys.stdout.write(json_output)
|
||||
sys.stdout.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -417,7 +417,12 @@ class InferencePipeline:
|
||||
result.errors.append(f"Business feature extraction error: {error_detail}")
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
"""Merge extracted fields, keeping best candidate for each field.
|
||||
|
||||
Selection priority:
|
||||
1. Prefer candidates without validation errors
|
||||
2. Among equal validity, prefer higher confidence
|
||||
"""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
for extracted in result.extracted_fields:
|
||||
@@ -430,7 +435,12 @@ class InferencePipeline:
|
||||
|
||||
# Select best candidate for each field
|
||||
for field_name, candidates in field_candidates.items():
|
||||
best = max(candidates, key=lambda x: x.confidence)
|
||||
# Sort by: (no validation error, confidence) - descending
|
||||
# This prefers candidates without errors, then by confidence
|
||||
best = max(
|
||||
candidates,
|
||||
key=lambda x: (x.validation_error is None, x.confidence)
|
||||
)
|
||||
result.fields[field_name] = best.normalized_value
|
||||
result.confidence[field_name] = best.confidence
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
|
||||
@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
|
||||
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
|
||||
docs = self.get_documents_with_failed_matches(limit=limit)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(docs)} documents with failed matches to validate")
|
||||
logger.info("Found %d documents with failed matches to validate", len(docs))
|
||||
|
||||
results = []
|
||||
for i, doc in enumerate(docs):
|
||||
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
|
||||
|
||||
if verbose:
|
||||
failed_fields = [f['field'] for f in doc['failed_fields']]
|
||||
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
||||
logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
|
||||
|
||||
result = self.validate_document(doc_id, provider, model)
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
if result.error:
|
||||
print(f" ERROR: {result.error}")
|
||||
logger.error(" ERROR: %s", result.error)
|
||||
else:
|
||||
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
||||
logger.info(" OK (%.0fms)", result.processing_time_ms)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.web.schemas.admin import (
|
||||
ExportResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from shared.bbox import expand_bbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -102,12 +103,52 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Get image dimensions for bbox expansion
|
||||
img_dims = storage.get_admin_image_dimensions(doc_id, page_num)
|
||||
if img_dims is None:
|
||||
# Fall back to standard A4 at 300 DPI if dimensions unavailable
|
||||
img_width, img_height = 2480, 3508
|
||||
else:
|
||||
img_width, img_height = img_dims
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
# Convert normalized coords to pixel coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
y0 = ann.y_center * img_height - half_h
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Use manual_mode for manual/imported annotations
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized YOLO format
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
new_x_center = max(0, min(1, new_x_center))
|
||||
new_y_center = max(0, min(1, new_y_center))
|
||||
new_width = max(0, min(1, new_width))
|
||||
new_height = max(0, min(1, new_height))
|
||||
|
||||
line = f"{ann.class_id} {new_x_center:.6f} {new_y_center:.6f} {new_width:.6f} {new_height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user