Files
invoice-master-poc-v2/packages/training/training/cli/validate.py
Yaojia Wang f1a7bfe6b7 WIP
2026-02-07 13:56:00 +01:00

343 lines
11 KiB
Python

#!/usr/bin/env python3
"""
CLI for cross-validation of invoice field extraction using LLM.
Validates documents with failed field matches by sending them to an LLM
and comparing the extraction results.
"""
import argparse
import logging
import sys
from pathlib import Path
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(
description='Cross-validate invoice field extraction using LLM'
)
subparsers = parser.add_subparsers(dest='command', help='Commands')
# Stats command
stats_parser = subparsers.add_parser('stats', help='Show failed match statistics')
# Validate command
validate_parser = subparsers.add_parser('validate', help='Validate documents with failed matches')
validate_parser.add_argument(
'--limit', '-l',
type=int,
default=10,
help='Maximum number of documents to validate (default: 10)'
)
validate_parser.add_argument(
'--provider', '-p',
choices=['openai', 'anthropic'],
default='openai',
help='LLM provider to use (default: openai)'
)
validate_parser.add_argument(
'--model', '-m',
help='Model to use (default: gpt-4o for OpenAI, claude-sonnet-4-20250514 for Anthropic)'
)
validate_parser.add_argument(
'--single', '-s',
help='Validate a single document ID'
)
# Compare command
compare_parser = subparsers.add_parser('compare', help='Compare validation results')
compare_parser.add_argument(
'document_id',
nargs='?',
help='Document ID to compare (or omit to show all)'
)
compare_parser.add_argument(
'--limit', '-l',
type=int,
default=20,
help='Maximum number of results to show (default: 20)'
)
# Report command
report_parser = subparsers.add_parser('report', help='Generate validation report')
report_parser.add_argument(
'--output', '-o',
default='reports/llm_validation_report.json',
help='Output file path (default: reports/llm_validation_report.json)'
)
args = parser.parse_args()
if not args.command:
parser.print_help()
return
# Configure logging for CLI
setup_cli_logging()
from backend.validation import LLMValidator
validator = LLMValidator()
validator.connect()
validator.create_validation_table()
if args.command == 'stats':
show_stats(validator)
elif args.command == 'validate':
if args.single:
validate_single(validator, args.single, args.provider, args.model)
else:
validate_batch(validator, args.limit, args.provider, args.model)
elif args.command == 'compare':
if args.document_id:
compare_single(validator, args.document_id)
else:
compare_all(validator, args.limit)
elif args.command == 'report':
generate_report(validator, args.output)
validator.close()
def show_stats(validator):
"""Show statistics about failed matches."""
stats = validator.get_failed_match_stats()
logger.info("=" * 50)
logger.info("Failed Match Statistics")
logger.info("=" * 50)
logger.info("Documents with failures: %d", stats['documents_with_failures'])
logger.info("Already validated: %d", stats['already_validated'])
logger.info("Remaining to validate: %d", stats['remaining'])
logger.info("Failures by field:")
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
logger.info(" %s: %d", field, count)
def validate_single(validator, doc_id: str, provider: str, model: str):
"""Validate a single document."""
logger.info("Validating document: %s", doc_id)
logger.info("Provider: %s, Model: %s", provider, model or 'default')
result = validator.validate_document(doc_id, provider, model)
if result.error:
logger.error("ERROR: %s", result.error)
return
logger.info("Processing time: %.0fms", result.processing_time_ms)
logger.info("Model used: %s", result.model_used)
logger.info("Extracted fields:")
logger.info(" Invoice Number: %s", result.invoice_number)
logger.info(" Invoice Date: %s", result.invoice_date)
logger.info(" Due Date: %s", result.invoice_due_date)
logger.info(" OCR: %s", result.ocr_number)
logger.info(" Bankgiro: %s", result.bankgiro)
logger.info(" Plusgiro: %s", result.plusgiro)
logger.info(" Amount: %s", result.amount)
logger.info(" Org Number: %s", result.supplier_organisation_number)
# Show comparison
logger.info("-" * 50)
logger.info("Comparison with autolabel:")
comparison = validator.compare_results(doc_id)
for field, data in comparison.items():
if data.get('csv_value'):
status = "[OK]" if data['agreement'] else "[FAIL]"
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
logger.info(" %s %s:", status, field)
logger.info(" CSV: %s", data['csv_value'])
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
logger.info(" LLM: %s", data['llm_value'])
def validate_batch(validator, limit: int, provider: str, model: str):
"""Validate a batch of documents."""
logger.info("Validating up to %d documents with failed matches", limit)
logger.info("Provider: %s, Model: %s", provider, model or 'default')
results = validator.validate_batch(
limit=limit,
provider=provider,
model=model,
verbose=True
)
# Summary
success = sum(1 for r in results if not r.error)
failed = len(results) - success
total_time = sum(r.processing_time_ms or 0 for r in results)
logger.info("=" * 50)
logger.info("Validation Complete")
logger.info("=" * 50)
logger.info("Total: %d", len(results))
logger.info("Success: %d", success)
logger.info("Failed: %d", failed)
logger.info("Total time: %.1fs", total_time/1000)
if success > 0:
logger.info("Avg time: %.0fms per document", total_time/success)
def compare_single(validator, doc_id: str):
"""Compare results for a single document."""
comparison = validator.compare_results(doc_id)
if 'error' in comparison:
logger.error("Error: %s", comparison['error'])
return
logger.info("Comparison for document: %s", doc_id)
logger.info("=" * 60)
for field, data in comparison.items():
if data.get('csv_value') is None:
continue
status = "[OK]" if data['agreement'] else "[FAIL]"
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
logger.info("%s %s:", status, field)
logger.info(" CSV value: %s", data['csv_value'])
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
logger.info(" LLM extracted: %s", data['llm_value'])
def compare_all(validator, limit: int):
"""Show comparison summary for all validated documents."""
conn = validator.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT document_id FROM llm_validations
WHERE error IS NULL
ORDER BY created_at DESC
LIMIT %s
""", (limit,))
doc_ids = [row[0] for row in cursor.fetchall()]
if not doc_ids:
logger.info("No validated documents found.")
return
logger.info("Comparison Summary (%d documents)", len(doc_ids))
logger.info("=" * 80)
# Aggregate stats
field_stats = {}
for doc_id in doc_ids:
comparison = validator.compare_results(doc_id)
if 'error' in comparison:
continue
for field, data in comparison.items():
if data.get('csv_value') is None:
continue
if field not in field_stats:
field_stats[field] = {
'total': 0,
'autolabel_matched': 0,
'llm_agrees': 0,
'llm_correct_auto_wrong': 0,
}
stats = field_stats[field]
stats['total'] += 1
if data['autolabel_matched']:
stats['autolabel_matched'] += 1
if data['agreement']:
stats['llm_agrees'] += 1
# LLM found correct value when autolabel failed
if not data['autolabel_matched'] and data['agreement']:
stats['llm_correct_auto_wrong'] += 1
logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
logger.info("-" * 80)
for field, stats in sorted(field_stats.items()):
logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
stats['llm_agrees'], stats['llm_correct_auto_wrong'])
def generate_report(validator, output_path: str):
"""Generate a detailed validation report."""
import json
from datetime import datetime
conn = validator.connect()
with conn.cursor() as cursor:
# Get all validated documents
cursor.execute("""
SELECT document_id, invoice_number, invoice_date, invoice_due_date,
ocr_number, bankgiro, plusgiro, amount,
supplier_organisation_number, model_used, processing_time_ms,
error, created_at
FROM llm_validations
ORDER BY created_at DESC
""")
validations = []
for row in cursor.fetchall():
doc_id = row[0]
comparison = validator.compare_results(doc_id) if not row[11] else {}
validations.append({
'document_id': doc_id,
'llm_extraction': {
'invoice_number': row[1],
'invoice_date': row[2],
'invoice_due_date': row[3],
'ocr_number': row[4],
'bankgiro': row[5],
'plusgiro': row[6],
'amount': row[7],
'supplier_organisation_number': row[8],
},
'model_used': row[9],
'processing_time_ms': row[10],
'error': row[11],
'created_at': str(row[12]) if row[12] else None,
'comparison': comparison,
})
# Calculate summary stats
stats = validator.get_failed_match_stats()
report = {
'generated_at': datetime.now().isoformat(),
'summary': {
'total_documents_with_failures': stats['documents_with_failures'],
'documents_validated': stats['already_validated'],
'failures_by_field': stats['failures_by_field'],
},
'validations': validations,
}
# Write report
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
logger.info("Report generated: %s", output_path)
logger.info("Total validations: %d", len(validations))
if __name__ == '__main__':
main()