WIP
This commit is contained in:
@@ -109,7 +109,9 @@
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")",
|
||||
"Skill(tdd)",
|
||||
"Skill(tdd:*)"
|
||||
"Skill(tdd:*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m training.cli.train --model runs/train/invoice_fields/weights/best.pt --device 0 --epochs 100\")",
|
||||
"Bash(git commit -m \"$\\(cat <<''EOF''\nfeat: add field-specific bbox expansion strategies for YOLO training\n\nImplement center-point based bbox scaling with directional compensation\nto capture field labels that typically appear above or to the left of\nfield values. This improves YOLO training data quality by including\ncontextual information around field values.\n\nKey changes:\n- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function\n- Define field-specific strategies \\(ocr_number, bankgiro, invoice_date, etc.\\)\n- Support manual_mode for minimal padding \\(no scaling\\)\n- Integrate expand_bbox into AnnotationGenerator\n- Add FIELD_TO_CLASS mapping for field_name to class_name lookup\n- Comprehensive tests with 100% coverage \\(45 tests\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -66,10 +70,13 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
logger.error("Model not found: %s", model_path)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
@@ -79,16 +86,16 @@ def main():
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
logger.error("Input not found: %s", input_path)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
logger.error("No PDF files found")
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
logger.info("Processing %d PDF file(s)", len(pdf_files))
|
||||
logger.info("Model: %s", model_path)
|
||||
|
||||
from backend.pipeline import InferencePipeline
|
||||
|
||||
@@ -107,18 +114,18 @@ def main():
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
logger.info("Processing: %s", pdf_path.name)
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
logger.info(" Success: %s", result.success)
|
||||
logger.info(" Fields: %d", len(result.fields))
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
logger.info(" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
logger.info(" Errors: %s", result.errors)
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
@@ -132,9 +139,11 @@ def main():
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
logger.info("Results written to: %s", args.output)
|
||||
else:
|
||||
print(json_output)
|
||||
# Output JSON to stdout (not logged)
|
||||
sys.stdout.write(json_output)
|
||||
sys.stdout.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -417,7 +417,12 @@ class InferencePipeline:
|
||||
result.errors.append(f"Business feature extraction error: {error_detail}")
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
"""Merge extracted fields, keeping best candidate for each field.
|
||||
|
||||
Selection priority:
|
||||
1. Prefer candidates without validation errors
|
||||
2. Among equal validity, prefer higher confidence
|
||||
"""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
for extracted in result.extracted_fields:
|
||||
@@ -430,7 +435,12 @@ class InferencePipeline:
|
||||
|
||||
# Select best candidate for each field
|
||||
for field_name, candidates in field_candidates.items():
|
||||
best = max(candidates, key=lambda x: x.confidence)
|
||||
# Sort by: (no validation error, confidence) - descending
|
||||
# This prefers candidates without errors, then by confidence
|
||||
best = max(
|
||||
candidates,
|
||||
key=lambda x: (x.validation_error is None, x.confidence)
|
||||
)
|
||||
result.fields[field_name] = best.normalized_value
|
||||
result.confidence[field_name] = best.confidence
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
|
||||
@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
|
||||
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
|
||||
docs = self.get_documents_with_failed_matches(limit=limit)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(docs)} documents with failed matches to validate")
|
||||
logger.info("Found %d documents with failed matches to validate", len(docs))
|
||||
|
||||
results = []
|
||||
for i, doc in enumerate(docs):
|
||||
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
|
||||
|
||||
if verbose:
|
||||
failed_fields = [f['field'] for f in doc['failed_fields']]
|
||||
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
||||
logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
|
||||
|
||||
result = self.validate_document(doc_id, provider, model)
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
if result.error:
|
||||
print(f" ERROR: {result.error}")
|
||||
logger.error(" ERROR: %s", result.error)
|
||||
else:
|
||||
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
||||
logger.info(" OK (%.0fms)", result.processing_time_ms)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.web.schemas.admin import (
|
||||
ExportResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from shared.bbox import expand_bbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -102,12 +103,52 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Get image dimensions for bbox expansion
|
||||
img_dims = storage.get_admin_image_dimensions(doc_id, page_num)
|
||||
if img_dims is None:
|
||||
# Fall back to standard A4 at 300 DPI if dimensions unavailable
|
||||
img_width, img_height = 2480, 3508
|
||||
else:
|
||||
img_width, img_height = img_dims
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
# Convert normalized coords to pixel coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
y0 = ann.y_center * img_height - half_h
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Use manual_mode for manual/imported annotations
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized YOLO format
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
new_x_center = max(0, min(1, new_x_center))
|
||||
new_y_center = max(0, min(1, new_y_center))
|
||||
new_width = max(0, min(1, new_width))
|
||||
new_height = max(0, min(1, new_height))
|
||||
|
||||
line = f"{ann.class_id} {new_x_center:.6f} {new_y_center:.6f} {new_width:.6f} {new_height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
|
||||
62
packages/shared/shared/logging_config.py
Normal file
62
packages/shared/shared/logging_config.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Logging Configuration
|
||||
|
||||
Provides consistent logging setup for CLI tools and modules.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_cli_logging(
|
||||
level: int = logging.INFO,
|
||||
name: Optional[str] = None,
|
||||
format_string: Optional[str] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure logging for CLI applications.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: INFO)
|
||||
name: Logger name (default: root logger)
|
||||
format_string: Custom format string (default: simple CLI format)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if format_string is None:
|
||||
format_string = "%(message)s"
|
||||
|
||||
# Configure root logger or specific logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
logger.handlers.clear()
|
||||
|
||||
# Create console handler
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(logging.Formatter(format_string))
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def setup_verbose_logging(
|
||||
level: int = logging.DEBUG,
|
||||
name: Optional[str] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure verbose logging with timestamps and module info.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: DEBUG)
|
||||
name: Logger name (default: root logger)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
return setup_cli_logging(level=level, name=name, format_string=format_string)
|
||||
@@ -9,6 +9,7 @@ Now reads from PostgreSQL database instead of JSONL files.
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,6 +17,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from shared.normalize import normalize_field
|
||||
from shared.matcher import FieldMatcher
|
||||
@@ -104,7 +108,7 @@ class LabelAnalyzer:
|
||||
for row in reader:
|
||||
doc_id = row['DocumentId']
|
||||
self.csv_data[doc_id] = row
|
||||
print(f"Loaded {len(self.csv_data)} records from CSV")
|
||||
logger.info("Loaded %d records from CSV", len(self.csv_data))
|
||||
|
||||
def load_labels(self):
|
||||
"""Load all label files from dataset."""
|
||||
@@ -150,12 +154,12 @@ class LabelAnalyzer:
|
||||
for doc in self.label_data.values()
|
||||
for labels in doc['pages'].values()
|
||||
)
|
||||
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
|
||||
logger.info("Loaded labels for %d documents (%d total labels)", total_docs, total_labels)
|
||||
|
||||
def load_report(self):
|
||||
"""Load autolabel report from database."""
|
||||
if not self.db:
|
||||
print("Database not configured, skipping report loading")
|
||||
logger.info("Database not configured, skipping report loading")
|
||||
return
|
||||
|
||||
# Get document IDs from CSV to query
|
||||
@@ -175,7 +179,7 @@ class LabelAnalyzer:
|
||||
self.report_data[doc_id] = doc
|
||||
loaded += 1
|
||||
|
||||
print(f"Loaded {loaded} autolabel reports from database")
|
||||
logger.info("Loaded %d autolabel reports from database", loaded)
|
||||
|
||||
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
|
||||
"""Analyze a single document."""
|
||||
@@ -373,7 +377,7 @@ class LabelAnalyzer:
|
||||
break
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} documents without PDF files")
|
||||
logger.info("Skipped %d documents without PDF files", skipped)
|
||||
|
||||
return results
|
||||
|
||||
@@ -447,7 +451,7 @@ class LabelAnalyzer:
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport saved to: {output}")
|
||||
logger.info("Report saved to: %s", output)
|
||||
|
||||
return report
|
||||
|
||||
@@ -456,52 +460,52 @@ def print_summary(report: dict):
|
||||
"""Print summary to console."""
|
||||
summary = report['summary']
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LABEL ANALYSIS SUMMARY")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("LABEL ANALYSIS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
|
||||
print(f"\nDocuments:")
|
||||
print(f" Total: {summary['total_documents']}")
|
||||
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
|
||||
logger.info("Documents:")
|
||||
logger.info(" Total: %d", summary['total_documents'])
|
||||
logger.info(" With issues: %d (%s)", summary['documents_with_issues'], summary['issue_rate'])
|
||||
|
||||
print(f"\nFields:")
|
||||
print(f" Expected: {summary['total_expected_fields']}")
|
||||
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
|
||||
print(f" Missing: {summary['missing_labels']}")
|
||||
print(f" Extra: {summary['extra_labels']}")
|
||||
logger.info("Fields:")
|
||||
logger.info(" Expected: %d", summary['total_expected_fields'])
|
||||
logger.info(" Labeled: %d (%s)", summary['total_labeled_fields'], summary['label_coverage'])
|
||||
logger.info(" Missing: %d", summary['missing_labels'])
|
||||
logger.info(" Extra: %d", summary['extra_labels'])
|
||||
|
||||
print(f"\nFailure Reasons:")
|
||||
logger.info("Failure Reasons:")
|
||||
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
logger.info(" %s: %d", reason, count)
|
||||
|
||||
print(f"\nFailures by Field:")
|
||||
logger.info("Failures by Field:")
|
||||
for field, reasons in report['failures_by_field'].items():
|
||||
total = sum(reasons.values())
|
||||
print(f" {field}: {total}")
|
||||
logger.info(" %s: %d", field, total)
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
logger.info(" - %s: %d", reason, count)
|
||||
|
||||
# Show sample issues
|
||||
if report['issues']:
|
||||
print(f"\n" + "-" * 60)
|
||||
print("SAMPLE ISSUES (first 10)")
|
||||
print("-" * 60)
|
||||
logger.info("-" * 60)
|
||||
logger.info("SAMPLE ISSUES (first 10)")
|
||||
logger.info("-" * 60)
|
||||
|
||||
for issue in report['issues'][:10]:
|
||||
print(f"\n[{issue['doc_id']}] {issue['field']}")
|
||||
print(f" CSV value: {issue['csv_value']}")
|
||||
print(f" Reason: {issue['reason']}")
|
||||
logger.info("[%s] %s", issue['doc_id'], issue['field'])
|
||||
logger.info(" CSV value: %s", issue['csv_value'])
|
||||
logger.info(" Reason: %s", issue['reason'])
|
||||
|
||||
if issue.get('details'):
|
||||
details = issue['details']
|
||||
if details.get('normalized_candidates'):
|
||||
print(f" Candidates: {details['normalized_candidates'][:5]}")
|
||||
logger.info(" Candidates: %s", details['normalized_candidates'][:5])
|
||||
if details.get('pdf_tokens_sample'):
|
||||
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
|
||||
logger.info(" PDF samples: %s", details['pdf_tokens_sample'][:5])
|
||||
if details.get('potential_matches'):
|
||||
print(f" Potential matches:")
|
||||
logger.info(" Potential matches:")
|
||||
for pm in details['potential_matches'][:3]:
|
||||
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
|
||||
logger.info(" - token='%s' matches candidate='%s'", pm['token'], pm['candidate'])
|
||||
|
||||
|
||||
def main():
|
||||
@@ -551,6 +555,9 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
analyzer = LabelAnalyzer(
|
||||
csv_path=args.csv,
|
||||
pdf_dir=args.pdf_dir,
|
||||
@@ -566,30 +573,30 @@ def main():
|
||||
|
||||
analysis = analyzer.analyze_document(args.single)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Document: {analysis.doc_id}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"PDF exists: {analysis.pdf_exists}")
|
||||
print(f"PDF type: {analysis.pdf_type}")
|
||||
print(f"Pages: {analysis.total_pages}")
|
||||
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Document: %s", analysis.doc_id)
|
||||
logger.info("=" * 60)
|
||||
logger.info("PDF exists: %s", analysis.pdf_exists)
|
||||
logger.info("PDF type: %s", analysis.pdf_type)
|
||||
logger.info("Pages: %d", analysis.total_pages)
|
||||
logger.info("Fields (CSV: %d, Labeled: %d):", analysis.csv_fields_count, analysis.labeled_fields_count)
|
||||
|
||||
for f in analysis.fields:
|
||||
status = "✓" if f.labeled else ("✗" if f.expected else "-")
|
||||
status = "[OK]" if f.labeled else ("[FAIL]" if f.expected else "[-]")
|
||||
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
|
||||
print(f" [{status}] {f.field_name}: {value_str}")
|
||||
logger.info(" %s %s: %s", status, f.field_name, value_str)
|
||||
|
||||
if f.failure_reason:
|
||||
print(f" Reason: {f.failure_reason}")
|
||||
logger.info(" Reason: %s", f.failure_reason)
|
||||
if f.details.get('normalized_candidates'):
|
||||
print(f" Candidates: {f.details['normalized_candidates']}")
|
||||
logger.info(" Candidates: %s", f.details['normalized_candidates'])
|
||||
if f.details.get('potential_matches'):
|
||||
print(f" Potential matches in PDF:")
|
||||
logger.info(" Potential matches in PDF:")
|
||||
for pm in f.details['potential_matches'][:3]:
|
||||
print(f" - '{pm['token']}'")
|
||||
logger.info(" - '%s'", pm['token'])
|
||||
else:
|
||||
# Full analysis
|
||||
print("Running label analysis...")
|
||||
logger.info("Running label analysis...")
|
||||
results = analyzer.run_analysis(limit=args.limit)
|
||||
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
|
||||
print_summary(report)
|
||||
|
||||
@@ -7,11 +7,15 @@ Generates statistics and insights from database or autolabel_report.jsonl
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
@@ -147,9 +151,9 @@ def load_reports_from_file(report_path: str) -> list[dict]:
|
||||
if not report_files:
|
||||
return []
|
||||
|
||||
print(f"Reading {len(report_files)} report file(s):")
|
||||
logger.info("Reading %d report file(s):", len(report_files))
|
||||
for f in report_files:
|
||||
print(f" - {f.name}")
|
||||
logger.info(" - %s", f.name)
|
||||
|
||||
reports = []
|
||||
for report_file in report_files:
|
||||
@@ -231,55 +235,55 @@ def analyze_reports(reports: list[dict]) -> dict:
|
||||
|
||||
def print_report(stats: dict, verbose: bool = False):
|
||||
"""Print analysis report."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AUTO-LABEL REPORT ANALYSIS")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("AUTO-LABEL REPORT ANALYSIS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Overall stats
|
||||
print(f"\n{'OVERALL STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "OVERALL STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
total = stats['total']
|
||||
successful = stats['successful']
|
||||
failed = stats['failed']
|
||||
success_rate = successful / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"Total documents: {total:>8}")
|
||||
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
|
||||
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
|
||||
logger.info("Total documents: %8d", total)
|
||||
logger.info("Successful: %8d (%.1f%%)", successful, success_rate)
|
||||
logger.info("Failed: %8d (%.1f%%)", failed, 100-success_rate)
|
||||
|
||||
# Processing time
|
||||
if 'processing_time_stats' in stats:
|
||||
pts = stats['processing_time_stats']
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {pts['avg_ms']:>8.1f}")
|
||||
print(f" Min: {pts['min_ms']:>8.1f}")
|
||||
print(f" Max: {pts['max_ms']:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", pts['avg_ms'])
|
||||
logger.info(" Min: %8.1f", pts['min_ms'])
|
||||
logger.info(" Max: %8.1f", pts['max_ms'])
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {avg_time:>8.1f}")
|
||||
print(f" Min: {min_time:>8.1f}")
|
||||
print(f" Max: {max_time:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", avg_time)
|
||||
logger.info(" Min: %8.1f", min_time)
|
||||
logger.info(" Max: %8.1f", max_time)
|
||||
|
||||
# By PDF type
|
||||
print(f"\n{'BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-15s %10s %10s %10s", 'Type', 'Total', 'Success', 'Rate')
|
||||
logger.info("-" * 60)
|
||||
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
|
||||
type_total = type_stats['total']
|
||||
type_success = type_stats['successful']
|
||||
type_rate = type_success / type_total * 100 if type_total > 0 else 0
|
||||
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
|
||||
logger.info("%-15s %10d %10d %9.1f%%", pdf_type, type_total, type_success, type_rate)
|
||||
|
||||
# By field
|
||||
print(f"\n{'FIELD MATCH STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-18s %7s %7s %7s %7s %7s %8s", 'Field', 'Total', 'Match', 'Rate', 'Exact', 'Flex', 'AvgScore')
|
||||
logger.info("-" * 60)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -299,16 +303,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
else:
|
||||
avg_score = 0
|
||||
|
||||
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
|
||||
logger.info("%-18s %7d %7d %6.1f%% %7d %7d %8.3f", field_name, total, matched, rate, exact, flex, avg_score)
|
||||
|
||||
# Field match by PDF type
|
||||
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
|
||||
for pdf_type in sorted(stats['by_pdf_type'].keys()):
|
||||
print(f"\n[{pdf_type.upper()}]")
|
||||
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
|
||||
print("-" * 50)
|
||||
logger.info("[%s]", pdf_type.upper())
|
||||
logger.info("%-18s %10s %10s %10s", 'Field', 'Total', 'Matched', 'Rate')
|
||||
logger.info("-" * 50)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -317,16 +321,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
total = type_stats['total']
|
||||
matched = type_stats['matched']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
|
||||
logger.info("%-18s %10d %10d %9.1f%%", field_name, total, matched, rate)
|
||||
|
||||
# Errors
|
||||
if stats.get('errors') and verbose:
|
||||
print(f"\n{'ERRORS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "ERRORS".center(60))
|
||||
logger.info("-" * 60)
|
||||
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
|
||||
print(f"{count:>5}x {error[:50]}")
|
||||
logger.info("%5dx %s", count, error[:50])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
def export_json(stats: dict, output_path: str):
|
||||
@@ -372,7 +376,7 @@ def export_json(stats: dict, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nStatistics exported to: {output_path}")
|
||||
logger.info("Statistics exported to: %s", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -401,25 +405,28 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Decide source
|
||||
use_db = not args.from_file and args.report is None
|
||||
|
||||
if use_db:
|
||||
print("Loading statistics from database...")
|
||||
logger.info("Loading statistics from database...")
|
||||
stats = load_reports_from_db()
|
||||
print(f"Loaded stats for {stats['total']} documents")
|
||||
logger.info("Loaded stats for %d documents", stats['total'])
|
||||
else:
|
||||
report_path = args.report or 'reports/autolabel_report.jsonl'
|
||||
path = Path(report_path)
|
||||
|
||||
# Check if file exists (handle glob patterns)
|
||||
if '*' not in str(path) and '?' not in str(path) and not path.exists():
|
||||
print(f"Error: Report file not found: {path}")
|
||||
logger.error("Report file not found: %s", path)
|
||||
return 1
|
||||
|
||||
print(f"Loading reports from: {report_path}")
|
||||
logger.info("Loading reports from: %s", report_path)
|
||||
reports = load_reports_from_file(report_path)
|
||||
print(f"Loaded {len(reports)} reports")
|
||||
logger.info("Loaded %d reports", len(reports))
|
||||
stats = analyze_reports(reports)
|
||||
|
||||
print_report(stats, verbose=args.verbose)
|
||||
|
||||
@@ -6,6 +6,7 @@ Generates YOLO training data from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
@@ -17,6 +18,10 @@ from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
import multiprocessing
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag for graceful shutdown
|
||||
_shutdown_requested = False
|
||||
|
||||
@@ -25,8 +30,8 @@ def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals for graceful shutdown."""
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||
print("(Press Ctrl+C again to force quit)\n")
|
||||
logger.warning("Shutdown requested. Finishing current batch and saving progress...")
|
||||
logger.warning("(Press Ctrl+C again to force quit)")
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
@@ -350,11 +355,14 @@ def main():
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
logger.info("Found %d CSV file(s) to process", len(csv_files))
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
@@ -371,7 +379,7 @@ def main():
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
db.create_tables() # Ensure tables exist
|
||||
print("Connected to database for status checking")
|
||||
logger.info("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
@@ -443,7 +451,7 @@ def main():
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
logger.error("Error processing %s: %s", doc_id, error)
|
||||
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
@@ -453,7 +461,7 @@ def main():
|
||||
from training.processing import DualPoolCoordinator
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
logger.info("Starting dual-pool mode: %d CPU + %d GPU workers", args.cpu_workers, args.gpu_workers)
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
@@ -467,10 +475,10 @@ def main():
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
# Check for shutdown request
|
||||
if _shutdown_requested:
|
||||
print("\nShutdown requested. Stopping after current batch...")
|
||||
logger.warning("Shutdown requested. Stopping after current batch...")
|
||||
break
|
||||
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
logger.info("[%d/%d] Processing: %s", csv_idx + 1, len(csv_files), csv_file.name)
|
||||
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
@@ -488,7 +496,7 @@ def main():
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
logger.info(" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
@@ -500,13 +508,13 @@ def main():
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
logger.info(" Skipping CSV (all %d documents already processed)", len(rows))
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
logger.info(" Found %d new documents to process (%d already in DB)", new_to_process, already_processed)
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
@@ -520,7 +528,7 @@ def main():
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
logger.info(" Reached limit of %d new documents, stopping.", args.limit)
|
||||
break
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
@@ -583,7 +591,7 @@ def main():
|
||||
))
|
||||
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
logger.info(" Skipped %d (already in DB), retrying %d failed", skipped_in_csv, retry_in_csv)
|
||||
|
||||
# Clean up retry documents: delete from database and remove temp folders
|
||||
if retry_doc_ids:
|
||||
@@ -599,7 +607,7 @@ def main():
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||
logger.info(" Cleaned up %d retry documents (DB + temp folders)", len(retry_doc_ids))
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
@@ -636,7 +644,7 @@ def main():
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
logger.info(" Text PDFs: %d, Scanned PDFs: %d", text_count, scan_count)
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
@@ -667,11 +675,11 @@ def main():
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
logger.info(" Batch complete: %d successful, %d failed", successful, failed)
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info(" Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
@@ -725,28 +733,28 @@ def main():
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Auto-labeling Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Auto-labeling Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total documents: %d", stats['total'])
|
||||
logger.info("Successful: %d", stats['successful'])
|
||||
logger.info("Failed: %d", stats['failed'])
|
||||
logger.info("Skipped (no PDF): %d", stats['skipped'])
|
||||
logger.info("Skipped (in DB): %d", stats['skipped_db'])
|
||||
logger.info("Retried (failed): %d", stats['retried'])
|
||||
logger.info("Total annotations: %d", stats['annotations'])
|
||||
logger.info("Images saved to: %s", output_dir / 'temp')
|
||||
logger.info("Labels stored in: PostgreSQL database")
|
||||
logger.info("Annotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
logger.info("Report files (%d):", len(shard_files))
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
logger.info(" - %s", sf)
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
logger.info("Report: %s", shard_files[0] if shard_files else args.report)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +17,9 @@ from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
from shared.config import get_db_connection_string, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
@@ -150,7 +154,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Warning: Line {line_no} - JSON parse error: {e}")
|
||||
logger.warning("Line %d - JSON parse error: %s", line_no, e)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
@@ -211,7 +215,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
# Flush batch if needed
|
||||
if len(doc_batch) >= batch_size:
|
||||
flush_batches()
|
||||
print(f" Processed {stats['imported'] + stats['skipped']} records...")
|
||||
logger.info(" Processed %d records...", stats['imported'] + stats['skipped'])
|
||||
|
||||
# Final flush
|
||||
flush_batches()
|
||||
@@ -243,11 +247,14 @@ def main():
|
||||
else:
|
||||
report_files = [report_path] if report_path.exists() else []
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
if not report_files:
|
||||
print(f"No report files found: {args.report}")
|
||||
logger.error("No report files found: %s", args.report)
|
||||
return
|
||||
|
||||
print(f"Found {len(report_files)} report file(s)")
|
||||
logger.info("Found %d report file(s)", len(report_files))
|
||||
|
||||
# Connect to database
|
||||
conn = psycopg2.connect(db_connection)
|
||||
@@ -257,20 +264,20 @@ def main():
|
||||
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
for report_file in report_files:
|
||||
print(f"\nImporting: {report_file.name}")
|
||||
logger.info("Importing: %s", report_file.name)
|
||||
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
|
||||
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
|
||||
logger.info(" Imported: %d, Skipped: %d, Errors: %d", stats['imported'], stats['skipped'], stats['errors'])
|
||||
|
||||
for key in total_stats:
|
||||
total_stats[key] += stats[key]
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Import Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total imported: {total_stats['imported']}")
|
||||
print(f"Total skipped: {total_stats['skipped']}")
|
||||
print(f"Total errors: {total_stats['errors']}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Import Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total imported: %d", total_stats['imported'])
|
||||
logger.info("Total skipped: %d", total_stats['skipped'])
|
||||
logger.info("Total errors: %d", total_stats['errors'])
|
||||
|
||||
# Quick stats from database
|
||||
with conn.cursor() as cursor:
|
||||
@@ -288,11 +295,11 @@ def main():
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"\nDatabase Stats:")
|
||||
print(f" Documents: {total_docs} ({success_docs} successful)")
|
||||
print(f" Field results: {total_fields} ({matched_fields} matched)")
|
||||
logger.info("Database Stats:")
|
||||
logger.info(" Documents: %d (%d successful)", total_docs, success_docs)
|
||||
logger.info(" Field results: %d (%d matched)", total_fields, matched_fields)
|
||||
if total_fields > 0:
|
||||
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
|
||||
logger.info(" Match rate: %.2f%%", matched_fields / total_fields * 100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -7,6 +7,7 @@ CSV values, and source CSV filename in a new table.
|
||||
import argparse
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -20,6 +21,9 @@ from shared.config import DEFAULT_DPI
|
||||
from shared.data.db import DocumentDB
|
||||
from shared.data.csv_loader import CSVLoader
|
||||
from shared.normalize.normalizer import normalize_field
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_failed_match_table(db: DocumentDB):
|
||||
@@ -57,7 +61,7 @@ def create_failed_match_table(db: DocumentDB):
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
print("Created table: failed_match_details")
|
||||
logger.info("Created table: failed_match_details")
|
||||
|
||||
|
||||
def get_failed_documents(db: DocumentDB) -> list:
|
||||
@@ -332,14 +336,17 @@ def main():
|
||||
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Expand CSV glob
|
||||
csv_files = sorted(glob.glob(args.csv))
|
||||
print(f"Found {len(csv_files)} CSV files")
|
||||
logger.info("Found %d CSV files", len(csv_files))
|
||||
|
||||
# Build CSV cache
|
||||
print("Building CSV filename cache...")
|
||||
logger.info("Building CSV filename cache...")
|
||||
build_csv_cache(csv_files)
|
||||
print(f"Cached {len(_csv_cache)} document IDs")
|
||||
logger.info("Cached %d document IDs", len(_csv_cache))
|
||||
|
||||
# Connect to database
|
||||
db = DocumentDB()
|
||||
@@ -349,13 +356,13 @@ def main():
|
||||
create_failed_match_table(db)
|
||||
|
||||
# Get all failed documents
|
||||
print("Fetching failed documents...")
|
||||
logger.info("Fetching failed documents...")
|
||||
failed_docs = get_failed_documents(db)
|
||||
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||
logger.info("Found %d documents with failed matches", len(failed_docs))
|
||||
|
||||
if args.limit:
|
||||
failed_docs = failed_docs[:args.limit]
|
||||
print(f"Limited to {len(failed_docs)} documents")
|
||||
logger.info("Limited to %d documents", len(failed_docs))
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
@@ -365,7 +372,7 @@ def main():
|
||||
if failed_fields:
|
||||
tasks.append((doc, failed_fields, csv_filename))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info("Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process with multiprocessing
|
||||
total_results = 0
|
||||
@@ -389,15 +396,15 @@ def main():
|
||||
batch_results = []
|
||||
|
||||
except TimeoutError:
|
||||
print(f"\nTimeout processing {doc_id}")
|
||||
logger.warning("Timeout processing %s", doc_id)
|
||||
except Exception as e:
|
||||
print(f"\nError processing {doc_id}: {e}")
|
||||
logger.error("Error processing %s: %s", doc_id, e)
|
||||
|
||||
# Save remaining results
|
||||
if batch_results:
|
||||
save_results_batch(db, batch_results)
|
||||
|
||||
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||
logger.info("Done! Saved %d failed match records to failed_match_details table", total_results)
|
||||
|
||||
# Show summary
|
||||
conn = db.connect()
|
||||
@@ -410,12 +417,12 @@ def main():
|
||||
GROUP BY field_name
|
||||
ORDER BY total DESC
|
||||
""")
|
||||
print("\nSummary by field:")
|
||||
print("-" * 70)
|
||||
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||
print("-" * 70)
|
||||
logger.info("Summary by field:")
|
||||
logger.info("-" * 70)
|
||||
logger.info("%-35s %8s %10s %12s", 'Field', 'Total', 'Has OCR', 'Avg Score')
|
||||
logger.info("-" * 70)
|
||||
for row in cursor.fetchall():
|
||||
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||
logger.info("%-35s %8d %10d %12.2f", row[0], row[1], row[2], row[3])
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -7,10 +7,14 @@ Images are read from filesystem, labels are dynamically generated from DB.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -119,47 +123,50 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Apply low-memory mode if specified
|
||||
if args.low_memory:
|
||||
print("🔧 Low memory mode enabled")
|
||||
logger.info("Low memory mode enabled")
|
||||
args.batch = min(args.batch, 8) # Reduce from 16 to 8
|
||||
args.workers = min(args.workers, 4) # Reduce from 8 to 4
|
||||
args.cache = False
|
||||
print(f" Batch size: {args.batch}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Cache: disabled")
|
||||
logger.info(" Batch size: %d", args.batch)
|
||||
logger.info(" Workers: %d", args.workers)
|
||||
logger.info(" Cache: disabled")
|
||||
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Error: Temp directory not found: {temp_dir}")
|
||||
print("Run autolabel first to generate images.")
|
||||
logger.error("Temp directory not found: %s", temp_dir)
|
||||
logger.error("Run autolabel first to generate images.")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print("YOLO Training with Database Labels")
|
||||
print("=" * 60)
|
||||
print(f"Dataset dir: {dataset_dir}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch}")
|
||||
print(f"Image size: {args.imgsz}")
|
||||
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("YOLO Training with Database Labels")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Dataset dir: %s", dataset_dir)
|
||||
logger.info("Model: %s", args.model)
|
||||
logger.info("Epochs: %d", args.epochs)
|
||||
logger.info("Batch size: %d", args.batch)
|
||||
logger.info("Image size: %d", args.imgsz)
|
||||
logger.info("Split ratio: %s/%s/%.1f", args.train_ratio, args.val_ratio, 1-args.train_ratio-args.val_ratio)
|
||||
if args.limit:
|
||||
print(f"Document limit: {args.limit}")
|
||||
logger.info("Document limit: %d", args.limit)
|
||||
|
||||
# Connect to database
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
print("\nConnecting to database...")
|
||||
logger.info("Connecting to database...")
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create datasets from database
|
||||
from training.yolo.db_dataset import create_datasets
|
||||
|
||||
print("Loading dataset from database...")
|
||||
logger.info("Loading dataset from database...")
|
||||
datasets = create_datasets(
|
||||
images_dir=dataset_dir,
|
||||
db=db,
|
||||
@@ -170,39 +177,39 @@ def main():
|
||||
limit=args.limit
|
||||
)
|
||||
|
||||
print(f"\nDataset splits:")
|
||||
print(f" Train: {len(datasets['train'])} items")
|
||||
print(f" Val: {len(datasets['val'])} items")
|
||||
print(f" Test: {len(datasets['test'])} items")
|
||||
logger.info("Dataset splits:")
|
||||
logger.info(" Train: %d items", len(datasets['train']))
|
||||
logger.info(" Val: %d items", len(datasets['val']))
|
||||
logger.info(" Test: %d items", len(datasets['test']))
|
||||
|
||||
if len(datasets['train']) == 0:
|
||||
print("\nError: No training data found!")
|
||||
print("Make sure autolabel has been run and images exist in temp directory.")
|
||||
logger.error("No training data found!")
|
||||
logger.error("Make sure autolabel has been run and images exist in temp directory.")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Export to YOLO format (required for Ultralytics training)
|
||||
print("\nExporting dataset to YOLO format...")
|
||||
logger.info("Exporting dataset to YOLO format...")
|
||||
for split_name, dataset in datasets.items():
|
||||
count = dataset.export_to_yolo_format(dataset_dir, split_name)
|
||||
print(f" {split_name}: {count} items exported")
|
||||
logger.info(" %s: %d items exported", split_name, count)
|
||||
|
||||
# Generate YOLO config files
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
|
||||
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
|
||||
logger.info("Generated dataset.yaml at: %s", dataset_dir / 'dataset.yaml')
|
||||
|
||||
if args.export_only:
|
||||
print("\nExport complete (--export-only specified, skipping training)")
|
||||
logger.info("Export complete (--export-only specified, skipping training)")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training using shared trainer
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("Starting YOLO Training")
|
||||
logger.info("=" * 60)
|
||||
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
@@ -232,30 +239,30 @@ def main():
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
print(f"\nError: Training failed - {result.error}")
|
||||
logger.error("Training failed - %s", result.error)
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {result.model_path}")
|
||||
print(f"Save directory: {result.save_dir}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Training Complete")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Best model: %s", result.model_path)
|
||||
logger.info("Save directory: %s", result.save_dir)
|
||||
if result.metrics:
|
||||
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||
logger.info("mAP@0.5: %s", result.metrics.get('mAP50', 'N/A'))
|
||||
logger.info("mAP@0.5-0.95: %s", result.metrics.get('mAP50-95', 'N/A'))
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
logger.info("Running validation on test set...")
|
||||
if result.model_path:
|
||||
config.model_path = result.model_path
|
||||
config.data_yaml = str(data_yaml)
|
||||
test_trainer = YOLOTrainer(config=config)
|
||||
test_metrics = test_trainer.validate(split='test')
|
||||
if test_metrics:
|
||||
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||
logger.info("mAP50: %.4f", test_metrics.get('mAP50', 0))
|
||||
logger.info("mAP50-95: %.4f", test_metrics.get('mAP50-95', 0))
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
@@ -7,9 +7,14 @@ and comparing the extraction results.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
@@ -73,6 +78,9 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
from backend.validation import LLMValidator
|
||||
|
||||
validator = LLMValidator()
|
||||
@@ -104,60 +112,58 @@ def show_stats(validator):
|
||||
"""Show statistics about failed matches."""
|
||||
stats = validator.get_failed_match_stats()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Failed Match Statistics")
|
||||
print("=" * 50)
|
||||
print(f"\nDocuments with failures: {stats['documents_with_failures']}")
|
||||
print(f"Already validated: {stats['already_validated']}")
|
||||
print(f"Remaining to validate: {stats['remaining']}")
|
||||
print("\nFailures by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Failed Match Statistics")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Documents with failures: %d", stats['documents_with_failures'])
|
||||
logger.info("Already validated: %d", stats['already_validated'])
|
||||
logger.info("Remaining to validate: %d", stats['remaining'])
|
||||
logger.info("Failures by field:")
|
||||
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
|
||||
|
||||
def validate_single(validator, doc_id: str, provider: str, model: str):
|
||||
"""Validate a single document."""
|
||||
print(f"\nValidating document: {doc_id}")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating document: %s", doc_id)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
result = validator.validate_document(doc_id, provider, model)
|
||||
|
||||
if result.error:
|
||||
print(f"ERROR: {result.error}")
|
||||
logger.error("ERROR: %s", result.error)
|
||||
return
|
||||
|
||||
print(f"Processing time: {result.processing_time_ms:.0f}ms")
|
||||
print(f"Model used: {result.model_used}")
|
||||
print("\nExtracted fields:")
|
||||
print(f" Invoice Number: {result.invoice_number}")
|
||||
print(f" Invoice Date: {result.invoice_date}")
|
||||
print(f" Due Date: {result.invoice_due_date}")
|
||||
print(f" OCR: {result.ocr_number}")
|
||||
print(f" Bankgiro: {result.bankgiro}")
|
||||
print(f" Plusgiro: {result.plusgiro}")
|
||||
print(f" Amount: {result.amount}")
|
||||
print(f" Org Number: {result.supplier_organisation_number}")
|
||||
logger.info("Processing time: %.0fms", result.processing_time_ms)
|
||||
logger.info("Model used: %s", result.model_used)
|
||||
logger.info("Extracted fields:")
|
||||
logger.info(" Invoice Number: %s", result.invoice_number)
|
||||
logger.info(" Invoice Date: %s", result.invoice_date)
|
||||
logger.info(" Due Date: %s", result.invoice_due_date)
|
||||
logger.info(" OCR: %s", result.ocr_number)
|
||||
logger.info(" Bankgiro: %s", result.bankgiro)
|
||||
logger.info(" Plusgiro: %s", result.plusgiro)
|
||||
logger.info(" Amount: %s", result.amount)
|
||||
logger.info(" Org Number: %s", result.supplier_organisation_number)
|
||||
|
||||
# Show comparison
|
||||
print("\n" + "-" * 50)
|
||||
print("Comparison with autolabel:")
|
||||
logger.info("-" * 50)
|
||||
logger.info("Comparison with autolabel:")
|
||||
comparison = validator.compare_results(doc_id)
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value'):
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
print(f" {status} {field}:")
|
||||
print(f" CSV: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM: {data['llm_value']}")
|
||||
logger.info(" %s %s:", status, field)
|
||||
logger.info(" CSV: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM: %s", data['llm_value'])
|
||||
|
||||
|
||||
def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
"""Validate a batch of documents."""
|
||||
print(f"\nValidating up to {limit} documents with failed matches")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating up to %d documents with failed matches", limit)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
results = validator.validate_batch(
|
||||
limit=limit,
|
||||
@@ -171,15 +177,15 @@ def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
failed = len(results) - success
|
||||
total_time = sum(r.processing_time_ms or 0 for r in results)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Validation Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total: {len(results)}")
|
||||
print(f"Success: {success}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Total time: {total_time/1000:.1f}s")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Validation Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total: %d", len(results))
|
||||
logger.info("Success: %d", success)
|
||||
logger.info("Failed: %d", failed)
|
||||
logger.info("Total time: %.1fs", total_time/1000)
|
||||
if success > 0:
|
||||
print(f"Avg time: {total_time/success:.0f}ms per document")
|
||||
logger.info("Avg time: %.0fms per document", total_time/success)
|
||||
|
||||
|
||||
def compare_single(validator, doc_id: str):
|
||||
@@ -187,23 +193,23 @@ def compare_single(validator, doc_id: str):
|
||||
comparison = validator.compare_results(doc_id)
|
||||
|
||||
if 'error' in comparison:
|
||||
print(f"Error: {comparison['error']}")
|
||||
logger.error("Error: %s", comparison['error'])
|
||||
return
|
||||
|
||||
print(f"\nComparison for document: {doc_id}")
|
||||
print("=" * 60)
|
||||
logger.info("Comparison for document: %s", doc_id)
|
||||
logger.info("=" * 60)
|
||||
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value') is None:
|
||||
continue
|
||||
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
|
||||
print(f"\n{status} {field}:")
|
||||
print(f" CSV value: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM extracted: {data['llm_value']}")
|
||||
logger.info("%s %s:", status, field)
|
||||
logger.info(" CSV value: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM extracted: %s", data['llm_value'])
|
||||
|
||||
|
||||
def compare_all(validator, limit: int):
|
||||
@@ -220,11 +226,11 @@ def compare_all(validator, limit: int):
|
||||
doc_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not doc_ids:
|
||||
print("No validated documents found.")
|
||||
logger.info("No validated documents found.")
|
||||
return
|
||||
|
||||
print(f"\nComparison Summary ({len(doc_ids)} documents)")
|
||||
print("=" * 80)
|
||||
logger.info("Comparison Summary (%d documents)", len(doc_ids))
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Aggregate stats
|
||||
field_stats = {}
|
||||
@@ -259,12 +265,12 @@ def compare_all(validator, limit: int):
|
||||
if not data['autolabel_matched'] and data['agreement']:
|
||||
stats['llm_correct_auto_wrong'] += 1
|
||||
|
||||
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}")
|
||||
print("-" * 80)
|
||||
logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
|
||||
logger.info("-" * 80)
|
||||
|
||||
for field, stats in sorted(field_stats.items()):
|
||||
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} "
|
||||
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}")
|
||||
logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
|
||||
stats['llm_agrees'], stats['llm_correct_auto_wrong'])
|
||||
|
||||
|
||||
def generate_report(validator, output_path: str):
|
||||
@@ -328,8 +334,8 @@ def generate_report(validator, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport generated: {output_path}")
|
||||
print(f"Total validations: {len(validations)}")
|
||||
logger.info("Report generated: %s", output_path)
|
||||
logger.info("Total validations: %d", len(validations))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -18,7 +18,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES, CLASS_NAMES
|
||||
from shared.bbox import expand_bbox
|
||||
from .annotation_generator import YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +157,7 @@ class DBYOLODataset:
|
||||
|
||||
# Split items for this split
|
||||
instance.items = instance._split_dataset_from_cache()
|
||||
print(f"Split '{split}': {len(instance.items)} items")
|
||||
logger.info("Split '%s': %d items", split, len(instance.items))
|
||||
|
||||
return instance
|
||||
|
||||
@@ -165,7 +166,7 @@ class DBYOLODataset:
|
||||
# Find all document directories
|
||||
temp_dir = self.images_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Temp directory not found: {temp_dir}")
|
||||
logger.warning("Temp directory not found: %s", temp_dir)
|
||||
return
|
||||
|
||||
# Collect all document IDs with images
|
||||
@@ -182,13 +183,13 @@ class DBYOLODataset:
|
||||
if images:
|
||||
doc_image_map[doc_dir.name] = sorted(images)
|
||||
|
||||
print(f"Found {len(doc_image_map)} documents with images")
|
||||
logger.info("Found %d documents with images", len(doc_image_map))
|
||||
|
||||
# Query database for all document labels
|
||||
doc_ids = list(doc_image_map.keys())
|
||||
doc_labels = self._load_labels_from_db(doc_ids)
|
||||
|
||||
print(f"Loaded labels for {len(doc_labels)} documents from database")
|
||||
logger.info("Loaded labels for %d documents from database", len(doc_labels))
|
||||
|
||||
# Build dataset items
|
||||
all_items: list[DatasetItem] = []
|
||||
@@ -227,19 +228,19 @@ class DBYOLODataset:
|
||||
else:
|
||||
skipped_no_labels += 1
|
||||
|
||||
print(f"Total images found: {total_images}")
|
||||
print(f"Images with labels: {len(all_items)}")
|
||||
logger.info("Total images found: %d", total_images)
|
||||
logger.info("Images with labels: %d", len(all_items))
|
||||
if skipped_no_db_record > 0:
|
||||
print(f"Skipped {skipped_no_db_record} images (document not in database)")
|
||||
logger.info("Skipped %d images (document not in database)", skipped_no_db_record)
|
||||
if skipped_no_labels > 0:
|
||||
print(f"Skipped {skipped_no_labels} images (no labels for page)")
|
||||
logger.info("Skipped %d images (no labels for page)", skipped_no_labels)
|
||||
|
||||
# Cache all items for sharing with other splits
|
||||
self._all_items = all_items
|
||||
|
||||
# Split dataset
|
||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||
print(f"Split '{self.split}': {len(self.items)} items")
|
||||
logger.info("Split '%s': %d items", self.split, len(self.items))
|
||||
|
||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
|
||||
"""
|
||||
@@ -374,7 +375,7 @@ class DBYOLODataset:
|
||||
|
||||
if has_csv_splits:
|
||||
# Use CSV-defined splits
|
||||
print("Using CSV-defined split field for train/val/test assignment")
|
||||
logger.info("Using CSV-defined split field for train/val/test assignment")
|
||||
|
||||
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
|
||||
# 'val' is taken from train set using val_ratio
|
||||
@@ -411,11 +412,11 @@ class DBYOLODataset:
|
||||
# Apply limit if specified
|
||||
if self.limit is not None and self.limit < len(split_doc_ids):
|
||||
split_doc_ids = split_doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
else:
|
||||
# Fall back to random splitting (original behavior)
|
||||
print("No CSV split field found, using random splitting")
|
||||
logger.info("No CSV split field found, using random splitting")
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(doc_ids)
|
||||
@@ -423,7 +424,7 @@ class DBYOLODataset:
|
||||
# Apply limit if specified (before splitting)
|
||||
if self.limit is not None and self.limit < len(doc_ids):
|
||||
doc_ids = doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
@@ -549,6 +550,8 @@ class DBYOLODataset:
|
||||
"""
|
||||
Convert annotations to normalized YOLO format.
|
||||
|
||||
Uses field-specific bbox expansion strategies via expand_bbox.
|
||||
|
||||
Args:
|
||||
annotations: List of annotations
|
||||
img_width: Actual image width in pixels
|
||||
@@ -568,26 +571,43 @@ class DBYOLODataset:
|
||||
|
||||
labels = []
|
||||
for ann in annotations:
|
||||
# Convert to pixels (if needed)
|
||||
x_center_px = ann.x_center * scale
|
||||
y_center_px = ann.y_center * scale
|
||||
width_px = ann.width * scale
|
||||
height_px = ann.height * scale
|
||||
# Convert center+size to corner coords in PDF points
|
||||
half_w = ann.width / 2
|
||||
half_h = ann.height / 2
|
||||
x0_pdf = ann.x_center - half_w
|
||||
y0_pdf = ann.y_center - half_h
|
||||
x1_pdf = ann.x_center + half_w
|
||||
y1_pdf = ann.y_center + half_h
|
||||
|
||||
# Add padding
|
||||
pad = self.bbox_padding_px
|
||||
width_px += 2 * pad
|
||||
height_px += 2 * pad
|
||||
# Convert to pixels
|
||||
x0_px = x0_pdf * scale
|
||||
y0_px = y0_pdf * scale
|
||||
x1_px = x1_pdf * scale
|
||||
y1_px = y1_pdf * scale
|
||||
|
||||
# Get class name for field-specific expansion
|
||||
class_name = CLASS_NAMES[ann.class_id]
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0_px, y0_px, x1_px, y1_px),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
height_px = y1 - y0
|
||||
if height_px < self.min_bbox_height_px:
|
||||
height_px = self.min_bbox_height_px
|
||||
extra = (self.min_bbox_height_px - height_px) / 2
|
||||
y0 = max(0, int(y0 - extra))
|
||||
y1 = min(img_height, int(y1 + extra))
|
||||
|
||||
# Normalize to 0-1
|
||||
x_center = x_center_px / img_width
|
||||
y_center = y_center_px / img_height
|
||||
width = width_px / img_width
|
||||
height = height_px / img_height
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / img_width
|
||||
y_center = (y0 + y1) / 2 / img_height
|
||||
width = (x1 - x0) / img_width
|
||||
height = (y1 - y0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
x_center = max(0, min(1, x_center))
|
||||
@@ -675,7 +695,7 @@ class DBYOLODataset:
|
||||
|
||||
count += 1
|
||||
|
||||
print(f"Exported {count} items to {output_dir / split_name}")
|
||||
logger.info("Exported %d items to %s", count, output_dir / split_name)
|
||||
return count
|
||||
|
||||
|
||||
@@ -706,7 +726,7 @@ def create_datasets(
|
||||
Dict with 'train', 'val', 'test' datasets
|
||||
"""
|
||||
# Create first dataset which loads all data
|
||||
print("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
logger.info("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
first_dataset = DBYOLODataset(
|
||||
images_dir=images_dir,
|
||||
db=db,
|
||||
|
||||
251
tests/training/yolo/test_db_dataset.py
Normal file
251
tests/training/yolo/test_db_dataset.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for db_dataset.py expand_bbox integration."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pathlib import Path
|
||||
|
||||
from training.yolo.db_dataset import DBYOLODataset
|
||||
from training.yolo.annotation_generator import YOLOAnnotation
|
||||
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
|
||||
from shared.fields import CLASS_NAMES
|
||||
|
||||
|
||||
class TestConvertLabelsWithExpandBbox:
|
||||
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
|
||||
|
||||
def test_convert_labels_uses_expand_bbox(self):
|
||||
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
|
||||
# Create a mock dataset without loading from DB
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Create annotation for bankgiro (has extra_left_ratio)
|
||||
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
|
||||
# center: (150, 225), width: 100, height: 50
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro
|
||||
x_center=150, # in PDF points
|
||||
y_center=225,
|
||||
width=100,
|
||||
height=50,
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
# Image size in pixels (at 300 DPI)
|
||||
img_width = 2480 # A4 width at 300 DPI
|
||||
img_height = 3508 # A4 height at 300 DPI
|
||||
|
||||
# Convert labels
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Should have one label
|
||||
assert labels.shape == (1, 5)
|
||||
|
||||
# Check class_id
|
||||
assert labels[0, 0] == 4
|
||||
|
||||
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
|
||||
# Original bbox at 300 DPI:
|
||||
# x0 = 100 * (300/72) = 416.67
|
||||
# y0 = 200 * (300/72) = 833.33
|
||||
# x1 = 200 * (300/72) = 833.33
|
||||
# y1 = 250 * (300/72) = 1041.67
|
||||
# width_px = 416.67, height_px = 208.33
|
||||
|
||||
# After expand_bbox with bankgiro strategy:
|
||||
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
|
||||
# The x_center should shift left due to extra_left_ratio
|
||||
x_center = labels[0, 1]
|
||||
y_center = labels[0, 2]
|
||||
width = labels[0, 3]
|
||||
height = labels[0, 4]
|
||||
|
||||
# Verify normalized values are in valid range
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 < width <= 1
|
||||
assert 0 < height <= 1
|
||||
|
||||
# Width should be larger than original due to scaling and extra_left
|
||||
# Original normalized width: 416.67 / 2480 = 0.168
|
||||
# After bankgiro expansion it should be wider
|
||||
assert width > 0.168
|
||||
|
||||
def test_convert_labels_different_field_types(self):
|
||||
"""Verify different field types use their specific strategies."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Same bbox for different field types
|
||||
base_annotation = {
|
||||
'x_center': 150,
|
||||
'y_center': 225,
|
||||
'width': 100,
|
||||
'height': 50,
|
||||
'confidence': 0.9
|
||||
}
|
||||
|
||||
# OCR number (class_id=3) - has extra_top_ratio=0.60
|
||||
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
|
||||
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
|
||||
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
|
||||
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Amount (class_id=6) - has extra_right_ratio=0.30
|
||||
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
|
||||
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Each field type should have different expansion
|
||||
# OCR should expand more vertically (extra_top)
|
||||
# Bankgiro should expand more to the left
|
||||
# Amount should expand more to the right
|
||||
|
||||
# OCR: extra_top shifts y_center up
|
||||
# Bankgiro: extra_left shifts x_center left
|
||||
# So bankgiro x_center < OCR x_center
|
||||
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
|
||||
|
||||
# OCR has higher scale_y (1.80) than amount (1.35)
|
||||
assert ocr_labels[0, 4] > amount_labels[0, 4]
|
||||
|
||||
def test_convert_labels_clamps_to_image_bounds(self):
|
||||
"""Verify labels are clamped to image boundaries."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Annotation near edge of image (in PDF points)
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro - will expand left
|
||||
x_center=30, # Very close to left edge
|
||||
y_center=50,
|
||||
width=40,
|
||||
height=30,
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# All values should be in valid range
|
||||
assert 0 <= labels[0, 1] <= 1 # x_center
|
||||
assert 0 <= labels[0, 2] <= 1 # y_center
|
||||
assert 0 < labels[0, 3] <= 1 # width
|
||||
assert 0 < labels[0, 4] <= 1 # height
|
||||
|
||||
def test_convert_labels_empty_annotations(self):
|
||||
"""Verify empty annotations return empty array."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
labels = dataset._convert_labels([], 2480, 3508, is_scanned=False)
|
||||
|
||||
assert labels.shape == (0, 5)
|
||||
assert labels.dtype == np.float32
|
||||
|
||||
def test_convert_labels_minimum_height(self):
|
||||
"""Verify minimum height is enforced after expansion."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 50 # Higher minimum
|
||||
|
||||
# Very small annotation
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=9, # payment_line - minimal expansion
|
||||
x_center=100,
|
||||
y_center=100,
|
||||
width=200,
|
||||
height=5, # Very small height
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
|
||||
|
||||
# Height should be at least min_bbox_height_px / img_height
|
||||
min_normalized_height = 50 / 3508
|
||||
assert labels[0, 4] >= min_normalized_height
|
||||
|
||||
|
||||
class TestCreateAnnotationWithClassName:
|
||||
"""Tests for _create_annotation storing class_name for expand_bbox lookup."""
|
||||
|
||||
def test_create_annotation_stores_class_name(self):
|
||||
"""Verify _create_annotation stores class_name for later use."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
|
||||
# Create annotation for invoice_number
|
||||
annotation = dataset._create_annotation(
|
||||
field_name="InvoiceNumber",
|
||||
bbox=[100, 200, 200, 250],
|
||||
score=0.9
|
||||
)
|
||||
|
||||
assert annotation.class_id == 0 # invoice_number class_id
|
||||
|
||||
|
||||
class TestLoadLabelsFromDbWithClassName:
|
||||
"""Tests for _load_labels_from_db preserving field_name for expansion."""
|
||||
|
||||
def test_load_labels_maps_field_names_correctly(self):
|
||||
"""Verify field names are mapped correctly for expand_bbox."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.min_confidence = 0.7
|
||||
|
||||
# Mock database
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_documents_batch.return_value = {
|
||||
'doc1': {
|
||||
'success': True,
|
||||
'pdf_type': 'text',
|
||||
'split': 'train',
|
||||
'field_results': [
|
||||
{
|
||||
'matched': True,
|
||||
'field_name': 'Bankgiro',
|
||||
'score': 0.9,
|
||||
'bbox': [100, 200, 200, 250],
|
||||
'page_no': 0
|
||||
},
|
||||
{
|
||||
'matched': True,
|
||||
'field_name': 'supplier_accounts(Plusgiro)',
|
||||
'score': 0.85,
|
||||
'bbox': [300, 400, 400, 450],
|
||||
'page_no': 0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
dataset.db = mock_db
|
||||
|
||||
result = dataset._load_labels_from_db(['doc1'])
|
||||
|
||||
assert 'doc1' in result
|
||||
page_labels, is_scanned, csv_split = result['doc1']
|
||||
|
||||
# Should have 2 annotations on page 0
|
||||
assert 0 in page_labels
|
||||
assert len(page_labels[0]) == 2
|
||||
|
||||
# First annotation: Bankgiro (class_id=4)
|
||||
assert page_labels[0][0].class_id == 4
|
||||
|
||||
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
|
||||
assert page_labels[0][1].class_id == 5
|
||||
367
tests/web/test_training_export.py
Normal file
367
tests/web/test_training_export.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Tests for Training Export with expand_bbox integration.
|
||||
|
||||
Tests the export endpoint's integration with field-specific bbox expansion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.bbox import expand_bbox
|
||||
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
class TestExpandBboxForExport:
|
||||
"""Tests for expand_bbox integration in export workflow."""
|
||||
|
||||
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
|
||||
"""Verify expand_bbox works with pixel-to-normalized conversion."""
|
||||
# Annotation stored as normalized coords
|
||||
x_center_norm = 0.5
|
||||
y_center_norm = 0.5
|
||||
width_norm = 0.1
|
||||
height_norm = 0.05
|
||||
|
||||
# Image dimensions
|
||||
img_width = 2480 # A4 at 300 DPI
|
||||
img_height = 3508
|
||||
|
||||
# Convert to pixel coords
|
||||
x_center_px = x_center_norm * img_width
|
||||
y_center_px = y_center_norm * img_height
|
||||
width_px = width_norm * img_width
|
||||
height_px = height_norm * img_height
|
||||
|
||||
# Convert to corner coords
|
||||
x0 = x_center_px - width_px / 2
|
||||
y0 = y_center_px - height_px / 2
|
||||
x1 = x_center_px + width_px / 2
|
||||
y1 = y_center_px + height_px / 2
|
||||
|
||||
# Apply expansion
|
||||
class_name = "invoice_number"
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify expanded bbox is larger
|
||||
assert ex0 < x0 # Left expanded
|
||||
assert ey0 < y0 # Top expanded
|
||||
assert ex1 > x1 # Right expanded
|
||||
assert ey1 > y1 # Bottom expanded
|
||||
|
||||
# Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify valid normalized coords
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 <= new_width <= 1
|
||||
assert 0 <= new_height <= 1
|
||||
|
||||
def test_expand_bbox_manual_mode_minimal_expansion(self):
|
||||
"""Verify manual annotations use minimal expansion."""
|
||||
# Small bbox
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Auto mode (field-specific expansion)
|
||||
auto_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
# Manual mode (minimal expansion)
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Auto expansion should be larger than manual
|
||||
auto_width = auto_result[2] - auto_result[0]
|
||||
manual_width = manual_result[2] - manual_result[0]
|
||||
assert auto_width > manual_width
|
||||
|
||||
auto_height = auto_result[3] - auto_result[1]
|
||||
manual_height = manual_result[3] - manual_result[1]
|
||||
assert auto_height > manual_height
|
||||
|
||||
def test_expand_bbox_different_sources_use_correct_mode(self):
|
||||
"""Verify different annotation sources use correct expansion mode."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Define source to manual_mode mapping
|
||||
source_mode_mapping = {
|
||||
"manual": True, # Manual annotations -> minimal expansion
|
||||
"auto": False, # Auto-labeled -> field-specific expansion
|
||||
"imported": True, # Imported (from CSV) -> minimal expansion
|
||||
}
|
||||
|
||||
results = {}
|
||||
for source, manual_mode in source_mode_mapping.items():
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="ocr_number",
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
results[source] = result
|
||||
|
||||
# Auto should have largest expansion
|
||||
auto_area = (results["auto"][2] - results["auto"][0]) * \
|
||||
(results["auto"][3] - results["auto"][1])
|
||||
manual_area = (results["manual"][2] - results["manual"][0]) * \
|
||||
(results["manual"][3] - results["manual"][1])
|
||||
imported_area = (results["imported"][2] - results["imported"][0]) * \
|
||||
(results["imported"][3] - results["imported"][1])
|
||||
|
||||
assert auto_area > manual_area
|
||||
assert auto_area > imported_area
|
||||
# Manual and imported should be the same (both use minimal mode)
|
||||
assert manual_area == imported_area
|
||||
|
||||
def test_expand_bbox_all_field_types_work(self):
|
||||
"""Verify expand_bbox works for all field types."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
for class_name in CLASS_NAMES:
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify result is a valid bbox
|
||||
assert len(result) == 4
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
|
||||
|
||||
class TestExportAnnotationExpansion:
|
||||
"""Tests for annotation expansion in export workflow."""
|
||||
|
||||
def test_annotation_bbox_conversion_workflow(self):
|
||||
"""Test full annotation bbox conversion workflow."""
|
||||
# Simulate stored annotation (normalized coords)
|
||||
class MockAnnotation:
|
||||
class_id = FIELD_CLASS_IDS["invoice_number"]
|
||||
class_name = "invoice_number"
|
||||
x_center = 0.3
|
||||
y_center = 0.2
|
||||
width = 0.15
|
||||
height = 0.03
|
||||
source = "auto"
|
||||
|
||||
ann = MockAnnotation()
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Step 1: Convert normalized to pixel corner coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
y0 = ann.y_center * img_height - half_h
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Step 2: Determine manual_mode based on source
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Step 3: Apply expand_bbox
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Step 4: Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify expansion happened (auto mode)
|
||||
assert new_width > ann.width
|
||||
assert new_height > ann.height
|
||||
|
||||
# Verify valid YOLO format
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 < new_width <= 1
|
||||
assert 0 < new_height <= 1
|
||||
|
||||
def test_export_applies_expansion_to_each_annotation(self):
|
||||
"""Test that export applies expansion to each annotation."""
|
||||
# Simulate multiple annotations with different sources
|
||||
# Use smaller bboxes so manual mode padding has visible effect
|
||||
annotations = [
|
||||
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "amount", "source": "imported", "x_center": 0.7, "y_center": 0.5, "width": 0.05, "height": 0.02},
|
||||
]
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
expanded_annotations = []
|
||||
for ann in annotations:
|
||||
# Convert to pixel coords
|
||||
half_w = (ann["width"] * img_width) / 2
|
||||
half_h = (ann["height"] * img_height) / 2
|
||||
x0 = ann["x_center"] * img_width - half_w
|
||||
y0 = ann["y_center"] * img_height - half_h
|
||||
x1 = ann["x_center"] * img_width + half_w
|
||||
y1 = ann["y_center"] * img_height + half_h
|
||||
|
||||
# Determine manual_mode
|
||||
manual_mode = ann["source"] in ("manual", "imported")
|
||||
|
||||
# Apply expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann["class_name"],
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized
|
||||
expanded_annotations.append({
|
||||
"class_name": ann["class_name"],
|
||||
"source": ann["source"],
|
||||
"x_center": (ex0 + ex1) / 2 / img_width,
|
||||
"y_center": (ey0 + ey1) / 2 / img_height,
|
||||
"width": (ex1 - ex0) / img_width,
|
||||
"height": (ey1 - ey0) / img_height,
|
||||
})
|
||||
|
||||
# Verify auto-labeled annotation expanded more than manual/imported
|
||||
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
|
||||
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
|
||||
|
||||
# Auto mode should expand more than manual mode
|
||||
# (auto has larger scale factors and max_pad)
|
||||
assert auto_ann["width"] > manual_ann["width"]
|
||||
assert auto_ann["height"] > manual_ann["height"]
|
||||
|
||||
# All annotations should be expanded (at least slightly for manual mode)
|
||||
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
|
||||
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
|
||||
# Width and height should be >= original (expansion or equal, with small tolerance)
|
||||
tolerance = 0.01 # 1% tolerance for integer rounding
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance), \
|
||||
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance), \
|
||||
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
|
||||
|
||||
|
||||
class TestExpandBboxEdgeCases:
|
||||
"""Tests for edge cases in export bbox expansion."""
|
||||
|
||||
def test_bbox_at_image_edge_left(self):
|
||||
"""Test bbox at left edge of image."""
|
||||
bbox = (0, 100, 50, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Left edge should be clamped to 0
|
||||
assert result[0] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_right(self):
|
||||
"""Test bbox at right edge of image."""
|
||||
bbox = (2400, 100, 2480, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Right edge should be clamped to image width
|
||||
assert result[2] <= img_width
|
||||
|
||||
def test_bbox_at_image_edge_top(self):
|
||||
"""Test bbox at top edge of image."""
|
||||
bbox = (100, 0, 200, 50)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Top edge should be clamped to 0
|
||||
assert result[1] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_bottom(self):
|
||||
"""Test bbox at bottom edge of image."""
|
||||
bbox = (100, 3400, 200, 3508)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Bottom edge should be clamped to image height
|
||||
assert result[3] <= img_height
|
||||
|
||||
def test_very_small_bbox(self):
|
||||
"""Test very small bbox gets expanded."""
|
||||
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Should still produce a valid expanded bbox
|
||||
assert result[2] > result[0]
|
||||
assert result[3] > result[1]
|
||||
Reference in New Issue
Block a user