WIP
This commit is contained in:
@@ -9,6 +9,7 @@ Now reads from PostgreSQL database instead of JSONL files.
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,6 +17,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from shared.normalize import normalize_field
|
||||
from shared.matcher import FieldMatcher
|
||||
@@ -104,7 +108,7 @@ class LabelAnalyzer:
|
||||
for row in reader:
|
||||
doc_id = row['DocumentId']
|
||||
self.csv_data[doc_id] = row
|
||||
print(f"Loaded {len(self.csv_data)} records from CSV")
|
||||
logger.info("Loaded %d records from CSV", len(self.csv_data))
|
||||
|
||||
def load_labels(self):
|
||||
"""Load all label files from dataset."""
|
||||
@@ -150,12 +154,12 @@ class LabelAnalyzer:
|
||||
for doc in self.label_data.values()
|
||||
for labels in doc['pages'].values()
|
||||
)
|
||||
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
|
||||
logger.info("Loaded labels for %d documents (%d total labels)", total_docs, total_labels)
|
||||
|
||||
def load_report(self):
|
||||
"""Load autolabel report from database."""
|
||||
if not self.db:
|
||||
print("Database not configured, skipping report loading")
|
||||
logger.info("Database not configured, skipping report loading")
|
||||
return
|
||||
|
||||
# Get document IDs from CSV to query
|
||||
@@ -175,7 +179,7 @@ class LabelAnalyzer:
|
||||
self.report_data[doc_id] = doc
|
||||
loaded += 1
|
||||
|
||||
print(f"Loaded {loaded} autolabel reports from database")
|
||||
logger.info("Loaded %d autolabel reports from database", loaded)
|
||||
|
||||
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
|
||||
"""Analyze a single document."""
|
||||
@@ -373,7 +377,7 @@ class LabelAnalyzer:
|
||||
break
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} documents without PDF files")
|
||||
logger.info("Skipped %d documents without PDF files", skipped)
|
||||
|
||||
return results
|
||||
|
||||
@@ -447,7 +451,7 @@ class LabelAnalyzer:
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport saved to: {output}")
|
||||
logger.info("Report saved to: %s", output)
|
||||
|
||||
return report
|
||||
|
||||
@@ -456,52 +460,52 @@ def print_summary(report: dict):
|
||||
"""Print summary to console."""
|
||||
summary = report['summary']
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LABEL ANALYSIS SUMMARY")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("LABEL ANALYSIS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
|
||||
print(f"\nDocuments:")
|
||||
print(f" Total: {summary['total_documents']}")
|
||||
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
|
||||
logger.info("Documents:")
|
||||
logger.info(" Total: %d", summary['total_documents'])
|
||||
logger.info(" With issues: %d (%s)", summary['documents_with_issues'], summary['issue_rate'])
|
||||
|
||||
print(f"\nFields:")
|
||||
print(f" Expected: {summary['total_expected_fields']}")
|
||||
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
|
||||
print(f" Missing: {summary['missing_labels']}")
|
||||
print(f" Extra: {summary['extra_labels']}")
|
||||
logger.info("Fields:")
|
||||
logger.info(" Expected: %d", summary['total_expected_fields'])
|
||||
logger.info(" Labeled: %d (%s)", summary['total_labeled_fields'], summary['label_coverage'])
|
||||
logger.info(" Missing: %d", summary['missing_labels'])
|
||||
logger.info(" Extra: %d", summary['extra_labels'])
|
||||
|
||||
print(f"\nFailure Reasons:")
|
||||
logger.info("Failure Reasons:")
|
||||
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
logger.info(" %s: %d", reason, count)
|
||||
|
||||
print(f"\nFailures by Field:")
|
||||
logger.info("Failures by Field:")
|
||||
for field, reasons in report['failures_by_field'].items():
|
||||
total = sum(reasons.values())
|
||||
print(f" {field}: {total}")
|
||||
logger.info(" %s: %d", field, total)
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
logger.info(" - %s: %d", reason, count)
|
||||
|
||||
# Show sample issues
|
||||
if report['issues']:
|
||||
print(f"\n" + "-" * 60)
|
||||
print("SAMPLE ISSUES (first 10)")
|
||||
print("-" * 60)
|
||||
logger.info("-" * 60)
|
||||
logger.info("SAMPLE ISSUES (first 10)")
|
||||
logger.info("-" * 60)
|
||||
|
||||
for issue in report['issues'][:10]:
|
||||
print(f"\n[{issue['doc_id']}] {issue['field']}")
|
||||
print(f" CSV value: {issue['csv_value']}")
|
||||
print(f" Reason: {issue['reason']}")
|
||||
logger.info("[%s] %s", issue['doc_id'], issue['field'])
|
||||
logger.info(" CSV value: %s", issue['csv_value'])
|
||||
logger.info(" Reason: %s", issue['reason'])
|
||||
|
||||
if issue.get('details'):
|
||||
details = issue['details']
|
||||
if details.get('normalized_candidates'):
|
||||
print(f" Candidates: {details['normalized_candidates'][:5]}")
|
||||
logger.info(" Candidates: %s", details['normalized_candidates'][:5])
|
||||
if details.get('pdf_tokens_sample'):
|
||||
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
|
||||
logger.info(" PDF samples: %s", details['pdf_tokens_sample'][:5])
|
||||
if details.get('potential_matches'):
|
||||
print(f" Potential matches:")
|
||||
logger.info(" Potential matches:")
|
||||
for pm in details['potential_matches'][:3]:
|
||||
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
|
||||
logger.info(" - token='%s' matches candidate='%s'", pm['token'], pm['candidate'])
|
||||
|
||||
|
||||
def main():
|
||||
@@ -551,6 +555,9 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
analyzer = LabelAnalyzer(
|
||||
csv_path=args.csv,
|
||||
pdf_dir=args.pdf_dir,
|
||||
@@ -566,30 +573,30 @@ def main():
|
||||
|
||||
analysis = analyzer.analyze_document(args.single)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Document: {analysis.doc_id}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"PDF exists: {analysis.pdf_exists}")
|
||||
print(f"PDF type: {analysis.pdf_type}")
|
||||
print(f"Pages: {analysis.total_pages}")
|
||||
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Document: %s", analysis.doc_id)
|
||||
logger.info("=" * 60)
|
||||
logger.info("PDF exists: %s", analysis.pdf_exists)
|
||||
logger.info("PDF type: %s", analysis.pdf_type)
|
||||
logger.info("Pages: %d", analysis.total_pages)
|
||||
logger.info("Fields (CSV: %d, Labeled: %d):", analysis.csv_fields_count, analysis.labeled_fields_count)
|
||||
|
||||
for f in analysis.fields:
|
||||
status = "✓" if f.labeled else ("✗" if f.expected else "-")
|
||||
status = "[OK]" if f.labeled else ("[FAIL]" if f.expected else "[-]")
|
||||
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
|
||||
print(f" [{status}] {f.field_name}: {value_str}")
|
||||
logger.info(" %s %s: %s", status, f.field_name, value_str)
|
||||
|
||||
if f.failure_reason:
|
||||
print(f" Reason: {f.failure_reason}")
|
||||
logger.info(" Reason: %s", f.failure_reason)
|
||||
if f.details.get('normalized_candidates'):
|
||||
print(f" Candidates: {f.details['normalized_candidates']}")
|
||||
logger.info(" Candidates: %s", f.details['normalized_candidates'])
|
||||
if f.details.get('potential_matches'):
|
||||
print(f" Potential matches in PDF:")
|
||||
logger.info(" Potential matches in PDF:")
|
||||
for pm in f.details['potential_matches'][:3]:
|
||||
print(f" - '{pm['token']}'")
|
||||
logger.info(" - '%s'", pm['token'])
|
||||
else:
|
||||
# Full analysis
|
||||
print("Running label analysis...")
|
||||
logger.info("Running label analysis...")
|
||||
results = analyzer.run_analysis(limit=args.limit)
|
||||
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
|
||||
print_summary(report)
|
||||
|
||||
@@ -7,11 +7,15 @@ Generates statistics and insights from database or autolabel_report.jsonl
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
@@ -147,9 +151,9 @@ def load_reports_from_file(report_path: str) -> list[dict]:
|
||||
if not report_files:
|
||||
return []
|
||||
|
||||
print(f"Reading {len(report_files)} report file(s):")
|
||||
logger.info("Reading %d report file(s):", len(report_files))
|
||||
for f in report_files:
|
||||
print(f" - {f.name}")
|
||||
logger.info(" - %s", f.name)
|
||||
|
||||
reports = []
|
||||
for report_file in report_files:
|
||||
@@ -231,55 +235,55 @@ def analyze_reports(reports: list[dict]) -> dict:
|
||||
|
||||
def print_report(stats: dict, verbose: bool = False):
|
||||
"""Print analysis report."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AUTO-LABEL REPORT ANALYSIS")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("AUTO-LABEL REPORT ANALYSIS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Overall stats
|
||||
print(f"\n{'OVERALL STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "OVERALL STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
total = stats['total']
|
||||
successful = stats['successful']
|
||||
failed = stats['failed']
|
||||
success_rate = successful / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"Total documents: {total:>8}")
|
||||
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
|
||||
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
|
||||
logger.info("Total documents: %8d", total)
|
||||
logger.info("Successful: %8d (%.1f%%)", successful, success_rate)
|
||||
logger.info("Failed: %8d (%.1f%%)", failed, 100-success_rate)
|
||||
|
||||
# Processing time
|
||||
if 'processing_time_stats' in stats:
|
||||
pts = stats['processing_time_stats']
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {pts['avg_ms']:>8.1f}")
|
||||
print(f" Min: {pts['min_ms']:>8.1f}")
|
||||
print(f" Max: {pts['max_ms']:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", pts['avg_ms'])
|
||||
logger.info(" Min: %8.1f", pts['min_ms'])
|
||||
logger.info(" Max: %8.1f", pts['max_ms'])
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {avg_time:>8.1f}")
|
||||
print(f" Min: {min_time:>8.1f}")
|
||||
print(f" Max: {max_time:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", avg_time)
|
||||
logger.info(" Min: %8.1f", min_time)
|
||||
logger.info(" Max: %8.1f", max_time)
|
||||
|
||||
# By PDF type
|
||||
print(f"\n{'BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-15s %10s %10s %10s", 'Type', 'Total', 'Success', 'Rate')
|
||||
logger.info("-" * 60)
|
||||
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
|
||||
type_total = type_stats['total']
|
||||
type_success = type_stats['successful']
|
||||
type_rate = type_success / type_total * 100 if type_total > 0 else 0
|
||||
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
|
||||
logger.info("%-15s %10d %10d %9.1f%%", pdf_type, type_total, type_success, type_rate)
|
||||
|
||||
# By field
|
||||
print(f"\n{'FIELD MATCH STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-18s %7s %7s %7s %7s %7s %8s", 'Field', 'Total', 'Match', 'Rate', 'Exact', 'Flex', 'AvgScore')
|
||||
logger.info("-" * 60)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -299,16 +303,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
else:
|
||||
avg_score = 0
|
||||
|
||||
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
|
||||
logger.info("%-18s %7d %7d %6.1f%% %7d %7d %8.3f", field_name, total, matched, rate, exact, flex, avg_score)
|
||||
|
||||
# Field match by PDF type
|
||||
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
|
||||
for pdf_type in sorted(stats['by_pdf_type'].keys()):
|
||||
print(f"\n[{pdf_type.upper()}]")
|
||||
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
|
||||
print("-" * 50)
|
||||
logger.info("[%s]", pdf_type.upper())
|
||||
logger.info("%-18s %10s %10s %10s", 'Field', 'Total', 'Matched', 'Rate')
|
||||
logger.info("-" * 50)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -317,16 +321,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
total = type_stats['total']
|
||||
matched = type_stats['matched']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
|
||||
logger.info("%-18s %10d %10d %9.1f%%", field_name, total, matched, rate)
|
||||
|
||||
# Errors
|
||||
if stats.get('errors') and verbose:
|
||||
print(f"\n{'ERRORS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "ERRORS".center(60))
|
||||
logger.info("-" * 60)
|
||||
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
|
||||
print(f"{count:>5}x {error[:50]}")
|
||||
logger.info("%5dx %s", count, error[:50])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
def export_json(stats: dict, output_path: str):
|
||||
@@ -372,7 +376,7 @@ def export_json(stats: dict, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nStatistics exported to: {output_path}")
|
||||
logger.info("Statistics exported to: %s", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -401,25 +405,28 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Decide source
|
||||
use_db = not args.from_file and args.report is None
|
||||
|
||||
if use_db:
|
||||
print("Loading statistics from database...")
|
||||
logger.info("Loading statistics from database...")
|
||||
stats = load_reports_from_db()
|
||||
print(f"Loaded stats for {stats['total']} documents")
|
||||
logger.info("Loaded stats for %d documents", stats['total'])
|
||||
else:
|
||||
report_path = args.report or 'reports/autolabel_report.jsonl'
|
||||
path = Path(report_path)
|
||||
|
||||
# Check if file exists (handle glob patterns)
|
||||
if '*' not in str(path) and '?' not in str(path) and not path.exists():
|
||||
print(f"Error: Report file not found: {path}")
|
||||
logger.error("Report file not found: %s", path)
|
||||
return 1
|
||||
|
||||
print(f"Loading reports from: {report_path}")
|
||||
logger.info("Loading reports from: %s", report_path)
|
||||
reports = load_reports_from_file(report_path)
|
||||
print(f"Loaded {len(reports)} reports")
|
||||
logger.info("Loaded %d reports", len(reports))
|
||||
stats = analyze_reports(reports)
|
||||
|
||||
print_report(stats, verbose=args.verbose)
|
||||
|
||||
@@ -6,6 +6,7 @@ Generates YOLO training data from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
@@ -17,6 +18,10 @@ from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
import multiprocessing
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag for graceful shutdown
|
||||
_shutdown_requested = False
|
||||
|
||||
@@ -25,8 +30,8 @@ def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals for graceful shutdown."""
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||
print("(Press Ctrl+C again to force quit)\n")
|
||||
logger.warning("Shutdown requested. Finishing current batch and saving progress...")
|
||||
logger.warning("(Press Ctrl+C again to force quit)")
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
@@ -350,11 +355,14 @@ def main():
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
logger.info("Found %d CSV file(s) to process", len(csv_files))
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
@@ -371,7 +379,7 @@ def main():
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
db.create_tables() # Ensure tables exist
|
||||
print("Connected to database for status checking")
|
||||
logger.info("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
@@ -443,7 +451,7 @@ def main():
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
logger.error("Error processing %s: %s", doc_id, error)
|
||||
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
@@ -453,7 +461,7 @@ def main():
|
||||
from training.processing import DualPoolCoordinator
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
logger.info("Starting dual-pool mode: %d CPU + %d GPU workers", args.cpu_workers, args.gpu_workers)
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
@@ -467,10 +475,10 @@ def main():
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
# Check for shutdown request
|
||||
if _shutdown_requested:
|
||||
print("\nShutdown requested. Stopping after current batch...")
|
||||
logger.warning("Shutdown requested. Stopping after current batch...")
|
||||
break
|
||||
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
logger.info("[%d/%d] Processing: %s", csv_idx + 1, len(csv_files), csv_file.name)
|
||||
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
@@ -488,7 +496,7 @@ def main():
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
logger.info(" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
@@ -500,13 +508,13 @@ def main():
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
logger.info(" Skipping CSV (all %d documents already processed)", len(rows))
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
logger.info(" Found %d new documents to process (%d already in DB)", new_to_process, already_processed)
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
@@ -520,7 +528,7 @@ def main():
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
logger.info(" Reached limit of %d new documents, stopping.", args.limit)
|
||||
break
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
@@ -583,7 +591,7 @@ def main():
|
||||
))
|
||||
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
logger.info(" Skipped %d (already in DB), retrying %d failed", skipped_in_csv, retry_in_csv)
|
||||
|
||||
# Clean up retry documents: delete from database and remove temp folders
|
||||
if retry_doc_ids:
|
||||
@@ -599,7 +607,7 @@ def main():
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||
logger.info(" Cleaned up %d retry documents (DB + temp folders)", len(retry_doc_ids))
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
@@ -636,7 +644,7 @@ def main():
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
logger.info(" Text PDFs: %d, Scanned PDFs: %d", text_count, scan_count)
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
@@ -667,11 +675,11 @@ def main():
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
logger.info(" Batch complete: %d successful, %d failed", successful, failed)
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info(" Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
@@ -725,28 +733,28 @@ def main():
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Auto-labeling Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Auto-labeling Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total documents: %d", stats['total'])
|
||||
logger.info("Successful: %d", stats['successful'])
|
||||
logger.info("Failed: %d", stats['failed'])
|
||||
logger.info("Skipped (no PDF): %d", stats['skipped'])
|
||||
logger.info("Skipped (in DB): %d", stats['skipped_db'])
|
||||
logger.info("Retried (failed): %d", stats['retried'])
|
||||
logger.info("Total annotations: %d", stats['annotations'])
|
||||
logger.info("Images saved to: %s", output_dir / 'temp')
|
||||
logger.info("Labels stored in: PostgreSQL database")
|
||||
logger.info("Annotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
logger.info("Report files (%d):", len(shard_files))
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
logger.info(" - %s", sf)
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
logger.info("Report: %s", shard_files[0] if shard_files else args.report)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +17,9 @@ from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
from shared.config import get_db_connection_string, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
@@ -150,7 +154,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Warning: Line {line_no} - JSON parse error: {e}")
|
||||
logger.warning("Line %d - JSON parse error: %s", line_no, e)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
@@ -211,7 +215,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
# Flush batch if needed
|
||||
if len(doc_batch) >= batch_size:
|
||||
flush_batches()
|
||||
print(f" Processed {stats['imported'] + stats['skipped']} records...")
|
||||
logger.info(" Processed %d records...", stats['imported'] + stats['skipped'])
|
||||
|
||||
# Final flush
|
||||
flush_batches()
|
||||
@@ -243,11 +247,14 @@ def main():
|
||||
else:
|
||||
report_files = [report_path] if report_path.exists() else []
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
if not report_files:
|
||||
print(f"No report files found: {args.report}")
|
||||
logger.error("No report files found: %s", args.report)
|
||||
return
|
||||
|
||||
print(f"Found {len(report_files)} report file(s)")
|
||||
logger.info("Found %d report file(s)", len(report_files))
|
||||
|
||||
# Connect to database
|
||||
conn = psycopg2.connect(db_connection)
|
||||
@@ -257,20 +264,20 @@ def main():
|
||||
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
for report_file in report_files:
|
||||
print(f"\nImporting: {report_file.name}")
|
||||
logger.info("Importing: %s", report_file.name)
|
||||
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
|
||||
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
|
||||
logger.info(" Imported: %d, Skipped: %d, Errors: %d", stats['imported'], stats['skipped'], stats['errors'])
|
||||
|
||||
for key in total_stats:
|
||||
total_stats[key] += stats[key]
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Import Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total imported: {total_stats['imported']}")
|
||||
print(f"Total skipped: {total_stats['skipped']}")
|
||||
print(f"Total errors: {total_stats['errors']}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Import Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total imported: %d", total_stats['imported'])
|
||||
logger.info("Total skipped: %d", total_stats['skipped'])
|
||||
logger.info("Total errors: %d", total_stats['errors'])
|
||||
|
||||
# Quick stats from database
|
||||
with conn.cursor() as cursor:
|
||||
@@ -288,11 +295,11 @@ def main():
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"\nDatabase Stats:")
|
||||
print(f" Documents: {total_docs} ({success_docs} successful)")
|
||||
print(f" Field results: {total_fields} ({matched_fields} matched)")
|
||||
logger.info("Database Stats:")
|
||||
logger.info(" Documents: %d (%d successful)", total_docs, success_docs)
|
||||
logger.info(" Field results: %d (%d matched)", total_fields, matched_fields)
|
||||
if total_fields > 0:
|
||||
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
|
||||
logger.info(" Match rate: %.2f%%", matched_fields / total_fields * 100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -7,6 +7,7 @@ CSV values, and source CSV filename in a new table.
|
||||
import argparse
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -20,6 +21,9 @@ from shared.config import DEFAULT_DPI
|
||||
from shared.data.db import DocumentDB
|
||||
from shared.data.csv_loader import CSVLoader
|
||||
from shared.normalize.normalizer import normalize_field
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_failed_match_table(db: DocumentDB):
|
||||
@@ -57,7 +61,7 @@ def create_failed_match_table(db: DocumentDB):
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
print("Created table: failed_match_details")
|
||||
logger.info("Created table: failed_match_details")
|
||||
|
||||
|
||||
def get_failed_documents(db: DocumentDB) -> list:
|
||||
@@ -332,14 +336,17 @@ def main():
|
||||
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Expand CSV glob
|
||||
csv_files = sorted(glob.glob(args.csv))
|
||||
print(f"Found {len(csv_files)} CSV files")
|
||||
logger.info("Found %d CSV files", len(csv_files))
|
||||
|
||||
# Build CSV cache
|
||||
print("Building CSV filename cache...")
|
||||
logger.info("Building CSV filename cache...")
|
||||
build_csv_cache(csv_files)
|
||||
print(f"Cached {len(_csv_cache)} document IDs")
|
||||
logger.info("Cached %d document IDs", len(_csv_cache))
|
||||
|
||||
# Connect to database
|
||||
db = DocumentDB()
|
||||
@@ -349,13 +356,13 @@ def main():
|
||||
create_failed_match_table(db)
|
||||
|
||||
# Get all failed documents
|
||||
print("Fetching failed documents...")
|
||||
logger.info("Fetching failed documents...")
|
||||
failed_docs = get_failed_documents(db)
|
||||
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||
logger.info("Found %d documents with failed matches", len(failed_docs))
|
||||
|
||||
if args.limit:
|
||||
failed_docs = failed_docs[:args.limit]
|
||||
print(f"Limited to {len(failed_docs)} documents")
|
||||
logger.info("Limited to %d documents", len(failed_docs))
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
@@ -365,7 +372,7 @@ def main():
|
||||
if failed_fields:
|
||||
tasks.append((doc, failed_fields, csv_filename))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info("Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process with multiprocessing
|
||||
total_results = 0
|
||||
@@ -389,15 +396,15 @@ def main():
|
||||
batch_results = []
|
||||
|
||||
except TimeoutError:
|
||||
print(f"\nTimeout processing {doc_id}")
|
||||
logger.warning("Timeout processing %s", doc_id)
|
||||
except Exception as e:
|
||||
print(f"\nError processing {doc_id}: {e}")
|
||||
logger.error("Error processing %s: %s", doc_id, e)
|
||||
|
||||
# Save remaining results
|
||||
if batch_results:
|
||||
save_results_batch(db, batch_results)
|
||||
|
||||
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||
logger.info("Done! Saved %d failed match records to failed_match_details table", total_results)
|
||||
|
||||
# Show summary
|
||||
conn = db.connect()
|
||||
@@ -410,12 +417,12 @@ def main():
|
||||
GROUP BY field_name
|
||||
ORDER BY total DESC
|
||||
""")
|
||||
print("\nSummary by field:")
|
||||
print("-" * 70)
|
||||
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||
print("-" * 70)
|
||||
logger.info("Summary by field:")
|
||||
logger.info("-" * 70)
|
||||
logger.info("%-35s %8s %10s %12s", 'Field', 'Total', 'Has OCR', 'Avg Score')
|
||||
logger.info("-" * 70)
|
||||
for row in cursor.fetchall():
|
||||
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||
logger.info("%-35s %8d %10d %12.2f", row[0], row[1], row[2], row[3])
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -7,10 +7,14 @@ Images are read from filesystem, labels are dynamically generated from DB.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -119,47 +123,50 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Apply low-memory mode if specified
|
||||
if args.low_memory:
|
||||
print("🔧 Low memory mode enabled")
|
||||
logger.info("Low memory mode enabled")
|
||||
args.batch = min(args.batch, 8) # Reduce from 16 to 8
|
||||
args.workers = min(args.workers, 4) # Reduce from 8 to 4
|
||||
args.cache = False
|
||||
print(f" Batch size: {args.batch}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Cache: disabled")
|
||||
logger.info(" Batch size: %d", args.batch)
|
||||
logger.info(" Workers: %d", args.workers)
|
||||
logger.info(" Cache: disabled")
|
||||
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Error: Temp directory not found: {temp_dir}")
|
||||
print("Run autolabel first to generate images.")
|
||||
logger.error("Temp directory not found: %s", temp_dir)
|
||||
logger.error("Run autolabel first to generate images.")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print("YOLO Training with Database Labels")
|
||||
print("=" * 60)
|
||||
print(f"Dataset dir: {dataset_dir}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch}")
|
||||
print(f"Image size: {args.imgsz}")
|
||||
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("YOLO Training with Database Labels")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Dataset dir: %s", dataset_dir)
|
||||
logger.info("Model: %s", args.model)
|
||||
logger.info("Epochs: %d", args.epochs)
|
||||
logger.info("Batch size: %d", args.batch)
|
||||
logger.info("Image size: %d", args.imgsz)
|
||||
logger.info("Split ratio: %s/%s/%.1f", args.train_ratio, args.val_ratio, 1-args.train_ratio-args.val_ratio)
|
||||
if args.limit:
|
||||
print(f"Document limit: {args.limit}")
|
||||
logger.info("Document limit: %d", args.limit)
|
||||
|
||||
# Connect to database
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
print("\nConnecting to database...")
|
||||
logger.info("Connecting to database...")
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create datasets from database
|
||||
from training.yolo.db_dataset import create_datasets
|
||||
|
||||
print("Loading dataset from database...")
|
||||
logger.info("Loading dataset from database...")
|
||||
datasets = create_datasets(
|
||||
images_dir=dataset_dir,
|
||||
db=db,
|
||||
@@ -170,39 +177,39 @@ def main():
|
||||
limit=args.limit
|
||||
)
|
||||
|
||||
print(f"\nDataset splits:")
|
||||
print(f" Train: {len(datasets['train'])} items")
|
||||
print(f" Val: {len(datasets['val'])} items")
|
||||
print(f" Test: {len(datasets['test'])} items")
|
||||
logger.info("Dataset splits:")
|
||||
logger.info(" Train: %d items", len(datasets['train']))
|
||||
logger.info(" Val: %d items", len(datasets['val']))
|
||||
logger.info(" Test: %d items", len(datasets['test']))
|
||||
|
||||
if len(datasets['train']) == 0:
|
||||
print("\nError: No training data found!")
|
||||
print("Make sure autolabel has been run and images exist in temp directory.")
|
||||
logger.error("No training data found!")
|
||||
logger.error("Make sure autolabel has been run and images exist in temp directory.")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Export to YOLO format (required for Ultralytics training)
|
||||
print("\nExporting dataset to YOLO format...")
|
||||
logger.info("Exporting dataset to YOLO format...")
|
||||
for split_name, dataset in datasets.items():
|
||||
count = dataset.export_to_yolo_format(dataset_dir, split_name)
|
||||
print(f" {split_name}: {count} items exported")
|
||||
logger.info(" %s: %d items exported", split_name, count)
|
||||
|
||||
# Generate YOLO config files
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
|
||||
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
|
||||
logger.info("Generated dataset.yaml at: %s", dataset_dir / 'dataset.yaml')
|
||||
|
||||
if args.export_only:
|
||||
print("\nExport complete (--export-only specified, skipping training)")
|
||||
logger.info("Export complete (--export-only specified, skipping training)")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training using shared trainer
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("Starting YOLO Training")
|
||||
logger.info("=" * 60)
|
||||
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
@@ -232,30 +239,30 @@ def main():
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
print(f"\nError: Training failed - {result.error}")
|
||||
logger.error("Training failed - %s", result.error)
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {result.model_path}")
|
||||
print(f"Save directory: {result.save_dir}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Training Complete")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Best model: %s", result.model_path)
|
||||
logger.info("Save directory: %s", result.save_dir)
|
||||
if result.metrics:
|
||||
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||
logger.info("mAP@0.5: %s", result.metrics.get('mAP50', 'N/A'))
|
||||
logger.info("mAP@0.5-0.95: %s", result.metrics.get('mAP50-95', 'N/A'))
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
logger.info("Running validation on test set...")
|
||||
if result.model_path:
|
||||
config.model_path = result.model_path
|
||||
config.data_yaml = str(data_yaml)
|
||||
test_trainer = YOLOTrainer(config=config)
|
||||
test_metrics = test_trainer.validate(split='test')
|
||||
if test_metrics:
|
||||
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||
logger.info("mAP50: %.4f", test_metrics.get('mAP50', 0))
|
||||
logger.info("mAP50-95: %.4f", test_metrics.get('mAP50-95', 0))
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
@@ -7,9 +7,14 @@ and comparing the extraction results.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
@@ -73,6 +78,9 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
from backend.validation import LLMValidator
|
||||
|
||||
validator = LLMValidator()
|
||||
@@ -104,60 +112,58 @@ def show_stats(validator):
|
||||
"""Show statistics about failed matches."""
|
||||
stats = validator.get_failed_match_stats()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Failed Match Statistics")
|
||||
print("=" * 50)
|
||||
print(f"\nDocuments with failures: {stats['documents_with_failures']}")
|
||||
print(f"Already validated: {stats['already_validated']}")
|
||||
print(f"Remaining to validate: {stats['remaining']}")
|
||||
print("\nFailures by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Failed Match Statistics")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Documents with failures: %d", stats['documents_with_failures'])
|
||||
logger.info("Already validated: %d", stats['already_validated'])
|
||||
logger.info("Remaining to validate: %d", stats['remaining'])
|
||||
logger.info("Failures by field:")
|
||||
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
|
||||
|
||||
def validate_single(validator, doc_id: str, provider: str, model: str):
|
||||
"""Validate a single document."""
|
||||
print(f"\nValidating document: {doc_id}")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating document: %s", doc_id)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
result = validator.validate_document(doc_id, provider, model)
|
||||
|
||||
if result.error:
|
||||
print(f"ERROR: {result.error}")
|
||||
logger.error("ERROR: %s", result.error)
|
||||
return
|
||||
|
||||
print(f"Processing time: {result.processing_time_ms:.0f}ms")
|
||||
print(f"Model used: {result.model_used}")
|
||||
print("\nExtracted fields:")
|
||||
print(f" Invoice Number: {result.invoice_number}")
|
||||
print(f" Invoice Date: {result.invoice_date}")
|
||||
print(f" Due Date: {result.invoice_due_date}")
|
||||
print(f" OCR: {result.ocr_number}")
|
||||
print(f" Bankgiro: {result.bankgiro}")
|
||||
print(f" Plusgiro: {result.plusgiro}")
|
||||
print(f" Amount: {result.amount}")
|
||||
print(f" Org Number: {result.supplier_organisation_number}")
|
||||
logger.info("Processing time: %.0fms", result.processing_time_ms)
|
||||
logger.info("Model used: %s", result.model_used)
|
||||
logger.info("Extracted fields:")
|
||||
logger.info(" Invoice Number: %s", result.invoice_number)
|
||||
logger.info(" Invoice Date: %s", result.invoice_date)
|
||||
logger.info(" Due Date: %s", result.invoice_due_date)
|
||||
logger.info(" OCR: %s", result.ocr_number)
|
||||
logger.info(" Bankgiro: %s", result.bankgiro)
|
||||
logger.info(" Plusgiro: %s", result.plusgiro)
|
||||
logger.info(" Amount: %s", result.amount)
|
||||
logger.info(" Org Number: %s", result.supplier_organisation_number)
|
||||
|
||||
# Show comparison
|
||||
print("\n" + "-" * 50)
|
||||
print("Comparison with autolabel:")
|
||||
logger.info("-" * 50)
|
||||
logger.info("Comparison with autolabel:")
|
||||
comparison = validator.compare_results(doc_id)
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value'):
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
print(f" {status} {field}:")
|
||||
print(f" CSV: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM: {data['llm_value']}")
|
||||
logger.info(" %s %s:", status, field)
|
||||
logger.info(" CSV: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM: %s", data['llm_value'])
|
||||
|
||||
|
||||
def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
"""Validate a batch of documents."""
|
||||
print(f"\nValidating up to {limit} documents with failed matches")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating up to %d documents with failed matches", limit)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
results = validator.validate_batch(
|
||||
limit=limit,
|
||||
@@ -171,15 +177,15 @@ def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
failed = len(results) - success
|
||||
total_time = sum(r.processing_time_ms or 0 for r in results)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Validation Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total: {len(results)}")
|
||||
print(f"Success: {success}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Total time: {total_time/1000:.1f}s")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Validation Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total: %d", len(results))
|
||||
logger.info("Success: %d", success)
|
||||
logger.info("Failed: %d", failed)
|
||||
logger.info("Total time: %.1fs", total_time/1000)
|
||||
if success > 0:
|
||||
print(f"Avg time: {total_time/success:.0f}ms per document")
|
||||
logger.info("Avg time: %.0fms per document", total_time/success)
|
||||
|
||||
|
||||
def compare_single(validator, doc_id: str):
|
||||
@@ -187,23 +193,23 @@ def compare_single(validator, doc_id: str):
|
||||
comparison = validator.compare_results(doc_id)
|
||||
|
||||
if 'error' in comparison:
|
||||
print(f"Error: {comparison['error']}")
|
||||
logger.error("Error: %s", comparison['error'])
|
||||
return
|
||||
|
||||
print(f"\nComparison for document: {doc_id}")
|
||||
print("=" * 60)
|
||||
logger.info("Comparison for document: %s", doc_id)
|
||||
logger.info("=" * 60)
|
||||
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value') is None:
|
||||
continue
|
||||
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
|
||||
print(f"\n{status} {field}:")
|
||||
print(f" CSV value: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM extracted: {data['llm_value']}")
|
||||
logger.info("%s %s:", status, field)
|
||||
logger.info(" CSV value: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM extracted: %s", data['llm_value'])
|
||||
|
||||
|
||||
def compare_all(validator, limit: int):
|
||||
@@ -220,11 +226,11 @@ def compare_all(validator, limit: int):
|
||||
doc_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not doc_ids:
|
||||
print("No validated documents found.")
|
||||
logger.info("No validated documents found.")
|
||||
return
|
||||
|
||||
print(f"\nComparison Summary ({len(doc_ids)} documents)")
|
||||
print("=" * 80)
|
||||
logger.info("Comparison Summary (%d documents)", len(doc_ids))
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Aggregate stats
|
||||
field_stats = {}
|
||||
@@ -259,12 +265,12 @@ def compare_all(validator, limit: int):
|
||||
if not data['autolabel_matched'] and data['agreement']:
|
||||
stats['llm_correct_auto_wrong'] += 1
|
||||
|
||||
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}")
|
||||
print("-" * 80)
|
||||
logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
|
||||
logger.info("-" * 80)
|
||||
|
||||
for field, stats in sorted(field_stats.items()):
|
||||
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} "
|
||||
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}")
|
||||
logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
|
||||
stats['llm_agrees'], stats['llm_correct_auto_wrong'])
|
||||
|
||||
|
||||
def generate_report(validator, output_path: str):
|
||||
@@ -328,8 +334,8 @@ def generate_report(validator, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport generated: {output_path}")
|
||||
print(f"Total validations: {len(validations)}")
|
||||
logger.info("Report generated: %s", output_path)
|
||||
logger.info("Total validations: %d", len(validations))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -18,7 +18,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES, CLASS_NAMES
|
||||
from shared.bbox import expand_bbox
|
||||
from .annotation_generator import YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +157,7 @@ class DBYOLODataset:
|
||||
|
||||
# Split items for this split
|
||||
instance.items = instance._split_dataset_from_cache()
|
||||
print(f"Split '{split}': {len(instance.items)} items")
|
||||
logger.info("Split '%s': %d items", split, len(instance.items))
|
||||
|
||||
return instance
|
||||
|
||||
@@ -165,7 +166,7 @@ class DBYOLODataset:
|
||||
# Find all document directories
|
||||
temp_dir = self.images_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Temp directory not found: {temp_dir}")
|
||||
logger.warning("Temp directory not found: %s", temp_dir)
|
||||
return
|
||||
|
||||
# Collect all document IDs with images
|
||||
@@ -182,13 +183,13 @@ class DBYOLODataset:
|
||||
if images:
|
||||
doc_image_map[doc_dir.name] = sorted(images)
|
||||
|
||||
print(f"Found {len(doc_image_map)} documents with images")
|
||||
logger.info("Found %d documents with images", len(doc_image_map))
|
||||
|
||||
# Query database for all document labels
|
||||
doc_ids = list(doc_image_map.keys())
|
||||
doc_labels = self._load_labels_from_db(doc_ids)
|
||||
|
||||
print(f"Loaded labels for {len(doc_labels)} documents from database")
|
||||
logger.info("Loaded labels for %d documents from database", len(doc_labels))
|
||||
|
||||
# Build dataset items
|
||||
all_items: list[DatasetItem] = []
|
||||
@@ -227,19 +228,19 @@ class DBYOLODataset:
|
||||
else:
|
||||
skipped_no_labels += 1
|
||||
|
||||
print(f"Total images found: {total_images}")
|
||||
print(f"Images with labels: {len(all_items)}")
|
||||
logger.info("Total images found: %d", total_images)
|
||||
logger.info("Images with labels: %d", len(all_items))
|
||||
if skipped_no_db_record > 0:
|
||||
print(f"Skipped {skipped_no_db_record} images (document not in database)")
|
||||
logger.info("Skipped %d images (document not in database)", skipped_no_db_record)
|
||||
if skipped_no_labels > 0:
|
||||
print(f"Skipped {skipped_no_labels} images (no labels for page)")
|
||||
logger.info("Skipped %d images (no labels for page)", skipped_no_labels)
|
||||
|
||||
# Cache all items for sharing with other splits
|
||||
self._all_items = all_items
|
||||
|
||||
# Split dataset
|
||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||
print(f"Split '{self.split}': {len(self.items)} items")
|
||||
logger.info("Split '%s': %d items", self.split, len(self.items))
|
||||
|
||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
|
||||
"""
|
||||
@@ -374,7 +375,7 @@ class DBYOLODataset:
|
||||
|
||||
if has_csv_splits:
|
||||
# Use CSV-defined splits
|
||||
print("Using CSV-defined split field for train/val/test assignment")
|
||||
logger.info("Using CSV-defined split field for train/val/test assignment")
|
||||
|
||||
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
|
||||
# 'val' is taken from train set using val_ratio
|
||||
@@ -411,11 +412,11 @@ class DBYOLODataset:
|
||||
# Apply limit if specified
|
||||
if self.limit is not None and self.limit < len(split_doc_ids):
|
||||
split_doc_ids = split_doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
else:
|
||||
# Fall back to random splitting (original behavior)
|
||||
print("No CSV split field found, using random splitting")
|
||||
logger.info("No CSV split field found, using random splitting")
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(doc_ids)
|
||||
@@ -423,7 +424,7 @@ class DBYOLODataset:
|
||||
# Apply limit if specified (before splitting)
|
||||
if self.limit is not None and self.limit < len(doc_ids):
|
||||
doc_ids = doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
@@ -549,6 +550,8 @@ class DBYOLODataset:
|
||||
"""
|
||||
Convert annotations to normalized YOLO format.
|
||||
|
||||
Uses field-specific bbox expansion strategies via expand_bbox.
|
||||
|
||||
Args:
|
||||
annotations: List of annotations
|
||||
img_width: Actual image width in pixels
|
||||
@@ -568,26 +571,43 @@ class DBYOLODataset:
|
||||
|
||||
labels = []
|
||||
for ann in annotations:
|
||||
# Convert to pixels (if needed)
|
||||
x_center_px = ann.x_center * scale
|
||||
y_center_px = ann.y_center * scale
|
||||
width_px = ann.width * scale
|
||||
height_px = ann.height * scale
|
||||
# Convert center+size to corner coords in PDF points
|
||||
half_w = ann.width / 2
|
||||
half_h = ann.height / 2
|
||||
x0_pdf = ann.x_center - half_w
|
||||
y0_pdf = ann.y_center - half_h
|
||||
x1_pdf = ann.x_center + half_w
|
||||
y1_pdf = ann.y_center + half_h
|
||||
|
||||
# Add padding
|
||||
pad = self.bbox_padding_px
|
||||
width_px += 2 * pad
|
||||
height_px += 2 * pad
|
||||
# Convert to pixels
|
||||
x0_px = x0_pdf * scale
|
||||
y0_px = y0_pdf * scale
|
||||
x1_px = x1_pdf * scale
|
||||
y1_px = y1_pdf * scale
|
||||
|
||||
# Get class name for field-specific expansion
|
||||
class_name = CLASS_NAMES[ann.class_id]
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0_px, y0_px, x1_px, y1_px),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
height_px = y1 - y0
|
||||
if height_px < self.min_bbox_height_px:
|
||||
height_px = self.min_bbox_height_px
|
||||
extra = (self.min_bbox_height_px - height_px) / 2
|
||||
y0 = max(0, int(y0 - extra))
|
||||
y1 = min(img_height, int(y1 + extra))
|
||||
|
||||
# Normalize to 0-1
|
||||
x_center = x_center_px / img_width
|
||||
y_center = y_center_px / img_height
|
||||
width = width_px / img_width
|
||||
height = height_px / img_height
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / img_width
|
||||
y_center = (y0 + y1) / 2 / img_height
|
||||
width = (x1 - x0) / img_width
|
||||
height = (y1 - y0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
x_center = max(0, min(1, x_center))
|
||||
@@ -675,7 +695,7 @@ class DBYOLODataset:
|
||||
|
||||
count += 1
|
||||
|
||||
print(f"Exported {count} items to {output_dir / split_name}")
|
||||
logger.info("Exported %d items to %s", count, output_dir / split_name)
|
||||
return count
|
||||
|
||||
|
||||
@@ -706,7 +726,7 @@ def create_datasets(
|
||||
Dict with 'train', 'val', 'test' datasets
|
||||
"""
|
||||
# Create first dataset which loads all data
|
||||
print("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
logger.info("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
first_dataset = DBYOLODataset(
|
||||
images_dir=images_dir,
|
||||
db=db,
|
||||
|
||||
Reference in New Issue
Block a user