Files
invoice-master-poc-v2/packages/backend/backend/cli/infer.py
Yaojia Wang b602d0a340 re-structure
2026-02-01 22:55:31 +01:00

142 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""
Inference CLI
Runs inference on new PDFs to extract invoice data.
"""
import argparse
import json
import sys
from pathlib import Path
from shared.config import DEFAULT_DPI
def main():
parser = argparse.ArgumentParser(
description='Extract invoice data from PDFs using trained model'
)
parser.add_argument(
'--model', '-m',
required=True,
help='Path to trained YOLO model (.pt file)'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Input PDF file or directory'
)
parser.add_argument(
'--output', '-o',
help='Output JSON file (default: stdout)'
)
parser.add_argument(
'--confidence',
type=float,
default=0.5,
help='Detection confidence threshold (default: 0.5)'
)
parser.add_argument(
'--dpi',
type=int,
default=DEFAULT_DPI,
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
)
parser.add_argument(
'--no-fallback',
action='store_true',
help='Disable fallback OCR'
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
)
parser.add_argument(
'--gpu',
action='store_true',
help='Use GPU'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
args = parser.parse_args()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
sys.exit(1)
# Get input files
input_path = Path(args.input)
if input_path.is_file():
pdf_files = [input_path]
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
from backend.pipeline import InferencePipeline
# Initialize pipeline
pipeline = InferencePipeline(
model_path=model_path,
confidence_threshold=args.confidence,
ocr_lang=args.lang,
use_gpu=args.gpu,
dpi=args.dpi,
enable_fallback=not args.no_fallback
)
# Process files
results = []
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
if result.fallback_used:
print(f" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
# Output results
if len(results) == 1:
output = results[0]
else:
output = results
json_output = json.dumps(output, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
else:
print(json_output)
if __name__ == '__main__':
main()