WOP
This commit is contained in:
@@ -114,65 +114,114 @@ class AutoLabelReport:
|
||||
|
||||
|
||||
class ReportWriter:
|
||||
"""Writes auto-label reports to file."""
|
||||
"""Writes auto-label reports to file with optional sharding."""
|
||||
|
||||
def __init__(self, output_path: str | Path):
|
||||
def __init__(
|
||||
self,
|
||||
output_path: str | Path,
|
||||
max_records_per_file: int = 0
|
||||
):
|
||||
"""
|
||||
Initialize report writer.
|
||||
|
||||
Args:
|
||||
output_path: Path to output JSONL file
|
||||
output_path: Path to output JSONL file (base name if sharding)
|
||||
max_records_per_file: Max records per file (0 = no limit, single file)
|
||||
"""
|
||||
self.output_path = Path(output_path)
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.max_records_per_file = max_records_per_file
|
||||
|
||||
# Sharding state
|
||||
self._current_shard = 0
|
||||
self._records_in_current_shard = 0
|
||||
self._shard_files: list[Path] = []
|
||||
|
||||
def _get_shard_path(self) -> Path:
|
||||
"""Get the path for current shard."""
|
||||
if self.max_records_per_file > 0:
|
||||
base = self.output_path.stem
|
||||
suffix = self.output_path.suffix
|
||||
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
|
||||
else:
|
||||
shard_path = self.output_path
|
||||
|
||||
if shard_path not in self._shard_files:
|
||||
self._shard_files.append(shard_path)
|
||||
|
||||
return shard_path
|
||||
|
||||
def _check_shard_rotation(self) -> None:
|
||||
"""Check if we need to rotate to a new shard file."""
|
||||
if self.max_records_per_file > 0:
|
||||
if self._records_in_current_shard >= self.max_records_per_file:
|
||||
self._current_shard += 1
|
||||
self._records_in_current_shard = 0
|
||||
|
||||
def write(self, report: AutoLabelReport) -> None:
|
||||
"""Append a report to the output file."""
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(report.to_json() + '\n')
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_dict(self, report_dict: dict) -> None:
|
||||
"""Append a report dict to the output file (for parallel processing)."""
|
||||
import json
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
|
||||
f.flush()
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_batch(self, reports: list[AutoLabelReport]) -> None:
|
||||
"""Write multiple reports."""
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
for report in reports:
|
||||
f.write(report.to_json() + '\n')
|
||||
for report in reports:
|
||||
self.write(report)
|
||||
|
||||
def get_shard_files(self) -> list[Path]:
|
||||
"""Get list of all shard files created."""
|
||||
return self._shard_files.copy()
|
||||
|
||||
|
||||
class ReportReader:
|
||||
"""Reads auto-label reports from file."""
|
||||
"""Reads auto-label reports from file(s)."""
|
||||
|
||||
def __init__(self, input_path: str | Path):
|
||||
"""
|
||||
Initialize report reader.
|
||||
|
||||
Args:
|
||||
input_path: Path to input JSONL file
|
||||
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
|
||||
"""
|
||||
self.input_path = Path(input_path)
|
||||
|
||||
# Handle glob pattern
|
||||
if '*' in str(input_path) or '?' in str(input_path):
|
||||
parent = self.input_path.parent
|
||||
pattern = self.input_path.name
|
||||
self.input_paths = sorted(parent.glob(pattern))
|
||||
else:
|
||||
self.input_paths = [self.input_path]
|
||||
|
||||
def read_all(self) -> list[AutoLabelReport]:
|
||||
"""Read all reports from file."""
|
||||
"""Read all reports from file(s)."""
|
||||
reports = []
|
||||
|
||||
if not self.input_path.exists():
|
||||
return reports
|
||||
for input_path in self.input_paths:
|
||||
if not input_path.exists():
|
||||
continue
|
||||
|
||||
with open(self.input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
data = json.loads(line)
|
||||
report = self._dict_to_report(data)
|
||||
reports.append(report)
|
||||
data = json.loads(line)
|
||||
report = self._dict_to_report(data)
|
||||
reports.append(report)
|
||||
|
||||
return reports
|
||||
|
||||
|
||||
Reference in New Issue
Block a user