343 lines
11 KiB
Python
343 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
CLI for cross-validation of invoice field extraction using LLM.
|
|
|
|
Validates documents with failed field matches by sending them to an LLM
|
|
and comparing the extraction results.
|
|
"""
|
|
|
|
import argparse
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from shared.logging_config import setup_cli_logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='Cross-validate invoice field extraction using LLM'
|
|
)
|
|
|
|
subparsers = parser.add_subparsers(dest='command', help='Commands')
|
|
|
|
# Stats command
|
|
stats_parser = subparsers.add_parser('stats', help='Show failed match statistics')
|
|
|
|
# Validate command
|
|
validate_parser = subparsers.add_parser('validate', help='Validate documents with failed matches')
|
|
validate_parser.add_argument(
|
|
'--limit', '-l',
|
|
type=int,
|
|
default=10,
|
|
help='Maximum number of documents to validate (default: 10)'
|
|
)
|
|
validate_parser.add_argument(
|
|
'--provider', '-p',
|
|
choices=['openai', 'anthropic'],
|
|
default='openai',
|
|
help='LLM provider to use (default: openai)'
|
|
)
|
|
validate_parser.add_argument(
|
|
'--model', '-m',
|
|
help='Model to use (default: gpt-4o for OpenAI, claude-sonnet-4-20250514 for Anthropic)'
|
|
)
|
|
validate_parser.add_argument(
|
|
'--single', '-s',
|
|
help='Validate a single document ID'
|
|
)
|
|
|
|
# Compare command
|
|
compare_parser = subparsers.add_parser('compare', help='Compare validation results')
|
|
compare_parser.add_argument(
|
|
'document_id',
|
|
nargs='?',
|
|
help='Document ID to compare (or omit to show all)'
|
|
)
|
|
compare_parser.add_argument(
|
|
'--limit', '-l',
|
|
type=int,
|
|
default=20,
|
|
help='Maximum number of results to show (default: 20)'
|
|
)
|
|
|
|
# Report command
|
|
report_parser = subparsers.add_parser('report', help='Generate validation report')
|
|
report_parser.add_argument(
|
|
'--output', '-o',
|
|
default='reports/llm_validation_report.json',
|
|
help='Output file path (default: reports/llm_validation_report.json)'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if not args.command:
|
|
parser.print_help()
|
|
return
|
|
|
|
# Configure logging for CLI
|
|
setup_cli_logging()
|
|
|
|
from backend.validation import LLMValidator
|
|
|
|
validator = LLMValidator()
|
|
validator.connect()
|
|
validator.create_validation_table()
|
|
|
|
if args.command == 'stats':
|
|
show_stats(validator)
|
|
|
|
elif args.command == 'validate':
|
|
if args.single:
|
|
validate_single(validator, args.single, args.provider, args.model)
|
|
else:
|
|
validate_batch(validator, args.limit, args.provider, args.model)
|
|
|
|
elif args.command == 'compare':
|
|
if args.document_id:
|
|
compare_single(validator, args.document_id)
|
|
else:
|
|
compare_all(validator, args.limit)
|
|
|
|
elif args.command == 'report':
|
|
generate_report(validator, args.output)
|
|
|
|
validator.close()
|
|
|
|
|
|
def show_stats(validator):
|
|
"""Show statistics about failed matches."""
|
|
stats = validator.get_failed_match_stats()
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("Failed Match Statistics")
|
|
logger.info("=" * 50)
|
|
logger.info("Documents with failures: %d", stats['documents_with_failures'])
|
|
logger.info("Already validated: %d", stats['already_validated'])
|
|
logger.info("Remaining to validate: %d", stats['remaining'])
|
|
logger.info("Failures by field:")
|
|
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
|
|
logger.info(" %s: %d", field, count)
|
|
|
|
|
|
def validate_single(validator, doc_id: str, provider: str, model: str):
|
|
"""Validate a single document."""
|
|
logger.info("Validating document: %s", doc_id)
|
|
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
|
|
|
result = validator.validate_document(doc_id, provider, model)
|
|
|
|
if result.error:
|
|
logger.error("ERROR: %s", result.error)
|
|
return
|
|
|
|
logger.info("Processing time: %.0fms", result.processing_time_ms)
|
|
logger.info("Model used: %s", result.model_used)
|
|
logger.info("Extracted fields:")
|
|
logger.info(" Invoice Number: %s", result.invoice_number)
|
|
logger.info(" Invoice Date: %s", result.invoice_date)
|
|
logger.info(" Due Date: %s", result.invoice_due_date)
|
|
logger.info(" OCR: %s", result.ocr_number)
|
|
logger.info(" Bankgiro: %s", result.bankgiro)
|
|
logger.info(" Plusgiro: %s", result.plusgiro)
|
|
logger.info(" Amount: %s", result.amount)
|
|
logger.info(" Org Number: %s", result.supplier_organisation_number)
|
|
|
|
# Show comparison
|
|
logger.info("-" * 50)
|
|
logger.info("Comparison with autolabel:")
|
|
comparison = validator.compare_results(doc_id)
|
|
for field, data in comparison.items():
|
|
if data.get('csv_value'):
|
|
status = "[OK]" if data['agreement'] else "[FAIL]"
|
|
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
|
logger.info(" %s %s:", status, field)
|
|
logger.info(" CSV: %s", data['csv_value'])
|
|
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
|
logger.info(" LLM: %s", data['llm_value'])
|
|
|
|
|
|
def validate_batch(validator, limit: int, provider: str, model: str):
|
|
"""Validate a batch of documents."""
|
|
logger.info("Validating up to %d documents with failed matches", limit)
|
|
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
|
|
|
results = validator.validate_batch(
|
|
limit=limit,
|
|
provider=provider,
|
|
model=model,
|
|
verbose=True
|
|
)
|
|
|
|
# Summary
|
|
success = sum(1 for r in results if not r.error)
|
|
failed = len(results) - success
|
|
total_time = sum(r.processing_time_ms or 0 for r in results)
|
|
|
|
logger.info("=" * 50)
|
|
logger.info("Validation Complete")
|
|
logger.info("=" * 50)
|
|
logger.info("Total: %d", len(results))
|
|
logger.info("Success: %d", success)
|
|
logger.info("Failed: %d", failed)
|
|
logger.info("Total time: %.1fs", total_time/1000)
|
|
if success > 0:
|
|
logger.info("Avg time: %.0fms per document", total_time/success)
|
|
|
|
|
|
def compare_single(validator, doc_id: str):
|
|
"""Compare results for a single document."""
|
|
comparison = validator.compare_results(doc_id)
|
|
|
|
if 'error' in comparison:
|
|
logger.error("Error: %s", comparison['error'])
|
|
return
|
|
|
|
logger.info("Comparison for document: %s", doc_id)
|
|
logger.info("=" * 60)
|
|
|
|
for field, data in comparison.items():
|
|
if data.get('csv_value') is None:
|
|
continue
|
|
|
|
status = "[OK]" if data['agreement'] else "[FAIL]"
|
|
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
|
|
|
logger.info("%s %s:", status, field)
|
|
logger.info(" CSV value: %s", data['csv_value'])
|
|
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
|
logger.info(" LLM extracted: %s", data['llm_value'])
|
|
|
|
|
|
def compare_all(validator, limit: int):
|
|
"""Show comparison summary for all validated documents."""
|
|
conn = validator.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
SELECT document_id FROM llm_validations
|
|
WHERE error IS NULL
|
|
ORDER BY created_at DESC
|
|
LIMIT %s
|
|
""", (limit,))
|
|
|
|
doc_ids = [row[0] for row in cursor.fetchall()]
|
|
|
|
if not doc_ids:
|
|
logger.info("No validated documents found.")
|
|
return
|
|
|
|
logger.info("Comparison Summary (%d documents)", len(doc_ids))
|
|
logger.info("=" * 80)
|
|
|
|
# Aggregate stats
|
|
field_stats = {}
|
|
|
|
for doc_id in doc_ids:
|
|
comparison = validator.compare_results(doc_id)
|
|
if 'error' in comparison:
|
|
continue
|
|
|
|
for field, data in comparison.items():
|
|
if data.get('csv_value') is None:
|
|
continue
|
|
|
|
if field not in field_stats:
|
|
field_stats[field] = {
|
|
'total': 0,
|
|
'autolabel_matched': 0,
|
|
'llm_agrees': 0,
|
|
'llm_correct_auto_wrong': 0,
|
|
}
|
|
|
|
stats = field_stats[field]
|
|
stats['total'] += 1
|
|
|
|
if data['autolabel_matched']:
|
|
stats['autolabel_matched'] += 1
|
|
|
|
if data['agreement']:
|
|
stats['llm_agrees'] += 1
|
|
|
|
# LLM found correct value when autolabel failed
|
|
if not data['autolabel_matched'] and data['agreement']:
|
|
stats['llm_correct_auto_wrong'] += 1
|
|
|
|
logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
|
|
logger.info("-" * 80)
|
|
|
|
for field, stats in sorted(field_stats.items()):
|
|
logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
|
|
stats['llm_agrees'], stats['llm_correct_auto_wrong'])
|
|
|
|
|
|
def generate_report(validator, output_path: str):
|
|
"""Generate a detailed validation report."""
|
|
import json
|
|
from datetime import datetime
|
|
|
|
conn = validator.connect()
|
|
with conn.cursor() as cursor:
|
|
# Get all validated documents
|
|
cursor.execute("""
|
|
SELECT document_id, invoice_number, invoice_date, invoice_due_date,
|
|
ocr_number, bankgiro, plusgiro, amount,
|
|
supplier_organisation_number, model_used, processing_time_ms,
|
|
error, created_at
|
|
FROM llm_validations
|
|
ORDER BY created_at DESC
|
|
""")
|
|
|
|
validations = []
|
|
for row in cursor.fetchall():
|
|
doc_id = row[0]
|
|
comparison = validator.compare_results(doc_id) if not row[11] else {}
|
|
|
|
validations.append({
|
|
'document_id': doc_id,
|
|
'llm_extraction': {
|
|
'invoice_number': row[1],
|
|
'invoice_date': row[2],
|
|
'invoice_due_date': row[3],
|
|
'ocr_number': row[4],
|
|
'bankgiro': row[5],
|
|
'plusgiro': row[6],
|
|
'amount': row[7],
|
|
'supplier_organisation_number': row[8],
|
|
},
|
|
'model_used': row[9],
|
|
'processing_time_ms': row[10],
|
|
'error': row[11],
|
|
'created_at': str(row[12]) if row[12] else None,
|
|
'comparison': comparison,
|
|
})
|
|
|
|
# Calculate summary stats
|
|
stats = validator.get_failed_match_stats()
|
|
|
|
report = {
|
|
'generated_at': datetime.now().isoformat(),
|
|
'summary': {
|
|
'total_documents_with_failures': stats['documents_with_failures'],
|
|
'documents_validated': stats['already_validated'],
|
|
'failures_by_field': stats['failures_by_field'],
|
|
},
|
|
'validations': validations,
|
|
}
|
|
|
|
# Write report
|
|
output_path = Path(output_path)
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f:
|
|
json.dump(report, f, indent=2, ensure_ascii=False)
|
|
|
|
logger.info("Report generated: %s", output_path)
|
|
logger.info("Total validations: %d", len(validations))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|