This commit is contained in:
Yaojia Wang
2026-01-13 00:10:27 +01:00
parent 1b7c61cdd8
commit b26fd61852
43 changed files with 7751 additions and 578 deletions

View File

@@ -114,65 +114,114 @@ class AutoLabelReport:
class ReportWriter:
"""Writes auto-label reports to file."""
"""Writes auto-label reports to file with optional sharding."""
def __init__(self, output_path: str | Path):
def __init__(
self,
output_path: str | Path,
max_records_per_file: int = 0
):
"""
Initialize report writer.
Args:
output_path: Path to output JSONL file
output_path: Path to output JSONL file (base name if sharding)
max_records_per_file: Max records per file (0 = no limit, single file)
"""
self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True)
self.max_records_per_file = max_records_per_file
# Sharding state
self._current_shard = 0
self._records_in_current_shard = 0
self._shard_files: list[Path] = []
def _get_shard_path(self) -> Path:
"""Get the path for current shard."""
if self.max_records_per_file > 0:
base = self.output_path.stem
suffix = self.output_path.suffix
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
else:
shard_path = self.output_path
if shard_path not in self._shard_files:
self._shard_files.append(shard_path)
return shard_path
def _check_shard_rotation(self) -> None:
"""Check if we need to rotate to a new shard file."""
if self.max_records_per_file > 0:
if self._records_in_current_shard >= self.max_records_per_file:
self._current_shard += 1
self._records_in_current_shard = 0
def write(self, report: AutoLabelReport) -> None:
"""Append a report to the output file."""
with open(self.output_path, 'a', encoding='utf-8') as f:
self._check_shard_rotation()
shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(report.to_json() + '\n')
self._records_in_current_shard += 1
def write_dict(self, report_dict: dict) -> None:
"""Append a report dict to the output file (for parallel processing)."""
import json
with open(self.output_path, 'a', encoding='utf-8') as f:
self._check_shard_rotation()
shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
f.flush()
self._records_in_current_shard += 1
def write_batch(self, reports: list[AutoLabelReport]) -> None:
"""Write multiple reports."""
with open(self.output_path, 'a', encoding='utf-8') as f:
for report in reports:
f.write(report.to_json() + '\n')
for report in reports:
self.write(report)
def get_shard_files(self) -> list[Path]:
"""Get list of all shard files created."""
return self._shard_files.copy()
class ReportReader:
"""Reads auto-label reports from file."""
"""Reads auto-label reports from file(s)."""
def __init__(self, input_path: str | Path):
"""
Initialize report reader.
Args:
input_path: Path to input JSONL file
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
"""
self.input_path = Path(input_path)
# Handle glob pattern
if '*' in str(input_path) or '?' in str(input_path):
parent = self.input_path.parent
pattern = self.input_path.name
self.input_paths = sorted(parent.glob(pattern))
else:
self.input_paths = [self.input_path]
def read_all(self) -> list[AutoLabelReport]:
"""Read all reports from file."""
"""Read all reports from file(s)."""
reports = []
if not self.input_path.exists():
return reports
for input_path in self.input_paths:
if not input_path.exists():
continue
with open(self.input_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
with open(input_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
report = self._dict_to_report(data)
reports.append(report)
data = json.loads(line)
report = self._dict_to_report(data)
reports.append(report)
return reports