This commit is contained in:
Yaojia Wang
2026-01-13 00:10:27 +01:00
parent 1b7c61cdd8
commit b26fd61852
43 changed files with 7751 additions and 578 deletions

View File

@@ -72,7 +72,7 @@ class CSVLoader:
def __init__(
self,
csv_path: str | Path,
csv_path: str | Path | list[str | Path],
pdf_dir: str | Path | None = None,
doc_map_path: str | Path | None = None,
encoding: str = 'utf-8'
@@ -81,13 +81,31 @@ class CSVLoader:
Initialize CSV loader.
Args:
csv_path: Path to the CSV file
csv_path: Path to CSV file(s). Can be:
- Single path: 'data/file.csv'
- List of paths: ['data/file1.csv', 'data/file2.csv']
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
doc_map_path: Optional path to document mapping CSV
encoding: CSV file encoding (default: utf-8)
"""
self.csv_path = Path(csv_path)
self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs'
# Handle multiple CSV files
if isinstance(csv_path, list):
self.csv_paths = [Path(p) for p in csv_path]
else:
csv_path = Path(csv_path)
# Check if it's a glob pattern (contains * or ?)
if '*' in str(csv_path) or '?' in str(csv_path):
parent = csv_path.parent
pattern = csv_path.name
self.csv_paths = sorted(parent.glob(pattern))
else:
self.csv_paths = [csv_path]
# For backward compatibility
self.csv_path = self.csv_paths[0] if self.csv_paths else None
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
self.encoding = encoding
@@ -185,21 +203,14 @@ class CSVLoader:
raw_data=dict(row)
)
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows."""
def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
"""Iterate over rows from a single CSV file."""
# Handle BOM - try utf-8-sig first to handle BOM correctly
encodings = ['utf-8-sig', self.encoding, 'latin-1']
for enc in encodings:
try:
with open(self.csv_path, 'r', encoding=enc) as f:
with open(csv_path, 'r', encoding=enc) as f:
reader = csv.DictReader(f)
for row in reader:
parsed = self._parse_row(row)
@@ -209,7 +220,27 @@ class CSVLoader:
except UnicodeDecodeError:
continue
raise ValueError(f"Could not read CSV file with any supported encoding")
raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV(s)."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows from all CSV files."""
seen_doc_ids = set()
for csv_path in self.csv_paths:
if not csv_path.exists():
continue
for row in self._iter_single_csv(csv_path):
# Deduplicate by DocumentId
if row.DocumentId not in seen_doc_ids:
seen_doc_ids.add(row.DocumentId)
yield row
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
"""
@@ -300,7 +331,7 @@ class CSVLoader:
return issues
def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV."""
def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV(s)."""
loader = CSVLoader(csv_path, pdf_dir)
return loader.load_all()