302 lines
11 KiB
Python
302 lines
11 KiB
Python
"""
|
|
Auto-Label Report Generator
|
|
|
|
Generates quality control reports for auto-labeling process.
|
|
"""
|
|
|
|
import json
|
|
from dataclasses import dataclass, field, asdict
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
@dataclass
|
|
class FieldMatchResult:
|
|
"""Result of matching a single field."""
|
|
field_name: str
|
|
csv_value: str | None
|
|
matched: bool
|
|
score: float = 0.0
|
|
matched_text: str | None = None
|
|
candidate_used: str | None = None # Which normalized variant matched
|
|
bbox: tuple[float, float, float, float] | None = None
|
|
page_no: int = 0
|
|
context_keywords: list[str] = field(default_factory=list)
|
|
error: str | None = None
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
# Convert bbox to native Python floats to avoid numpy serialization issues
|
|
bbox_list = None
|
|
if self.bbox:
|
|
bbox_list = [float(x) for x in self.bbox]
|
|
|
|
return {
|
|
'field_name': self.field_name,
|
|
'csv_value': self.csv_value,
|
|
'matched': self.matched,
|
|
'score': float(self.score) if self.score else 0.0,
|
|
'matched_text': self.matched_text,
|
|
'candidate_used': self.candidate_used,
|
|
'bbox': bbox_list,
|
|
'page_no': int(self.page_no) if self.page_no else 0,
|
|
'context_keywords': self.context_keywords,
|
|
'error': self.error
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class AutoLabelReport:
|
|
"""Report for a single document's auto-labeling process."""
|
|
document_id: str
|
|
pdf_path: str | None = None
|
|
pdf_type: str | None = None # 'text' | 'scanned' | 'mixed'
|
|
success: bool = False
|
|
total_pages: int = 0
|
|
fields_matched: int = 0
|
|
fields_total: int = 0
|
|
field_results: list[FieldMatchResult] = field(default_factory=list)
|
|
annotations_generated: int = 0
|
|
image_paths: list[str] = field(default_factory=list)
|
|
label_paths: list[str] = field(default_factory=list)
|
|
processing_time_ms: float = 0.0
|
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
errors: list[str] = field(default_factory=list)
|
|
|
|
def add_field_result(self, result: FieldMatchResult) -> None:
|
|
"""Add a field matching result."""
|
|
self.field_results.append(result)
|
|
self.fields_total += 1
|
|
if result.matched:
|
|
self.fields_matched += 1
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert to dictionary."""
|
|
return {
|
|
'document_id': self.document_id,
|
|
'pdf_path': self.pdf_path,
|
|
'pdf_type': self.pdf_type,
|
|
'success': self.success,
|
|
'total_pages': self.total_pages,
|
|
'fields_matched': self.fields_matched,
|
|
'fields_total': self.fields_total,
|
|
'field_results': [r.to_dict() for r in self.field_results],
|
|
'annotations_generated': self.annotations_generated,
|
|
'image_paths': self.image_paths,
|
|
'label_paths': self.label_paths,
|
|
'processing_time_ms': self.processing_time_ms,
|
|
'timestamp': self.timestamp,
|
|
'errors': self.errors
|
|
}
|
|
|
|
def to_json(self, indent: int | None = None) -> str:
|
|
"""Convert to JSON string."""
|
|
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
|
|
|
@property
|
|
def match_rate(self) -> float:
|
|
"""Calculate field match rate."""
|
|
if self.fields_total == 0:
|
|
return 0.0
|
|
return self.fields_matched / self.fields_total
|
|
|
|
def get_summary(self) -> dict:
|
|
"""Get a summary of the report."""
|
|
return {
|
|
'document_id': self.document_id,
|
|
'success': self.success,
|
|
'match_rate': f"{self.match_rate:.1%}",
|
|
'fields': f"{self.fields_matched}/{self.fields_total}",
|
|
'annotations': self.annotations_generated,
|
|
'errors': len(self.errors)
|
|
}
|
|
|
|
|
|
class ReportWriter:
|
|
"""Writes auto-label reports to file with optional sharding."""
|
|
|
|
def __init__(
|
|
self,
|
|
output_path: str | Path,
|
|
max_records_per_file: int = 0
|
|
):
|
|
"""
|
|
Initialize report writer.
|
|
|
|
Args:
|
|
output_path: Path to output JSONL file (base name if sharding)
|
|
max_records_per_file: Max records per file (0 = no limit, single file)
|
|
"""
|
|
self.output_path = Path(output_path)
|
|
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.max_records_per_file = max_records_per_file
|
|
|
|
# Sharding state
|
|
self._current_shard = 0
|
|
self._records_in_current_shard = 0
|
|
self._shard_files: list[Path] = []
|
|
|
|
def _get_shard_path(self) -> Path:
|
|
"""Get the path for current shard."""
|
|
if self.max_records_per_file > 0:
|
|
base = self.output_path.stem
|
|
suffix = self.output_path.suffix
|
|
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
|
|
else:
|
|
shard_path = self.output_path
|
|
|
|
if shard_path not in self._shard_files:
|
|
self._shard_files.append(shard_path)
|
|
|
|
return shard_path
|
|
|
|
def _check_shard_rotation(self) -> None:
|
|
"""Check if we need to rotate to a new shard file."""
|
|
if self.max_records_per_file > 0:
|
|
if self._records_in_current_shard >= self.max_records_per_file:
|
|
self._current_shard += 1
|
|
self._records_in_current_shard = 0
|
|
|
|
def write(self, report: AutoLabelReport) -> None:
|
|
"""Append a report to the output file."""
|
|
self._check_shard_rotation()
|
|
shard_path = self._get_shard_path()
|
|
with open(shard_path, 'a', encoding='utf-8') as f:
|
|
f.write(report.to_json() + '\n')
|
|
self._records_in_current_shard += 1
|
|
|
|
def write_dict(self, report_dict: dict) -> None:
|
|
"""Append a report dict to the output file (for parallel processing)."""
|
|
self._check_shard_rotation()
|
|
shard_path = self._get_shard_path()
|
|
with open(shard_path, 'a', encoding='utf-8') as f:
|
|
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
|
|
f.flush()
|
|
self._records_in_current_shard += 1
|
|
|
|
def write_batch(self, reports: list[AutoLabelReport]) -> None:
|
|
"""Write multiple reports."""
|
|
for report in reports:
|
|
self.write(report)
|
|
|
|
def get_shard_files(self) -> list[Path]:
|
|
"""Get list of all shard files created."""
|
|
return self._shard_files.copy()
|
|
|
|
|
|
class ReportReader:
|
|
"""Reads auto-label reports from file(s)."""
|
|
|
|
def __init__(self, input_path: str | Path):
|
|
"""
|
|
Initialize report reader.
|
|
|
|
Args:
|
|
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
|
|
"""
|
|
self.input_path = Path(input_path)
|
|
|
|
# Handle glob pattern
|
|
if '*' in str(input_path) or '?' in str(input_path):
|
|
parent = self.input_path.parent
|
|
pattern = self.input_path.name
|
|
self.input_paths = sorted(parent.glob(pattern))
|
|
else:
|
|
self.input_paths = [self.input_path]
|
|
|
|
def read_all(self) -> list[AutoLabelReport]:
|
|
"""Read all reports from file(s)."""
|
|
reports = []
|
|
|
|
for input_path in self.input_paths:
|
|
if not input_path.exists():
|
|
continue
|
|
|
|
with open(input_path, 'r', encoding='utf-8') as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
data = json.loads(line)
|
|
report = self._dict_to_report(data)
|
|
reports.append(report)
|
|
|
|
return reports
|
|
|
|
def _dict_to_report(self, data: dict) -> AutoLabelReport:
|
|
"""Convert dictionary to AutoLabelReport."""
|
|
field_results = []
|
|
for fr_data in data.get('field_results', []):
|
|
bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None
|
|
field_results.append(FieldMatchResult(
|
|
field_name=fr_data['field_name'],
|
|
csv_value=fr_data.get('csv_value'),
|
|
matched=fr_data.get('matched', False),
|
|
score=fr_data.get('score', 0.0),
|
|
matched_text=fr_data.get('matched_text'),
|
|
candidate_used=fr_data.get('candidate_used'),
|
|
bbox=bbox,
|
|
page_no=fr_data.get('page_no', 0),
|
|
context_keywords=fr_data.get('context_keywords', []),
|
|
error=fr_data.get('error')
|
|
))
|
|
|
|
return AutoLabelReport(
|
|
document_id=data['document_id'],
|
|
pdf_path=data.get('pdf_path'),
|
|
pdf_type=data.get('pdf_type'),
|
|
success=data.get('success', False),
|
|
total_pages=data.get('total_pages', 0),
|
|
fields_matched=data.get('fields_matched', 0),
|
|
fields_total=data.get('fields_total', 0),
|
|
field_results=field_results,
|
|
annotations_generated=data.get('annotations_generated', 0),
|
|
image_paths=data.get('image_paths', []),
|
|
label_paths=data.get('label_paths', []),
|
|
processing_time_ms=data.get('processing_time_ms', 0.0),
|
|
timestamp=data.get('timestamp', ''),
|
|
errors=data.get('errors', [])
|
|
)
|
|
|
|
def get_statistics(self) -> dict:
|
|
"""Calculate statistics from all reports."""
|
|
reports = self.read_all()
|
|
|
|
if not reports:
|
|
return {'total': 0}
|
|
|
|
successful = sum(1 for r in reports if r.success)
|
|
total_fields_matched = sum(r.fields_matched for r in reports)
|
|
total_fields = sum(r.fields_total for r in reports)
|
|
total_annotations = sum(r.annotations_generated for r in reports)
|
|
|
|
# Per-field statistics
|
|
field_stats = {}
|
|
for report in reports:
|
|
for fr in report.field_results:
|
|
if fr.field_name not in field_stats:
|
|
field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0}
|
|
field_stats[fr.field_name]['total'] += 1
|
|
if fr.matched:
|
|
field_stats[fr.field_name]['matched'] += 1
|
|
field_stats[fr.field_name]['avg_score'] += fr.score
|
|
|
|
# Calculate averages
|
|
for field_name, stats in field_stats.items():
|
|
if stats['matched'] > 0:
|
|
stats['avg_score'] /= stats['matched']
|
|
stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0
|
|
|
|
return {
|
|
'total': len(reports),
|
|
'successful': successful,
|
|
'success_rate': successful / len(reports),
|
|
'total_fields_matched': total_fields_matched,
|
|
'total_fields': total_fields,
|
|
'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0,
|
|
'total_annotations': total_annotations,
|
|
'field_statistics': field_stats
|
|
}
|