This commit is contained in:
Yaojia Wang
2026-02-07 13:56:00 +01:00
parent 0990239e9c
commit f1a7bfe6b7
16 changed files with 1121 additions and 307 deletions

View File

@@ -109,7 +109,9 @@
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")", "Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")",
"Skill(tdd)", "Skill(tdd)",
"Skill(tdd:*)" "Skill(tdd:*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m training.cli.train --model runs/train/invoice_fields/weights/best.pt --device 0 --epochs 100\")",
"Bash(git commit -m \"$\\(cat <<''EOF''\nfeat: add field-specific bbox expansion strategies for YOLO training\n\nImplement center-point based bbox scaling with directional compensation\nto capture field labels that typically appear above or to the left of\nfield values. This improves YOLO training data quality by including\ncontextual information around field values.\n\nKey changes:\n- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function\n- Define field-specific strategies \\(ocr_number, bankgiro, invoice_date, etc.\\)\n- Support manual_mode for minimal padding \\(no scaling\\)\n- Integrate expand_bbox into AnnotationGenerator\n- Add FIELD_TO_CLASS mapping for field_name to class_name lookup\n- Comprehensive tests with 100% coverage \\(45 tests\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
], ],
"deny": [], "deny": [],
"ask": [], "ask": [],

View File

@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
import argparse import argparse
import json import json
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main(): def main():
@@ -66,10 +70,13 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Validate model # Validate model
model_path = Path(args.model) model_path = Path(args.model)
if not model_path.exists(): if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr) logger.error("Model not found: %s", model_path)
sys.exit(1) sys.exit(1)
# Get input files # Get input files
@@ -79,16 +86,16 @@ def main():
elif input_path.is_dir(): elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf')) pdf_files = list(input_path.glob('*.pdf'))
else: else:
print(f"Error: Input not found: {input_path}", file=sys.stderr) logger.error("Input not found: %s", input_path)
sys.exit(1) sys.exit(1)
if not pdf_files: if not pdf_files:
print("Error: No PDF files found", file=sys.stderr) logger.error("No PDF files found")
sys.exit(1) sys.exit(1)
if args.verbose: if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)") logger.info("Processing %d PDF file(s)", len(pdf_files))
print(f"Model: {model_path}") logger.info("Model: %s", model_path)
from backend.pipeline import InferencePipeline from backend.pipeline import InferencePipeline
@@ -107,18 +114,18 @@ def main():
for pdf_path in pdf_files: for pdf_path in pdf_files:
if args.verbose: if args.verbose:
print(f"Processing: {pdf_path.name}") logger.info("Processing: %s", pdf_path.name)
result = pipeline.process_pdf(pdf_path) result = pipeline.process_pdf(pdf_path)
results.append(result.to_json()) results.append(result.to_json())
if args.verbose: if args.verbose:
print(f" Success: {result.success}") logger.info(" Success: %s", result.success)
print(f" Fields: {len(result.fields)}") logger.info(" Fields: %d", len(result.fields))
if result.fallback_used: if result.fallback_used:
print(f" Fallback used: Yes") logger.info(" Fallback used: Yes")
if result.errors: if result.errors:
print(f" Errors: {result.errors}") logger.info(" Errors: %s", result.errors)
# Output results # Output results
if len(results) == 1: if len(results) == 1:
@@ -132,9 +139,11 @@ def main():
with open(args.output, 'w', encoding='utf-8') as f: with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output) f.write(json_output)
if args.verbose: if args.verbose:
print(f"\nResults written to: {args.output}") logger.info("Results written to: %s", args.output)
else: else:
print(json_output) # Output JSON to stdout (not logged)
sys.stdout.write(json_output)
sys.stdout.write('\n')
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -417,7 +417,12 @@ class InferencePipeline:
result.errors.append(f"Business feature extraction error: {error_detail}") result.errors.append(f"Business feature extraction error: {error_detail}")
def _merge_fields(self, result: InferenceResult) -> None: def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field.""" """Merge extracted fields, keeping best candidate for each field.
Selection priority:
1. Prefer candidates without validation errors
2. Among equal validity, prefer higher confidence
"""
field_candidates: dict[str, list[ExtractedField]] = {} field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields: for extracted in result.extracted_fields:
@@ -430,7 +435,12 @@ class InferencePipeline:
# Select best candidate for each field # Select best candidate for each field
for field_name, candidates in field_candidates.items(): for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence) # Sort by: (no validation error, confidence) - descending
# This prefers candidates without errors, then by confidence
best = max(
candidates,
key=lambda x: (x.validation_error is None, x.confidence)
)
result.fields[field_name] = best.normalized_value result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence result.confidence[field_name] = best.confidence
# Store bbox for each field (useful for payment_line and other fields) # Store bbox for each field (useful for payment_line and other fields)

View File

@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
import json import json
import base64 import base64
import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
from datetime import datetime from datetime import datetime
import psycopg2 import psycopg2
logger = logging.getLogger(__name__)
from psycopg2.extras import execute_values from psycopg2.extras import execute_values
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
docs = self.get_documents_with_failed_matches(limit=limit) docs = self.get_documents_with_failed_matches(limit=limit)
if verbose: if verbose:
print(f"Found {len(docs)} documents with failed matches to validate") logger.info("Found %d documents with failed matches to validate", len(docs))
results = [] results = []
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
if verbose: if verbose:
failed_fields = [f['field'] for f in doc['failed_fields']] failed_fields = [f['field'] for f in doc['failed_fields']]
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})") logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
result = self.validate_document(doc_id, provider, model) result = self.validate_document(doc_id, provider, model)
results.append(result) results.append(result)
if verbose: if verbose:
if result.error: if result.error:
print(f" ERROR: {result.error}") logger.error(" ERROR: %s", result.error)
else: else:
print(f" OK ({result.processing_time_ms:.0f}ms)") logger.info(" OK (%.0fms)", result.processing_time_ms)
return results return results

View File

@@ -11,6 +11,7 @@ from backend.web.schemas.admin import (
ExportResponse, ExportResponse,
) )
from backend.web.schemas.common import ErrorResponse from backend.web.schemas.common import ErrorResponse
from shared.bbox import expand_bbox
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -102,12 +103,52 @@ def register_export_routes(router: APIRouter) -> None:
dst_image.write_bytes(image_content) dst_image.write_bytes(image_content)
total_images += 1 total_images += 1
# Get image dimensions for bbox expansion
img_dims = storage.get_admin_image_dimensions(doc_id, page_num)
if img_dims is None:
# Fall back to standard A4 at 300 DPI if dimensions unavailable
img_width, img_height = 2480, 3508
else:
img_width, img_height = img_dims
label_name = f"{doc.document_id}_page{page_num}.txt" label_name = f"{doc.document_id}_page{page_num}.txt"
label_path = export_dir / "labels" / split / label_name label_path = export_dir / "labels" / split / label_name
with open(label_path, "w") as f: with open(label_path, "w") as f:
for ann in page_annotations: for ann in page_annotations:
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n" # Convert normalized coords to pixel coords
half_w = (ann.width * img_width) / 2
half_h = (ann.height * img_height) / 2
x0 = ann.x_center * img_width - half_w
y0 = ann.y_center * img_height - half_h
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Use manual_mode for manual/imported annotations
manual_mode = ann.source in ("manual", "imported")
# Apply field-specific bbox expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Convert back to normalized YOLO format
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Clamp to valid range
new_x_center = max(0, min(1, new_x_center))
new_y_center = max(0, min(1, new_y_center))
new_width = max(0, min(1, new_width))
new_height = max(0, min(1, new_height))
line = f"{ann.class_id} {new_x_center:.6f} {new_y_center:.6f} {new_width:.6f} {new_height:.6f}\n"
f.write(line) f.write(line)
total_annotations += 1 total_annotations += 1

View File

@@ -0,0 +1,62 @@
"""
Logging Configuration
Provides consistent logging setup for CLI tools and modules.
"""
import logging
import sys
from typing import Optional
def setup_cli_logging(
level: int = logging.INFO,
name: Optional[str] = None,
format_string: Optional[str] = None,
) -> logging.Logger:
"""
Configure logging for CLI applications.
Args:
level: Logging level (default: INFO)
name: Logger name (default: root logger)
format_string: Custom format string (default: simple CLI format)
Returns:
Configured logger instance
"""
if format_string is None:
format_string = "%(message)s"
# Configure root logger or specific logger
logger = logging.getLogger(name)
logger.setLevel(level)
# Remove existing handlers to avoid duplicates
logger.handlers.clear()
# Create console handler
handler = logging.StreamHandler(sys.stdout)
handler.setLevel(level)
handler.setFormatter(logging.Formatter(format_string))
logger.addHandler(handler)
return logger
def setup_verbose_logging(
level: int = logging.DEBUG,
name: Optional[str] = None,
) -> logging.Logger:
"""
Configure verbose logging with timestamps and module info.
Args:
level: Logging level (default: DEBUG)
name: Logger name (default: root logger)
Returns:
Configured logger instance
"""
format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
return setup_cli_logging(level=level, name=name, format_string=format_string)

View File

@@ -9,6 +9,7 @@ Now reads from PostgreSQL database instead of JSONL files.
import argparse import argparse
import csv import csv
import json import json
import logging
import sys import sys
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -16,6 +17,9 @@ from pathlib import Path
from typing import Optional from typing import Optional
from shared.config import get_db_connection_string from shared.config import get_db_connection_string
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
from shared.normalize import normalize_field from shared.normalize import normalize_field
from shared.matcher import FieldMatcher from shared.matcher import FieldMatcher
@@ -104,7 +108,7 @@ class LabelAnalyzer:
for row in reader: for row in reader:
doc_id = row['DocumentId'] doc_id = row['DocumentId']
self.csv_data[doc_id] = row self.csv_data[doc_id] = row
print(f"Loaded {len(self.csv_data)} records from CSV") logger.info("Loaded %d records from CSV", len(self.csv_data))
def load_labels(self): def load_labels(self):
"""Load all label files from dataset.""" """Load all label files from dataset."""
@@ -150,12 +154,12 @@ class LabelAnalyzer:
for doc in self.label_data.values() for doc in self.label_data.values()
for labels in doc['pages'].values() for labels in doc['pages'].values()
) )
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)") logger.info("Loaded labels for %d documents (%d total labels)", total_docs, total_labels)
def load_report(self): def load_report(self):
"""Load autolabel report from database.""" """Load autolabel report from database."""
if not self.db: if not self.db:
print("Database not configured, skipping report loading") logger.info("Database not configured, skipping report loading")
return return
# Get document IDs from CSV to query # Get document IDs from CSV to query
@@ -175,7 +179,7 @@ class LabelAnalyzer:
self.report_data[doc_id] = doc self.report_data[doc_id] = doc
loaded += 1 loaded += 1
print(f"Loaded {loaded} autolabel reports from database") logger.info("Loaded %d autolabel reports from database", loaded)
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]: def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
"""Analyze a single document.""" """Analyze a single document."""
@@ -373,7 +377,7 @@ class LabelAnalyzer:
break break
if skipped > 0: if skipped > 0:
print(f"Skipped {skipped} documents without PDF files") logger.info("Skipped %d documents without PDF files", skipped)
return results return results
@@ -447,7 +451,7 @@ class LabelAnalyzer:
with open(output, 'w', encoding='utf-8') as f: with open(output, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False) json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\nReport saved to: {output}") logger.info("Report saved to: %s", output)
return report return report
@@ -456,52 +460,52 @@ def print_summary(report: dict):
"""Print summary to console.""" """Print summary to console."""
summary = report['summary'] summary = report['summary']
print("\n" + "=" * 60) logger.info("=" * 60)
print("LABEL ANALYSIS SUMMARY") logger.info("LABEL ANALYSIS SUMMARY")
print("=" * 60) logger.info("=" * 60)
print(f"\nDocuments:") logger.info("Documents:")
print(f" Total: {summary['total_documents']}") logger.info(" Total: %d", summary['total_documents'])
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})") logger.info(" With issues: %d (%s)", summary['documents_with_issues'], summary['issue_rate'])
print(f"\nFields:") logger.info("Fields:")
print(f" Expected: {summary['total_expected_fields']}") logger.info(" Expected: %d", summary['total_expected_fields'])
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})") logger.info(" Labeled: %d (%s)", summary['total_labeled_fields'], summary['label_coverage'])
print(f" Missing: {summary['missing_labels']}") logger.info(" Missing: %d", summary['missing_labels'])
print(f" Extra: {summary['extra_labels']}") logger.info(" Extra: %d", summary['extra_labels'])
print(f"\nFailure Reasons:") logger.info("Failure Reasons:")
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]): for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
print(f" {reason}: {count}") logger.info(" %s: %d", reason, count)
print(f"\nFailures by Field:") logger.info("Failures by Field:")
for field, reasons in report['failures_by_field'].items(): for field, reasons in report['failures_by_field'].items():
total = sum(reasons.values()) total = sum(reasons.values())
print(f" {field}: {total}") logger.info(" %s: %d", field, total)
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]): for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" - {reason}: {count}") logger.info(" - %s: %d", reason, count)
# Show sample issues # Show sample issues
if report['issues']: if report['issues']:
print(f"\n" + "-" * 60) logger.info("-" * 60)
print("SAMPLE ISSUES (first 10)") logger.info("SAMPLE ISSUES (first 10)")
print("-" * 60) logger.info("-" * 60)
for issue in report['issues'][:10]: for issue in report['issues'][:10]:
print(f"\n[{issue['doc_id']}] {issue['field']}") logger.info("[%s] %s", issue['doc_id'], issue['field'])
print(f" CSV value: {issue['csv_value']}") logger.info(" CSV value: %s", issue['csv_value'])
print(f" Reason: {issue['reason']}") logger.info(" Reason: %s", issue['reason'])
if issue.get('details'): if issue.get('details'):
details = issue['details'] details = issue['details']
if details.get('normalized_candidates'): if details.get('normalized_candidates'):
print(f" Candidates: {details['normalized_candidates'][:5]}") logger.info(" Candidates: %s", details['normalized_candidates'][:5])
if details.get('pdf_tokens_sample'): if details.get('pdf_tokens_sample'):
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}") logger.info(" PDF samples: %s", details['pdf_tokens_sample'][:5])
if details.get('potential_matches'): if details.get('potential_matches'):
print(f" Potential matches:") logger.info(" Potential matches:")
for pm in details['potential_matches'][:3]: for pm in details['potential_matches'][:3]:
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'") logger.info(" - token='%s' matches candidate='%s'", pm['token'], pm['candidate'])
def main(): def main():
@@ -551,6 +555,9 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
analyzer = LabelAnalyzer( analyzer = LabelAnalyzer(
csv_path=args.csv, csv_path=args.csv,
pdf_dir=args.pdf_dir, pdf_dir=args.pdf_dir,
@@ -566,30 +573,30 @@ def main():
analysis = analyzer.analyze_document(args.single) analysis = analyzer.analyze_document(args.single)
print(f"\n{'=' * 60}") logger.info("=" * 60)
print(f"Document: {analysis.doc_id}") logger.info("Document: %s", analysis.doc_id)
print(f"{'=' * 60}") logger.info("=" * 60)
print(f"PDF exists: {analysis.pdf_exists}") logger.info("PDF exists: %s", analysis.pdf_exists)
print(f"PDF type: {analysis.pdf_type}") logger.info("PDF type: %s", analysis.pdf_type)
print(f"Pages: {analysis.total_pages}") logger.info("Pages: %d", analysis.total_pages)
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):") logger.info("Fields (CSV: %d, Labeled: %d):", analysis.csv_fields_count, analysis.labeled_fields_count)
for f in analysis.fields: for f in analysis.fields:
status = "" if f.labeled else ("" if f.expected else "-") status = "[OK]" if f.labeled else ("[FAIL]" if f.expected else "[-]")
value_str = f.csv_value[:30] if f.csv_value else "(empty)" value_str = f.csv_value[:30] if f.csv_value else "(empty)"
print(f" [{status}] {f.field_name}: {value_str}") logger.info(" %s %s: %s", status, f.field_name, value_str)
if f.failure_reason: if f.failure_reason:
print(f" Reason: {f.failure_reason}") logger.info(" Reason: %s", f.failure_reason)
if f.details.get('normalized_candidates'): if f.details.get('normalized_candidates'):
print(f" Candidates: {f.details['normalized_candidates']}") logger.info(" Candidates: %s", f.details['normalized_candidates'])
if f.details.get('potential_matches'): if f.details.get('potential_matches'):
print(f" Potential matches in PDF:") logger.info(" Potential matches in PDF:")
for pm in f.details['potential_matches'][:3]: for pm in f.details['potential_matches'][:3]:
print(f" - '{pm['token']}'") logger.info(" - '%s'", pm['token'])
else: else:
# Full analysis # Full analysis
print("Running label analysis...") logger.info("Running label analysis...")
results = analyzer.run_analysis(limit=args.limit) results = analyzer.run_analysis(limit=args.limit)
report = analyzer.generate_report(results, args.output, verbose=args.verbose) report = analyzer.generate_report(results, args.output, verbose=args.verbose)
print_summary(report) print_summary(report)

View File

@@ -7,11 +7,15 @@ Generates statistics and insights from database or autolabel_report.jsonl
import argparse import argparse
import json import json
import logging
import sys import sys
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from shared.config import get_db_connection_string from shared.config import get_db_connection_string
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def load_reports_from_db() -> dict: def load_reports_from_db() -> dict:
@@ -147,9 +151,9 @@ def load_reports_from_file(report_path: str) -> list[dict]:
if not report_files: if not report_files:
return [] return []
print(f"Reading {len(report_files)} report file(s):") logger.info("Reading %d report file(s):", len(report_files))
for f in report_files: for f in report_files:
print(f" - {f.name}") logger.info(" - %s", f.name)
reports = [] reports = []
for report_file in report_files: for report_file in report_files:
@@ -231,55 +235,55 @@ def analyze_reports(reports: list[dict]) -> dict:
def print_report(stats: dict, verbose: bool = False): def print_report(stats: dict, verbose: bool = False):
"""Print analysis report.""" """Print analysis report."""
print("\n" + "=" * 60) logger.info("=" * 60)
print("AUTO-LABEL REPORT ANALYSIS") logger.info("AUTO-LABEL REPORT ANALYSIS")
print("=" * 60) logger.info("=" * 60)
# Overall stats # Overall stats
print(f"\n{'OVERALL STATISTICS':^60}") logger.info("%s", "OVERALL STATISTICS".center(60))
print("-" * 60) logger.info("-" * 60)
total = stats['total'] total = stats['total']
successful = stats['successful'] successful = stats['successful']
failed = stats['failed'] failed = stats['failed']
success_rate = successful / total * 100 if total > 0 else 0 success_rate = successful / total * 100 if total > 0 else 0
print(f"Total documents: {total:>8}") logger.info("Total documents: %8d", total)
print(f"Successful: {successful:>8} ({success_rate:.1f}%)") logger.info("Successful: %8d (%.1f%%)", successful, success_rate)
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)") logger.info("Failed: %8d (%.1f%%)", failed, 100-success_rate)
# Processing time # Processing time
if 'processing_time_stats' in stats: if 'processing_time_stats' in stats:
pts = stats['processing_time_stats'] pts = stats['processing_time_stats']
print(f"\nProcessing time (ms):") logger.info("Processing time (ms):")
print(f" Average: {pts['avg_ms']:>8.1f}") logger.info(" Average: %8.1f", pts['avg_ms'])
print(f" Min: {pts['min_ms']:>8.1f}") logger.info(" Min: %8.1f", pts['min_ms'])
print(f" Max: {pts['max_ms']:>8.1f}") logger.info(" Max: %8.1f", pts['max_ms'])
elif stats.get('processing_times'): elif stats.get('processing_times'):
times = stats['processing_times'] times = stats['processing_times']
avg_time = sum(times) / len(times) avg_time = sum(times) / len(times)
min_time = min(times) min_time = min(times)
max_time = max(times) max_time = max(times)
print(f"\nProcessing time (ms):") logger.info("Processing time (ms):")
print(f" Average: {avg_time:>8.1f}") logger.info(" Average: %8.1f", avg_time)
print(f" Min: {min_time:>8.1f}") logger.info(" Min: %8.1f", min_time)
print(f" Max: {max_time:>8.1f}") logger.info(" Max: %8.1f", max_time)
# By PDF type # By PDF type
print(f"\n{'BY PDF TYPE':^60}") logger.info("%s", "BY PDF TYPE".center(60))
print("-" * 60) logger.info("-" * 60)
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}") logger.info("%-15s %10s %10s %10s", 'Type', 'Total', 'Success', 'Rate')
print("-" * 60) logger.info("-" * 60)
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()): for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
type_total = type_stats['total'] type_total = type_stats['total']
type_success = type_stats['successful'] type_success = type_stats['successful']
type_rate = type_success / type_total * 100 if type_total > 0 else 0 type_rate = type_success / type_total * 100 if type_total > 0 else 0
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%") logger.info("%-15s %10d %10d %9.1f%%", pdf_type, type_total, type_success, type_rate)
# By field # By field
print(f"\n{'FIELD MATCH STATISTICS':^60}") logger.info("%s", "FIELD MATCH STATISTICS".center(60))
print("-" * 60) logger.info("-" * 60)
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}") logger.info("%-18s %7s %7s %7s %7s %7s %8s", 'Field', 'Total', 'Match', 'Rate', 'Exact', 'Flex', 'AvgScore')
print("-" * 60) logger.info("-" * 60)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']: for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']: if field_name not in stats['by_field']:
@@ -299,16 +303,16 @@ def print_report(stats: dict, verbose: bool = False):
else: else:
avg_score = 0 avg_score = 0
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}") logger.info("%-18s %7d %7d %6.1f%% %7d %7d %8.3f", field_name, total, matched, rate, exact, flex, avg_score)
# Field match by PDF type # Field match by PDF type
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}") logger.info("%s", "FIELD MATCH BY PDF TYPE".center(60))
print("-" * 60) logger.info("-" * 60)
for pdf_type in sorted(stats['by_pdf_type'].keys()): for pdf_type in sorted(stats['by_pdf_type'].keys()):
print(f"\n[{pdf_type.upper()}]") logger.info("[%s]", pdf_type.upper())
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}") logger.info("%-18s %10s %10s %10s", 'Field', 'Total', 'Matched', 'Rate')
print("-" * 50) logger.info("-" * 50)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']: for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']: if field_name not in stats['by_field']:
@@ -317,16 +321,16 @@ def print_report(stats: dict, verbose: bool = False):
total = type_stats['total'] total = type_stats['total']
matched = type_stats['matched'] matched = type_stats['matched']
rate = matched / total * 100 if total > 0 else 0 rate = matched / total * 100 if total > 0 else 0
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%") logger.info("%-18s %10d %10d %9.1f%%", field_name, total, matched, rate)
# Errors # Errors
if stats.get('errors') and verbose: if stats.get('errors') and verbose:
print(f"\n{'ERRORS':^60}") logger.info("%s", "ERRORS".center(60))
print("-" * 60) logger.info("-" * 60)
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]: for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
print(f"{count:>5}x {error[:50]}") logger.info("%5dx %s", count, error[:50])
print("\n" + "=" * 60) logger.info("=" * 60)
def export_json(stats: dict, output_path: str): def export_json(stats: dict, output_path: str):
@@ -372,7 +376,7 @@ def export_json(stats: dict, output_path: str):
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False) json.dump(export_data, f, indent=2, ensure_ascii=False)
print(f"\nStatistics exported to: {output_path}") logger.info("Statistics exported to: %s", output_path)
def main(): def main():
@@ -401,25 +405,28 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Decide source # Decide source
use_db = not args.from_file and args.report is None use_db = not args.from_file and args.report is None
if use_db: if use_db:
print("Loading statistics from database...") logger.info("Loading statistics from database...")
stats = load_reports_from_db() stats = load_reports_from_db()
print(f"Loaded stats for {stats['total']} documents") logger.info("Loaded stats for %d documents", stats['total'])
else: else:
report_path = args.report or 'reports/autolabel_report.jsonl' report_path = args.report or 'reports/autolabel_report.jsonl'
path = Path(report_path) path = Path(report_path)
# Check if file exists (handle glob patterns) # Check if file exists (handle glob patterns)
if '*' not in str(path) and '?' not in str(path) and not path.exists(): if '*' not in str(path) and '?' not in str(path) and not path.exists():
print(f"Error: Report file not found: {path}") logger.error("Report file not found: %s", path)
return 1 return 1
print(f"Loading reports from: {report_path}") logger.info("Loading reports from: %s", report_path)
reports = load_reports_from_file(report_path) reports = load_reports_from_file(report_path)
print(f"Loaded {len(reports)} reports") logger.info("Loaded %d reports", len(reports))
stats = analyze_reports(reports) stats = analyze_reports(reports)
print_report(stats, verbose=args.verbose) print_report(stats, verbose=args.verbose)

View File

@@ -6,6 +6,7 @@ Generates YOLO training data from PDFs and structured CSV data.
""" """
import argparse import argparse
import logging
import sys import sys
import time import time
import os import os
@@ -17,6 +18,10 @@ from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
import multiprocessing import multiprocessing
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
# Global flag for graceful shutdown # Global flag for graceful shutdown
_shutdown_requested = False _shutdown_requested = False
@@ -25,8 +30,8 @@ def _signal_handler(signum, frame):
"""Handle interrupt signals for graceful shutdown.""" """Handle interrupt signals for graceful shutdown."""
global _shutdown_requested global _shutdown_requested
_shutdown_requested = True _shutdown_requested = True
print("\n\nShutdown requested. Finishing current batch and saving progress...") logger.warning("Shutdown requested. Finishing current batch and saving progress...")
print("(Press Ctrl+C again to force quit)\n") logger.warning("(Press Ctrl+C again to force quit)")
# Windows compatibility: use 'spawn' method for multiprocessing # Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR # This is required on Windows and is also safer for libraries like PaddleOCR
@@ -350,11 +355,14 @@ def main():
if ',' in csv_input and '*' not in csv_input: if ',' in csv_input and '*' not in csv_input:
csv_input = [p.strip() for p in csv_input.split(',')] csv_input = [p.strip() for p in csv_input.split(',')]
# Configure logging for CLI
setup_cli_logging()
# Get list of CSV files (don't load all data at once) # Get list of CSV files (don't load all data at once)
temp_loader = CSVLoader(csv_input, args.pdf_dir) temp_loader = CSVLoader(csv_input, args.pdf_dir)
csv_files = temp_loader.csv_paths csv_files = temp_loader.csv_paths
pdf_dir = temp_loader.pdf_dir pdf_dir = temp_loader.pdf_dir
print(f"Found {len(csv_files)} CSV file(s) to process") logger.info("Found %d CSV file(s) to process", len(csv_files))
# Setup output directories # Setup output directories
output_dir = Path(args.output) output_dir = Path(args.output)
@@ -371,7 +379,7 @@ def main():
db = DocumentDB() db = DocumentDB()
db.connect() db.connect()
db.create_tables() # Ensure tables exist db.create_tables() # Ensure tables exist
print("Connected to database for status checking") logger.info("Connected to database for status checking")
# Global stats # Global stats
stats = { stats = {
@@ -443,7 +451,7 @@ def main():
db.save_documents_batch(db_batch) db.save_documents_batch(db_batch)
db_batch.clear() db_batch.clear()
if args.verbose: if args.verbose:
print(f"Error processing {doc_id}: {error}") logger.error("Error processing %s: %s", doc_id, error)
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs) # Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
dual_pool_coordinator = None dual_pool_coordinator = None
@@ -453,7 +461,7 @@ def main():
from training.processing import DualPoolCoordinator from training.processing import DualPoolCoordinator
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers") logger.info("Starting dual-pool mode: %d CPU + %d GPU workers", args.cpu_workers, args.gpu_workers)
dual_pool_coordinator = DualPoolCoordinator( dual_pool_coordinator = DualPoolCoordinator(
cpu_workers=args.cpu_workers, cpu_workers=args.cpu_workers,
gpu_workers=args.gpu_workers, gpu_workers=args.gpu_workers,
@@ -467,10 +475,10 @@ def main():
for csv_idx, csv_file in enumerate(csv_files): for csv_idx, csv_file in enumerate(csv_files):
# Check for shutdown request # Check for shutdown request
if _shutdown_requested: if _shutdown_requested:
print("\nShutdown requested. Stopping after current batch...") logger.warning("Shutdown requested. Stopping after current batch...")
break break
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}") logger.info("[%d/%d] Processing: %s", csv_idx + 1, len(csv_files), csv_file.name)
# Load only this CSV file # Load only this CSV file
single_loader = CSVLoader(str(csv_file), str(pdf_dir)) single_loader = CSVLoader(str(csv_file), str(pdf_dir))
@@ -488,7 +496,7 @@ def main():
seen_doc_ids.add(r.DocumentId) seen_doc_ids.add(r.DocumentId)
if not rows: if not rows:
print(f" Skipping CSV (no new documents)") logger.info(" Skipping CSV (no new documents)")
continue continue
# Batch query database for all document IDs in this CSV # Batch query database for all document IDs in this CSV
@@ -500,13 +508,13 @@ def main():
# Skip entire CSV if all documents are already processed # Skip entire CSV if all documents are already processed
if already_processed == len(rows): if already_processed == len(rows):
print(f" Skipping CSV (all {len(rows)} documents already processed)") logger.info(" Skipping CSV (all %d documents already processed)", len(rows))
stats['skipped_db'] += len(rows) stats['skipped_db'] += len(rows)
continue continue
# Count how many new documents need processing in this CSV # Count how many new documents need processing in this CSV
new_to_process = len(rows) - already_processed new_to_process = len(rows) - already_processed
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)") logger.info(" Found %d new documents to process (%d already in DB)", new_to_process, already_processed)
stats['total'] += len(rows) stats['total'] += len(rows)
@@ -520,7 +528,7 @@ def main():
if args.limit: if args.limit:
remaining_limit = args.limit - stats.get('tasks_submitted', 0) remaining_limit = args.limit - stats.get('tasks_submitted', 0)
if remaining_limit <= 0: if remaining_limit <= 0:
print(f" Reached limit of {args.limit} new documents, stopping.") logger.info(" Reached limit of %d new documents, stopping.", args.limit)
break break
else: else:
remaining_limit = float('inf') remaining_limit = float('inf')
@@ -583,7 +591,7 @@ def main():
)) ))
if skipped_in_csv > 0 or retry_in_csv > 0: if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed") logger.info(" Skipped %d (already in DB), retrying %d failed", skipped_in_csv, retry_in_csv)
# Clean up retry documents: delete from database and remove temp folders # Clean up retry documents: delete from database and remove temp folders
if retry_doc_ids: if retry_doc_ids:
@@ -599,7 +607,7 @@ def main():
temp_doc_dir = output_dir / 'temp' / doc_id temp_doc_dir = output_dir / 'temp' / doc_id
if temp_doc_dir.exists(): if temp_doc_dir.exists():
shutil.rmtree(temp_doc_dir, ignore_errors=True) shutil.rmtree(temp_doc_dir, ignore_errors=True)
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)") logger.info(" Cleaned up %d retry documents (DB + temp folders)", len(retry_doc_ids))
if not tasks: if not tasks:
continue continue
@@ -636,7 +644,7 @@ def main():
# Count task types # Count task types
text_count = sum(1 for d in documents if not d["is_scanned"]) text_count = sum(1 for d in documents if not d["is_scanned"])
scan_count = len(documents) - text_count scan_count = len(documents) - text_count
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}") logger.info(" Text PDFs: %d, Scanned PDFs: %d", text_count, scan_count)
# Progress tracking with tqdm # Progress tracking with tqdm
pbar = tqdm(total=len(documents), desc="Processing") pbar = tqdm(total=len(documents), desc="Processing")
@@ -667,11 +675,11 @@ def main():
# Log summary # Log summary
successful = sum(1 for r in results if r.success) successful = sum(1 for r in results if r.success)
failed = len(results) - successful failed = len(results) - successful
print(f" Batch complete: {successful} successful, {failed} failed") logger.info(" Batch complete: %d successful, %d failed", successful, failed)
else: else:
# Single-pool mode (original behavior) # Single-pool mode (original behavior)
print(f" Processing {len(tasks)} documents with {args.workers} workers...") logger.info(" Processing %d documents with %d workers...", len(tasks), args.workers)
# Process documents in parallel (inside CSV loop for streaming) # Process documents in parallel (inside CSV loop for streaming)
# Use single process for debugging or when workers=1 # Use single process for debugging or when workers=1
@@ -725,28 +733,28 @@ def main():
db.close() db.close()
# Print summary # Print summary
print("\n" + "=" * 50) logger.info("=" * 50)
print("Auto-labeling Complete") logger.info("Auto-labeling Complete")
print("=" * 50) logger.info("=" * 50)
print(f"Total documents: {stats['total']}") logger.info("Total documents: %d", stats['total'])
print(f"Successful: {stats['successful']}") logger.info("Successful: %d", stats['successful'])
print(f"Failed: {stats['failed']}") logger.info("Failed: %d", stats['failed'])
print(f"Skipped (no PDF): {stats['skipped']}") logger.info("Skipped (no PDF): %d", stats['skipped'])
print(f"Skipped (in DB): {stats['skipped_db']}") logger.info("Skipped (in DB): %d", stats['skipped_db'])
print(f"Retried (failed): {stats['retried']}") logger.info("Retried (failed): %d", stats['retried'])
print(f"Total annotations: {stats['annotations']}") logger.info("Total annotations: %d", stats['annotations'])
print(f"\nImages saved to: {output_dir / 'temp'}") logger.info("Images saved to: %s", output_dir / 'temp')
print(f"Labels stored in: PostgreSQL database") logger.info("Labels stored in: PostgreSQL database")
print(f"\nAnnotations by field:") logger.info("Annotations by field:")
for field, count in stats['by_field'].items(): for field, count in stats['by_field'].items():
print(f" {field}: {count}") logger.info(" %s: %d", field, count)
shard_files = report_writer.get_shard_files() shard_files = report_writer.get_shard_files()
if len(shard_files) > 1: if len(shard_files) > 1:
print(f"\nReport files ({len(shard_files)}):") logger.info("Report files (%d):", len(shard_files))
for sf in shard_files: for sf in shard_files:
print(f" - {sf}") logger.info(" - %s", sf)
else: else:
print(f"\nReport: {shard_files[0] if shard_files else args.report}") logger.info("Report: %s", shard_files[0] if shard_files else args.report)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -8,6 +8,7 @@ Usage:
import argparse import argparse
import json import json
import logging
import sys import sys
from pathlib import Path from pathlib import Path
@@ -16,6 +17,9 @@ from psycopg2.extras import execute_values
# Add project root to path # Add project root to path
from shared.config import get_db_connection_string, PATHS from shared.config import get_db_connection_string, PATHS
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def create_tables(conn): def create_tables(conn):
@@ -150,7 +154,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
try: try:
record = json.loads(line) record = json.loads(line)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print(f" Warning: Line {line_no} - JSON parse error: {e}") logger.warning("Line %d - JSON parse error: %s", line_no, e)
stats['errors'] += 1 stats['errors'] += 1
continue continue
@@ -211,7 +215,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
# Flush batch if needed # Flush batch if needed
if len(doc_batch) >= batch_size: if len(doc_batch) >= batch_size:
flush_batches() flush_batches()
print(f" Processed {stats['imported'] + stats['skipped']} records...") logger.info(" Processed %d records...", stats['imported'] + stats['skipped'])
# Final flush # Final flush
flush_batches() flush_batches()
@@ -243,11 +247,14 @@ def main():
else: else:
report_files = [report_path] if report_path.exists() else [] report_files = [report_path] if report_path.exists() else []
# Configure logging for CLI
setup_cli_logging()
if not report_files: if not report_files:
print(f"No report files found: {args.report}") logger.error("No report files found: %s", args.report)
return return
print(f"Found {len(report_files)} report file(s)") logger.info("Found %d report file(s)", len(report_files))
# Connect to database # Connect to database
conn = psycopg2.connect(db_connection) conn = psycopg2.connect(db_connection)
@@ -257,20 +264,20 @@ def main():
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0} total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
for report_file in report_files: for report_file in report_files:
print(f"\nImporting: {report_file.name}") logger.info("Importing: %s", report_file.name)
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size) stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}") logger.info(" Imported: %d, Skipped: %d, Errors: %d", stats['imported'], stats['skipped'], stats['errors'])
for key in total_stats: for key in total_stats:
total_stats[key] += stats[key] total_stats[key] += stats[key]
# Print summary # Print summary
print("\n" + "=" * 50) logger.info("=" * 50)
print("Import Complete") logger.info("Import Complete")
print("=" * 50) logger.info("=" * 50)
print(f"Total imported: {total_stats['imported']}") logger.info("Total imported: %d", total_stats['imported'])
print(f"Total skipped: {total_stats['skipped']}") logger.info("Total skipped: %d", total_stats['skipped'])
print(f"Total errors: {total_stats['errors']}") logger.info("Total errors: %d", total_stats['errors'])
# Quick stats from database # Quick stats from database
with conn.cursor() as cursor: with conn.cursor() as cursor:
@@ -288,11 +295,11 @@ def main():
conn.close() conn.close()
print(f"\nDatabase Stats:") logger.info("Database Stats:")
print(f" Documents: {total_docs} ({success_docs} successful)") logger.info(" Documents: %d (%d successful)", total_docs, success_docs)
print(f" Field results: {total_fields} ({matched_fields} matched)") logger.info(" Field results: %d (%d matched)", total_fields, matched_fields)
if total_fields > 0: if total_fields > 0:
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%") logger.info(" Match rate: %.2f%%", matched_fields / total_fields * 100)
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -7,6 +7,7 @@ CSV values, and source CSV filename in a new table.
import argparse import argparse
import json import json
import glob import glob
import logging
import os import os
import sys import sys
import time import time
@@ -20,6 +21,9 @@ from shared.config import DEFAULT_DPI
from shared.data.db import DocumentDB from shared.data.db import DocumentDB
from shared.data.csv_loader import CSVLoader from shared.data.csv_loader import CSVLoader
from shared.normalize.normalizer import normalize_field from shared.normalize.normalizer import normalize_field
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def create_failed_match_table(db: DocumentDB): def create_failed_match_table(db: DocumentDB):
@@ -57,7 +61,7 @@ def create_failed_match_table(db: DocumentDB):
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched); CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
""") """)
conn.commit() conn.commit()
print("Created table: failed_match_details") logger.info("Created table: failed_match_details")
def get_failed_documents(db: DocumentDB) -> list: def get_failed_documents(db: DocumentDB) -> list:
@@ -332,14 +336,17 @@ def main():
parser.add_argument('--limit', type=int, help='Limit number of documents to process') parser.add_argument('--limit', type=int, help='Limit number of documents to process')
args = parser.parse_args() args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Expand CSV glob # Expand CSV glob
csv_files = sorted(glob.glob(args.csv)) csv_files = sorted(glob.glob(args.csv))
print(f"Found {len(csv_files)} CSV files") logger.info("Found %d CSV files", len(csv_files))
# Build CSV cache # Build CSV cache
print("Building CSV filename cache...") logger.info("Building CSV filename cache...")
build_csv_cache(csv_files) build_csv_cache(csv_files)
print(f"Cached {len(_csv_cache)} document IDs") logger.info("Cached %d document IDs", len(_csv_cache))
# Connect to database # Connect to database
db = DocumentDB() db = DocumentDB()
@@ -349,13 +356,13 @@ def main():
create_failed_match_table(db) create_failed_match_table(db)
# Get all failed documents # Get all failed documents
print("Fetching failed documents...") logger.info("Fetching failed documents...")
failed_docs = get_failed_documents(db) failed_docs = get_failed_documents(db)
print(f"Found {len(failed_docs)} documents with failed matches") logger.info("Found %d documents with failed matches", len(failed_docs))
if args.limit: if args.limit:
failed_docs = failed_docs[:args.limit] failed_docs = failed_docs[:args.limit]
print(f"Limited to {len(failed_docs)} documents") logger.info("Limited to %d documents", len(failed_docs))
# Prepare tasks # Prepare tasks
tasks = [] tasks = []
@@ -365,7 +372,7 @@ def main():
if failed_fields: if failed_fields:
tasks.append((doc, failed_fields, csv_filename)) tasks.append((doc, failed_fields, csv_filename))
print(f"Processing {len(tasks)} documents with {args.workers} workers...") logger.info("Processing %d documents with %d workers...", len(tasks), args.workers)
# Process with multiprocessing # Process with multiprocessing
total_results = 0 total_results = 0
@@ -389,15 +396,15 @@ def main():
batch_results = [] batch_results = []
except TimeoutError: except TimeoutError:
print(f"\nTimeout processing {doc_id}") logger.warning("Timeout processing %s", doc_id)
except Exception as e: except Exception as e:
print(f"\nError processing {doc_id}: {e}") logger.error("Error processing %s: %s", doc_id, e)
# Save remaining results # Save remaining results
if batch_results: if batch_results:
save_results_batch(db, batch_results) save_results_batch(db, batch_results)
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table") logger.info("Done! Saved %d failed match records to failed_match_details table", total_results)
# Show summary # Show summary
conn = db.connect() conn = db.connect()
@@ -410,12 +417,12 @@ def main():
GROUP BY field_name GROUP BY field_name
ORDER BY total DESC ORDER BY total DESC
""") """)
print("\nSummary by field:") logger.info("Summary by field:")
print("-" * 70) logger.info("-" * 70)
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}") logger.info("%-35s %8s %10s %12s", 'Field', 'Total', 'Has OCR', 'Avg Score')
print("-" * 70) logger.info("-" * 70)
for row in cursor.fetchall(): for row in cursor.fetchall():
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}") logger.info("%-35s %8d %10d %12.2f", row[0], row[1], row[2], row[3])
db.close() db.close()

View File

@@ -7,10 +7,14 @@ Images are read from filesystem, labels are dynamically generated from DB.
""" """
import argparse import argparse
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from shared.config import DEFAULT_DPI, PATHS from shared.config import DEFAULT_DPI, PATHS
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main(): def main():
@@ -119,47 +123,50 @@ def main():
args = parser.parse_args() args = parser.parse_args()
# Configure logging for CLI
setup_cli_logging()
# Apply low-memory mode if specified # Apply low-memory mode if specified
if args.low_memory: if args.low_memory:
print("🔧 Low memory mode enabled") logger.info("Low memory mode enabled")
args.batch = min(args.batch, 8) # Reduce from 16 to 8 args.batch = min(args.batch, 8) # Reduce from 16 to 8
args.workers = min(args.workers, 4) # Reduce from 8 to 4 args.workers = min(args.workers, 4) # Reduce from 8 to 4
args.cache = False args.cache = False
print(f" Batch size: {args.batch}") logger.info(" Batch size: %d", args.batch)
print(f" Workers: {args.workers}") logger.info(" Workers: %d", args.workers)
print(f" Cache: disabled") logger.info(" Cache: disabled")
# Validate dataset directory # Validate dataset directory
dataset_dir = Path(args.dataset_dir) dataset_dir = Path(args.dataset_dir)
temp_dir = dataset_dir / 'temp' temp_dir = dataset_dir / 'temp'
if not temp_dir.exists(): if not temp_dir.exists():
print(f"Error: Temp directory not found: {temp_dir}") logger.error("Temp directory not found: %s", temp_dir)
print("Run autolabel first to generate images.") logger.error("Run autolabel first to generate images.")
sys.exit(1) sys.exit(1)
print("=" * 60) logger.info("=" * 60)
print("YOLO Training with Database Labels") logger.info("YOLO Training with Database Labels")
print("=" * 60) logger.info("=" * 60)
print(f"Dataset dir: {dataset_dir}") logger.info("Dataset dir: %s", dataset_dir)
print(f"Model: {args.model}") logger.info("Model: %s", args.model)
print(f"Epochs: {args.epochs}") logger.info("Epochs: %d", args.epochs)
print(f"Batch size: {args.batch}") logger.info("Batch size: %d", args.batch)
print(f"Image size: {args.imgsz}") logger.info("Image size: %d", args.imgsz)
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}") logger.info("Split ratio: %s/%s/%.1f", args.train_ratio, args.val_ratio, 1-args.train_ratio-args.val_ratio)
if args.limit: if args.limit:
print(f"Document limit: {args.limit}") logger.info("Document limit: %d", args.limit)
# Connect to database # Connect to database
from shared.data.db import DocumentDB from shared.data.db import DocumentDB
print("\nConnecting to database...") logger.info("Connecting to database...")
db = DocumentDB() db = DocumentDB()
db.connect() db.connect()
# Create datasets from database # Create datasets from database
from training.yolo.db_dataset import create_datasets from training.yolo.db_dataset import create_datasets
print("Loading dataset from database...") logger.info("Loading dataset from database...")
datasets = create_datasets( datasets = create_datasets(
images_dir=dataset_dir, images_dir=dataset_dir,
db=db, db=db,
@@ -170,39 +177,39 @@ def main():
limit=args.limit limit=args.limit
) )
print(f"\nDataset splits:") logger.info("Dataset splits:")
print(f" Train: {len(datasets['train'])} items") logger.info(" Train: %d items", len(datasets['train']))
print(f" Val: {len(datasets['val'])} items") logger.info(" Val: %d items", len(datasets['val']))
print(f" Test: {len(datasets['test'])} items") logger.info(" Test: %d items", len(datasets['test']))
if len(datasets['train']) == 0: if len(datasets['train']) == 0:
print("\nError: No training data found!") logger.error("No training data found!")
print("Make sure autolabel has been run and images exist in temp directory.") logger.error("Make sure autolabel has been run and images exist in temp directory.")
db.close() db.close()
sys.exit(1) sys.exit(1)
# Export to YOLO format (required for Ultralytics training) # Export to YOLO format (required for Ultralytics training)
print("\nExporting dataset to YOLO format...") logger.info("Exporting dataset to YOLO format...")
for split_name, dataset in datasets.items(): for split_name, dataset in datasets.items():
count = dataset.export_to_yolo_format(dataset_dir, split_name) count = dataset.export_to_yolo_format(dataset_dir, split_name)
print(f" {split_name}: {count} items exported") logger.info(" %s: %d items exported", split_name, count)
# Generate YOLO config files # Generate YOLO config files
from training.yolo.annotation_generator import AnnotationGenerator from training.yolo.annotation_generator import AnnotationGenerator
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt') AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml') AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}") logger.info("Generated dataset.yaml at: %s", dataset_dir / 'dataset.yaml')
if args.export_only: if args.export_only:
print("\nExport complete (--export-only specified, skipping training)") logger.info("Export complete (--export-only specified, skipping training)")
db.close() db.close()
return return
# Start training using shared trainer # Start training using shared trainer
print("\n" + "=" * 60) logger.info("=" * 60)
print("Starting YOLO Training") logger.info("Starting YOLO Training")
print("=" * 60) logger.info("=" * 60)
from shared.training import YOLOTrainer, TrainingConfig from shared.training import YOLOTrainer, TrainingConfig
@@ -232,30 +239,30 @@ def main():
result = trainer.train() result = trainer.train()
if not result.success: if not result.success:
print(f"\nError: Training failed - {result.error}") logger.error("Training failed - %s", result.error)
db.close() db.close()
sys.exit(1) sys.exit(1)
# Print results # Print results
print("\n" + "=" * 60) logger.info("=" * 60)
print("Training Complete") logger.info("Training Complete")
print("=" * 60) logger.info("=" * 60)
print(f"Best model: {result.model_path}") logger.info("Best model: %s", result.model_path)
print(f"Save directory: {result.save_dir}") logger.info("Save directory: %s", result.save_dir)
if result.metrics: if result.metrics:
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}") logger.info("mAP@0.5: %s", result.metrics.get('mAP50', 'N/A'))
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}") logger.info("mAP@0.5-0.95: %s", result.metrics.get('mAP50-95', 'N/A'))
# Validate on test set # Validate on test set
print("\nRunning validation on test set...") logger.info("Running validation on test set...")
if result.model_path: if result.model_path:
config.model_path = result.model_path config.model_path = result.model_path
config.data_yaml = str(data_yaml) config.data_yaml = str(data_yaml)
test_trainer = YOLOTrainer(config=config) test_trainer = YOLOTrainer(config=config)
test_metrics = test_trainer.validate(split='test') test_metrics = test_trainer.validate(split='test')
if test_metrics: if test_metrics:
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}") logger.info("mAP50: %.4f", test_metrics.get('mAP50', 0))
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}") logger.info("mAP50-95: %.4f", test_metrics.get('mAP50-95', 0))
# Close database # Close database
db.close() db.close()

View File

@@ -7,9 +7,14 @@ and comparing the extraction results.
""" """
import argparse import argparse
import logging
import sys import sys
from pathlib import Path from pathlib import Path
from shared.logging_config import setup_cli_logging
logger = logging.getLogger(__name__)
def main(): def main():
@@ -73,6 +78,9 @@ def main():
parser.print_help() parser.print_help()
return return
# Configure logging for CLI
setup_cli_logging()
from backend.validation import LLMValidator from backend.validation import LLMValidator
validator = LLMValidator() validator = LLMValidator()
@@ -104,60 +112,58 @@ def show_stats(validator):
"""Show statistics about failed matches.""" """Show statistics about failed matches."""
stats = validator.get_failed_match_stats() stats = validator.get_failed_match_stats()
print("\n" + "=" * 50) logger.info("=" * 50)
print("Failed Match Statistics") logger.info("Failed Match Statistics")
print("=" * 50) logger.info("=" * 50)
print(f"\nDocuments with failures: {stats['documents_with_failures']}") logger.info("Documents with failures: %d", stats['documents_with_failures'])
print(f"Already validated: {stats['already_validated']}") logger.info("Already validated: %d", stats['already_validated'])
print(f"Remaining to validate: {stats['remaining']}") logger.info("Remaining to validate: %d", stats['remaining'])
print("\nFailures by field:") logger.info("Failures by field:")
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]): for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
print(f" {field}: {count}") logger.info(" %s: %d", field, count)
def validate_single(validator, doc_id: str, provider: str, model: str): def validate_single(validator, doc_id: str, provider: str, model: str):
"""Validate a single document.""" """Validate a single document."""
print(f"\nValidating document: {doc_id}") logger.info("Validating document: %s", doc_id)
print(f"Provider: {provider}, Model: {model or 'default'}") logger.info("Provider: %s, Model: %s", provider, model or 'default')
print()
result = validator.validate_document(doc_id, provider, model) result = validator.validate_document(doc_id, provider, model)
if result.error: if result.error:
print(f"ERROR: {result.error}") logger.error("ERROR: %s", result.error)
return return
print(f"Processing time: {result.processing_time_ms:.0f}ms") logger.info("Processing time: %.0fms", result.processing_time_ms)
print(f"Model used: {result.model_used}") logger.info("Model used: %s", result.model_used)
print("\nExtracted fields:") logger.info("Extracted fields:")
print(f" Invoice Number: {result.invoice_number}") logger.info(" Invoice Number: %s", result.invoice_number)
print(f" Invoice Date: {result.invoice_date}") logger.info(" Invoice Date: %s", result.invoice_date)
print(f" Due Date: {result.invoice_due_date}") logger.info(" Due Date: %s", result.invoice_due_date)
print(f" OCR: {result.ocr_number}") logger.info(" OCR: %s", result.ocr_number)
print(f" Bankgiro: {result.bankgiro}") logger.info(" Bankgiro: %s", result.bankgiro)
print(f" Plusgiro: {result.plusgiro}") logger.info(" Plusgiro: %s", result.plusgiro)
print(f" Amount: {result.amount}") logger.info(" Amount: %s", result.amount)
print(f" Org Number: {result.supplier_organisation_number}") logger.info(" Org Number: %s", result.supplier_organisation_number)
# Show comparison # Show comparison
print("\n" + "-" * 50) logger.info("-" * 50)
print("Comparison with autolabel:") logger.info("Comparison with autolabel:")
comparison = validator.compare_results(doc_id) comparison = validator.compare_results(doc_id)
for field, data in comparison.items(): for field, data in comparison.items():
if data.get('csv_value'): if data.get('csv_value'):
status = "" if data['agreement'] else "" status = "[OK]" if data['agreement'] else "[FAIL]"
auto_status = "matched" if data['autolabel_matched'] else "FAILED" auto_status = "matched" if data['autolabel_matched'] else "FAILED"
print(f" {status} {field}:") logger.info(" %s %s:", status, field)
print(f" CSV: {data['csv_value']}") logger.info(" CSV: %s", data['csv_value'])
print(f" Autolabel: {data['autolabel_text']} ({auto_status})") logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
print(f" LLM: {data['llm_value']}") logger.info(" LLM: %s", data['llm_value'])
def validate_batch(validator, limit: int, provider: str, model: str): def validate_batch(validator, limit: int, provider: str, model: str):
"""Validate a batch of documents.""" """Validate a batch of documents."""
print(f"\nValidating up to {limit} documents with failed matches") logger.info("Validating up to %d documents with failed matches", limit)
print(f"Provider: {provider}, Model: {model or 'default'}") logger.info("Provider: %s, Model: %s", provider, model or 'default')
print()
results = validator.validate_batch( results = validator.validate_batch(
limit=limit, limit=limit,
@@ -171,15 +177,15 @@ def validate_batch(validator, limit: int, provider: str, model: str):
failed = len(results) - success failed = len(results) - success
total_time = sum(r.processing_time_ms or 0 for r in results) total_time = sum(r.processing_time_ms or 0 for r in results)
print("\n" + "=" * 50) logger.info("=" * 50)
print("Validation Complete") logger.info("Validation Complete")
print("=" * 50) logger.info("=" * 50)
print(f"Total: {len(results)}") logger.info("Total: %d", len(results))
print(f"Success: {success}") logger.info("Success: %d", success)
print(f"Failed: {failed}") logger.info("Failed: %d", failed)
print(f"Total time: {total_time/1000:.1f}s") logger.info("Total time: %.1fs", total_time/1000)
if success > 0: if success > 0:
print(f"Avg time: {total_time/success:.0f}ms per document") logger.info("Avg time: %.0fms per document", total_time/success)
def compare_single(validator, doc_id: str): def compare_single(validator, doc_id: str):
@@ -187,23 +193,23 @@ def compare_single(validator, doc_id: str):
comparison = validator.compare_results(doc_id) comparison = validator.compare_results(doc_id)
if 'error' in comparison: if 'error' in comparison:
print(f"Error: {comparison['error']}") logger.error("Error: %s", comparison['error'])
return return
print(f"\nComparison for document: {doc_id}") logger.info("Comparison for document: %s", doc_id)
print("=" * 60) logger.info("=" * 60)
for field, data in comparison.items(): for field, data in comparison.items():
if data.get('csv_value') is None: if data.get('csv_value') is None:
continue continue
status = "" if data['agreement'] else "" status = "[OK]" if data['agreement'] else "[FAIL]"
auto_status = "matched" if data['autolabel_matched'] else "FAILED" auto_status = "matched" if data['autolabel_matched'] else "FAILED"
print(f"\n{status} {field}:") logger.info("%s %s:", status, field)
print(f" CSV value: {data['csv_value']}") logger.info(" CSV value: %s", data['csv_value'])
print(f" Autolabel: {data['autolabel_text']} ({auto_status})") logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
print(f" LLM extracted: {data['llm_value']}") logger.info(" LLM extracted: %s", data['llm_value'])
def compare_all(validator, limit: int): def compare_all(validator, limit: int):
@@ -220,11 +226,11 @@ def compare_all(validator, limit: int):
doc_ids = [row[0] for row in cursor.fetchall()] doc_ids = [row[0] for row in cursor.fetchall()]
if not doc_ids: if not doc_ids:
print("No validated documents found.") logger.info("No validated documents found.")
return return
print(f"\nComparison Summary ({len(doc_ids)} documents)") logger.info("Comparison Summary (%d documents)", len(doc_ids))
print("=" * 80) logger.info("=" * 80)
# Aggregate stats # Aggregate stats
field_stats = {} field_stats = {}
@@ -259,12 +265,12 @@ def compare_all(validator, limit: int):
if not data['autolabel_matched'] and data['agreement']: if not data['autolabel_matched'] and data['agreement']:
stats['llm_correct_auto_wrong'] += 1 stats['llm_correct_auto_wrong'] += 1
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}") logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
print("-" * 80) logger.info("-" * 80)
for field, stats in sorted(field_stats.items()): for field, stats in sorted(field_stats.items()):
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} " logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}") stats['llm_agrees'], stats['llm_correct_auto_wrong'])
def generate_report(validator, output_path: str): def generate_report(validator, output_path: str):
@@ -328,8 +334,8 @@ def generate_report(validator, output_path: str):
with open(output_path, 'w', encoding='utf-8') as f: with open(output_path, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False) json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\nReport generated: {output_path}") logger.info("Report generated: %s", output_path)
print(f"Total validations: {len(validations)}") logger.info("Total validations: %d", len(validations))
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -18,7 +18,8 @@ import numpy as np
from PIL import Image from PIL import Image
from shared.config import DEFAULT_DPI from shared.config import DEFAULT_DPI
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES, CLASS_NAMES
from shared.bbox import expand_bbox
from .annotation_generator import YOLOAnnotation from .annotation_generator import YOLOAnnotation
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -156,7 +157,7 @@ class DBYOLODataset:
# Split items for this split # Split items for this split
instance.items = instance._split_dataset_from_cache() instance.items = instance._split_dataset_from_cache()
print(f"Split '{split}': {len(instance.items)} items") logger.info("Split '%s': %d items", split, len(instance.items))
return instance return instance
@@ -165,7 +166,7 @@ class DBYOLODataset:
# Find all document directories # Find all document directories
temp_dir = self.images_dir / 'temp' temp_dir = self.images_dir / 'temp'
if not temp_dir.exists(): if not temp_dir.exists():
print(f"Temp directory not found: {temp_dir}") logger.warning("Temp directory not found: %s", temp_dir)
return return
# Collect all document IDs with images # Collect all document IDs with images
@@ -182,13 +183,13 @@ class DBYOLODataset:
if images: if images:
doc_image_map[doc_dir.name] = sorted(images) doc_image_map[doc_dir.name] = sorted(images)
print(f"Found {len(doc_image_map)} documents with images") logger.info("Found %d documents with images", len(doc_image_map))
# Query database for all document labels # Query database for all document labels
doc_ids = list(doc_image_map.keys()) doc_ids = list(doc_image_map.keys())
doc_labels = self._load_labels_from_db(doc_ids) doc_labels = self._load_labels_from_db(doc_ids)
print(f"Loaded labels for {len(doc_labels)} documents from database") logger.info("Loaded labels for %d documents from database", len(doc_labels))
# Build dataset items # Build dataset items
all_items: list[DatasetItem] = [] all_items: list[DatasetItem] = []
@@ -227,19 +228,19 @@ class DBYOLODataset:
else: else:
skipped_no_labels += 1 skipped_no_labels += 1
print(f"Total images found: {total_images}") logger.info("Total images found: %d", total_images)
print(f"Images with labels: {len(all_items)}") logger.info("Images with labels: %d", len(all_items))
if skipped_no_db_record > 0: if skipped_no_db_record > 0:
print(f"Skipped {skipped_no_db_record} images (document not in database)") logger.info("Skipped %d images (document not in database)", skipped_no_db_record)
if skipped_no_labels > 0: if skipped_no_labels > 0:
print(f"Skipped {skipped_no_labels} images (no labels for page)") logger.info("Skipped %d images (no labels for page)", skipped_no_labels)
# Cache all items for sharing with other splits # Cache all items for sharing with other splits
self._all_items = all_items self._all_items = all_items
# Split dataset # Split dataset
self.items, self._doc_ids_ordered = self._split_dataset(all_items) self.items, self._doc_ids_ordered = self._split_dataset(all_items)
print(f"Split '{self.split}': {len(self.items)} items") logger.info("Split '%s': %d items", self.split, len(self.items))
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]: def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
""" """
@@ -374,7 +375,7 @@ class DBYOLODataset:
if has_csv_splits: if has_csv_splits:
# Use CSV-defined splits # Use CSV-defined splits
print("Using CSV-defined split field for train/val/test assignment") logger.info("Using CSV-defined split field for train/val/test assignment")
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback) # Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
# 'val' is taken from train set using val_ratio # 'val' is taken from train set using val_ratio
@@ -411,11 +412,11 @@ class DBYOLODataset:
# Apply limit if specified # Apply limit if specified
if self.limit is not None and self.limit < len(split_doc_ids): if self.limit is not None and self.limit < len(split_doc_ids):
split_doc_ids = split_doc_ids[:self.limit] split_doc_ids = split_doc_ids[:self.limit]
print(f"Limited to {self.limit} documents") logger.info("Limited to %d documents", self.limit)
else: else:
# Fall back to random splitting (original behavior) # Fall back to random splitting (original behavior)
print("No CSV split field found, using random splitting") logger.info("No CSV split field found, using random splitting")
random.seed(self.seed) random.seed(self.seed)
random.shuffle(doc_ids) random.shuffle(doc_ids)
@@ -423,7 +424,7 @@ class DBYOLODataset:
# Apply limit if specified (before splitting) # Apply limit if specified (before splitting)
if self.limit is not None and self.limit < len(doc_ids): if self.limit is not None and self.limit < len(doc_ids):
doc_ids = doc_ids[:self.limit] doc_ids = doc_ids[:self.limit]
print(f"Limited to {self.limit} documents") logger.info("Limited to %d documents", self.limit)
# Calculate split indices # Calculate split indices
n_total = len(doc_ids) n_total = len(doc_ids)
@@ -549,6 +550,8 @@ class DBYOLODataset:
""" """
Convert annotations to normalized YOLO format. Convert annotations to normalized YOLO format.
Uses field-specific bbox expansion strategies via expand_bbox.
Args: Args:
annotations: List of annotations annotations: List of annotations
img_width: Actual image width in pixels img_width: Actual image width in pixels
@@ -568,26 +571,43 @@ class DBYOLODataset:
labels = [] labels = []
for ann in annotations: for ann in annotations:
# Convert to pixels (if needed) # Convert center+size to corner coords in PDF points
x_center_px = ann.x_center * scale half_w = ann.width / 2
y_center_px = ann.y_center * scale half_h = ann.height / 2
width_px = ann.width * scale x0_pdf = ann.x_center - half_w
height_px = ann.height * scale y0_pdf = ann.y_center - half_h
x1_pdf = ann.x_center + half_w
y1_pdf = ann.y_center + half_h
# Add padding # Convert to pixels
pad = self.bbox_padding_px x0_px = x0_pdf * scale
width_px += 2 * pad y0_px = y0_pdf * scale
height_px += 2 * pad x1_px = x1_pdf * scale
y1_px = y1_pdf * scale
# Get class name for field-specific expansion
class_name = CLASS_NAMES[ann.class_id]
# Apply field-specific bbox expansion
x0, y0, x1, y1 = expand_bbox(
bbox=(x0_px, y0_px, x1_px, y1_px),
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# Ensure minimum height # Ensure minimum height
height_px = y1 - y0
if height_px < self.min_bbox_height_px: if height_px < self.min_bbox_height_px:
height_px = self.min_bbox_height_px extra = (self.min_bbox_height_px - height_px) / 2
y0 = max(0, int(y0 - extra))
y1 = min(img_height, int(y1 + extra))
# Normalize to 0-1 # Convert to YOLO format (normalized center + size)
x_center = x_center_px / img_width x_center = (x0 + x1) / 2 / img_width
y_center = y_center_px / img_height y_center = (y0 + y1) / 2 / img_height
width = width_px / img_width width = (x1 - x0) / img_width
height = height_px / img_height height = (y1 - y0) / img_height
# Clamp to valid range # Clamp to valid range
x_center = max(0, min(1, x_center)) x_center = max(0, min(1, x_center))
@@ -675,7 +695,7 @@ class DBYOLODataset:
count += 1 count += 1
print(f"Exported {count} items to {output_dir / split_name}") logger.info("Exported %d items to %s", count, output_dir / split_name)
return count return count
@@ -706,7 +726,7 @@ def create_datasets(
Dict with 'train', 'val', 'test' datasets Dict with 'train', 'val', 'test' datasets
""" """
# Create first dataset which loads all data # Create first dataset which loads all data
print("Loading dataset (this may take a few minutes for large datasets)...") logger.info("Loading dataset (this may take a few minutes for large datasets)...")
first_dataset = DBYOLODataset( first_dataset = DBYOLODataset(
images_dir=images_dir, images_dir=images_dir,
db=db, db=db,

View File

@@ -0,0 +1,251 @@
"""Tests for db_dataset.py expand_bbox integration."""
import numpy as np
import pytest
from unittest.mock import MagicMock, patch
from pathlib import Path
from training.yolo.db_dataset import DBYOLODataset
from training.yolo.annotation_generator import YOLOAnnotation
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
from shared.fields import CLASS_NAMES
class TestConvertLabelsWithExpandBbox:
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
def test_convert_labels_uses_expand_bbox(self):
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
# Create a mock dataset without loading from DB
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Create annotation for bankgiro (has extra_left_ratio)
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
# center: (150, 225), width: 100, height: 50
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro
x_center=150, # in PDF points
y_center=225,
width=100,
height=50,
confidence=0.9
)
]
# Image size in pixels (at 300 DPI)
img_width = 2480 # A4 width at 300 DPI
img_height = 3508 # A4 height at 300 DPI
# Convert labels
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# Should have one label
assert labels.shape == (1, 5)
# Check class_id
assert labels[0, 0] == 4
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
# Original bbox at 300 DPI:
# x0 = 100 * (300/72) = 416.67
# y0 = 200 * (300/72) = 833.33
# x1 = 200 * (300/72) = 833.33
# y1 = 250 * (300/72) = 1041.67
# width_px = 416.67, height_px = 208.33
# After expand_bbox with bankgiro strategy:
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
# The x_center should shift left due to extra_left_ratio
x_center = labels[0, 1]
y_center = labels[0, 2]
width = labels[0, 3]
height = labels[0, 4]
# Verify normalized values are in valid range
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 < width <= 1
assert 0 < height <= 1
# Width should be larger than original due to scaling and extra_left
# Original normalized width: 416.67 / 2480 = 0.168
# After bankgiro expansion it should be wider
assert width > 0.168
def test_convert_labels_different_field_types(self):
"""Verify different field types use their specific strategies."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
img_width = 2480
img_height = 3508
# Same bbox for different field types
base_annotation = {
'x_center': 150,
'y_center': 225,
'width': 100,
'height': 50,
'confidence': 0.9
}
# OCR number (class_id=3) - has extra_top_ratio=0.60
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
# Amount (class_id=6) - has extra_right_ratio=0.30
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
# Each field type should have different expansion
# OCR should expand more vertically (extra_top)
# Bankgiro should expand more to the left
# Amount should expand more to the right
# OCR: extra_top shifts y_center up
# Bankgiro: extra_left shifts x_center left
# So bankgiro x_center < OCR x_center
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
# OCR has higher scale_y (1.80) than amount (1.35)
assert ocr_labels[0, 4] > amount_labels[0, 4]
def test_convert_labels_clamps_to_image_bounds(self):
"""Verify labels are clamped to image boundaries."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Annotation near edge of image (in PDF points)
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro - will expand left
x_center=30, # Very close to left edge
y_center=50,
width=40,
height=30,
confidence=0.9
)
]
img_width = 2480
img_height = 3508
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# All values should be in valid range
assert 0 <= labels[0, 1] <= 1 # x_center
assert 0 <= labels[0, 2] <= 1 # y_center
assert 0 < labels[0, 3] <= 1 # width
assert 0 < labels[0, 4] <= 1 # height
def test_convert_labels_empty_annotations(self):
"""Verify empty annotations return empty array."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
labels = dataset._convert_labels([], 2480, 3508, is_scanned=False)
assert labels.shape == (0, 5)
assert labels.dtype == np.float32
def test_convert_labels_minimum_height(self):
"""Verify minimum height is enforced after expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 50 # Higher minimum
# Very small annotation
annotations = [
YOLOAnnotation(
class_id=9, # payment_line - minimal expansion
x_center=100,
y_center=100,
width=200,
height=5, # Very small height
confidence=0.9
)
]
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
# Height should be at least min_bbox_height_px / img_height
min_normalized_height = 50 / 3508
assert labels[0, 4] >= min_normalized_height
class TestCreateAnnotationWithClassName:
"""Tests for _create_annotation storing class_name for expand_bbox lookup."""
def test_create_annotation_stores_class_name(self):
"""Verify _create_annotation stores class_name for later use."""
dataset = object.__new__(DBYOLODataset)
# Create annotation for invoice_number
annotation = dataset._create_annotation(
field_name="InvoiceNumber",
bbox=[100, 200, 200, 250],
score=0.9
)
assert annotation.class_id == 0 # invoice_number class_id
class TestLoadLabelsFromDbWithClassName:
"""Tests for _load_labels_from_db preserving field_name for expansion."""
def test_load_labels_maps_field_names_correctly(self):
"""Verify field names are mapped correctly for expand_bbox."""
dataset = object.__new__(DBYOLODataset)
dataset.min_confidence = 0.7
# Mock database
mock_db = MagicMock()
mock_db.get_documents_batch.return_value = {
'doc1': {
'success': True,
'pdf_type': 'text',
'split': 'train',
'field_results': [
{
'matched': True,
'field_name': 'Bankgiro',
'score': 0.9,
'bbox': [100, 200, 200, 250],
'page_no': 0
},
{
'matched': True,
'field_name': 'supplier_accounts(Plusgiro)',
'score': 0.85,
'bbox': [300, 400, 400, 450],
'page_no': 0
}
]
}
}
dataset.db = mock_db
result = dataset._load_labels_from_db(['doc1'])
assert 'doc1' in result
page_labels, is_scanned, csv_split = result['doc1']
# Should have 2 annotations on page 0
assert 0 in page_labels
assert len(page_labels[0]) == 2
# First annotation: Bankgiro (class_id=4)
assert page_labels[0][0].class_id == 4
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
assert page_labels[0][1].class_id == 5

View File

@@ -0,0 +1,367 @@
"""
Tests for Training Export with expand_bbox integration.
Tests the export endpoint's integration with field-specific bbox expansion.
"""
import pytest
from unittest.mock import MagicMock, patch
from uuid import uuid4
from shared.bbox import expand_bbox
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
class TestExpandBboxForExport:
"""Tests for expand_bbox integration in export workflow."""
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
"""Verify expand_bbox works with pixel-to-normalized conversion."""
# Annotation stored as normalized coords
x_center_norm = 0.5
y_center_norm = 0.5
width_norm = 0.1
height_norm = 0.05
# Image dimensions
img_width = 2480 # A4 at 300 DPI
img_height = 3508
# Convert to pixel coords
x_center_px = x_center_norm * img_width
y_center_px = y_center_norm * img_height
width_px = width_norm * img_width
height_px = height_norm * img_height
# Convert to corner coords
x0 = x_center_px - width_px / 2
y0 = y_center_px - height_px / 2
x1 = x_center_px + width_px / 2
y1 = y_center_px + height_px / 2
# Apply expansion
class_name = "invoice_number"
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# Verify expanded bbox is larger
assert ex0 < x0 # Left expanded
assert ey0 < y0 # Top expanded
assert ex1 > x1 # Right expanded
assert ey1 > y1 # Bottom expanded
# Convert back to normalized
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Verify valid normalized coords
assert 0 <= new_x_center <= 1
assert 0 <= new_y_center <= 1
assert 0 <= new_width <= 1
assert 0 <= new_height <= 1
def test_expand_bbox_manual_mode_minimal_expansion(self):
"""Verify manual annotations use minimal expansion."""
# Small bbox
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
# Auto mode (field-specific expansion)
auto_result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
manual_mode=False,
)
# Manual mode (minimal expansion)
manual_result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
manual_mode=True,
)
# Auto expansion should be larger than manual
auto_width = auto_result[2] - auto_result[0]
manual_width = manual_result[2] - manual_result[0]
assert auto_width > manual_width
auto_height = auto_result[3] - auto_result[1]
manual_height = manual_result[3] - manual_result[1]
assert auto_height > manual_height
def test_expand_bbox_different_sources_use_correct_mode(self):
"""Verify different annotation sources use correct expansion mode."""
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
# Define source to manual_mode mapping
source_mode_mapping = {
"manual": True, # Manual annotations -> minimal expansion
"auto": False, # Auto-labeled -> field-specific expansion
"imported": True, # Imported (from CSV) -> minimal expansion
}
results = {}
for source, manual_mode in source_mode_mapping.items():
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="ocr_number",
manual_mode=manual_mode,
)
results[source] = result
# Auto should have largest expansion
auto_area = (results["auto"][2] - results["auto"][0]) * \
(results["auto"][3] - results["auto"][1])
manual_area = (results["manual"][2] - results["manual"][0]) * \
(results["manual"][3] - results["manual"][1])
imported_area = (results["imported"][2] - results["imported"][0]) * \
(results["imported"][3] - results["imported"][1])
assert auto_area > manual_area
assert auto_area > imported_area
# Manual and imported should be the same (both use minimal mode)
assert manual_area == imported_area
def test_expand_bbox_all_field_types_work(self):
"""Verify expand_bbox works for all field types."""
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
for class_name in CLASS_NAMES:
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# Verify result is a valid bbox
assert len(result) == 4
x0, y0, x1, y1 = result
assert x0 >= 0
assert y0 >= 0
assert x1 <= img_width
assert y1 <= img_height
assert x1 > x0
assert y1 > y0
class TestExportAnnotationExpansion:
"""Tests for annotation expansion in export workflow."""
def test_annotation_bbox_conversion_workflow(self):
"""Test full annotation bbox conversion workflow."""
# Simulate stored annotation (normalized coords)
class MockAnnotation:
class_id = FIELD_CLASS_IDS["invoice_number"]
class_name = "invoice_number"
x_center = 0.3
y_center = 0.2
width = 0.15
height = 0.03
source = "auto"
ann = MockAnnotation()
img_width = 2480
img_height = 3508
# Step 1: Convert normalized to pixel corner coords
half_w = (ann.width * img_width) / 2
half_h = (ann.height * img_height) / 2
x0 = ann.x_center * img_width - half_w
y0 = ann.y_center * img_height - half_h
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Step 2: Determine manual_mode based on source
manual_mode = ann.source in ("manual", "imported")
# Step 3: Apply expand_bbox
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Step 4: Convert back to normalized
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Verify expansion happened (auto mode)
assert new_width > ann.width
assert new_height > ann.height
# Verify valid YOLO format
assert 0 <= new_x_center <= 1
assert 0 <= new_y_center <= 1
assert 0 < new_width <= 1
assert 0 < new_height <= 1
def test_export_applies_expansion_to_each_annotation(self):
"""Test that export applies expansion to each annotation."""
# Simulate multiple annotations with different sources
# Use smaller bboxes so manual mode padding has visible effect
annotations = [
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
{"class_name": "amount", "source": "imported", "x_center": 0.7, "y_center": 0.5, "width": 0.05, "height": 0.02},
]
img_width = 2480
img_height = 3508
expanded_annotations = []
for ann in annotations:
# Convert to pixel coords
half_w = (ann["width"] * img_width) / 2
half_h = (ann["height"] * img_height) / 2
x0 = ann["x_center"] * img_width - half_w
y0 = ann["y_center"] * img_height - half_h
x1 = ann["x_center"] * img_width + half_w
y1 = ann["y_center"] * img_height + half_h
# Determine manual_mode
manual_mode = ann["source"] in ("manual", "imported")
# Apply expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann["class_name"],
manual_mode=manual_mode,
)
# Convert back to normalized
expanded_annotations.append({
"class_name": ann["class_name"],
"source": ann["source"],
"x_center": (ex0 + ex1) / 2 / img_width,
"y_center": (ey0 + ey1) / 2 / img_height,
"width": (ex1 - ex0) / img_width,
"height": (ey1 - ey0) / img_height,
})
# Verify auto-labeled annotation expanded more than manual/imported
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
# Auto mode should expand more than manual mode
# (auto has larger scale factors and max_pad)
assert auto_ann["width"] > manual_ann["width"]
assert auto_ann["height"] > manual_ann["height"]
# All annotations should be expanded (at least slightly for manual mode)
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
# Width and height should be >= original (expansion or equal, with small tolerance)
tolerance = 0.01 # 1% tolerance for integer rounding
assert exp["width"] >= orig["width"] * (1 - tolerance), \
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
assert exp["height"] >= orig["height"] * (1 - tolerance), \
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
class TestExpandBboxEdgeCases:
"""Tests for edge cases in export bbox expansion."""
def test_bbox_at_image_edge_left(self):
"""Test bbox at left edge of image."""
bbox = (0, 100, 50, 150)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
# Left edge should be clamped to 0
assert result[0] >= 0
def test_bbox_at_image_edge_right(self):
"""Test bbox at right edge of image."""
bbox = (2400, 100, 2480, 150)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
# Right edge should be clamped to image width
assert result[2] <= img_width
def test_bbox_at_image_edge_top(self):
"""Test bbox at top edge of image."""
bbox = (100, 0, 200, 50)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
# Top edge should be clamped to 0
assert result[1] >= 0
def test_bbox_at_image_edge_bottom(self):
"""Test bbox at bottom edge of image."""
bbox = (100, 3400, 200, 3508)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
# Bottom edge should be clamped to image height
assert result[3] <= img_height
def test_very_small_bbox(self):
"""Test very small bbox gets expanded."""
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
# Should still produce a valid expanded bbox
assert result[2] > result[0]
assert result[3] > result[1]