Initial commit: Invoice field extraction system using YOLO + OCR

Features:
- Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations
- Flexible date matching: year-month match, nearby date tolerance
- PDF text extraction with PyMuPDF
- OCR support for scanned documents (PaddleOCR)
- YOLO training and inference pipeline
- 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Yaojia Wang
2026-01-10 17:44:14 +01:00
commit 8938661850
35 changed files with 5020 additions and 0 deletions

71
.gitignore vendored Normal file
View File

@@ -0,0 +1,71 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
# Virtual environments
venv/
ENV/
env/
.venv/
# IDE
.idea/
.vscode/
*.swp
*.swo
*~
# Data files (large files)
data/raw_pdfs/
data/dataset/train/images/
data/dataset/val/images/
data/dataset/test/images/
data/dataset/train/labels/
data/dataset/val/labels/
data/dataset/test/labels/
*.pdf
*.png
*.jpg
*.jpeg
# Model weights
models/weights/
runs/
*.pt
*.onnx
# Reports and logs
reports/*.jsonl
logs/
*.log
# Jupyter
.ipynb_checkpoints/
# OS
.DS_Store
Thumbs.db
# Credentials
.env
*.key
*.pem
credentials.json

226
README.md Normal file
View File

@@ -0,0 +1,226 @@
# Invoice Master POC v2
自动账单信息提取系统 - 使用 YOLO + OCR 从 PDF 发票中提取结构化数据。
## 运行环境
> **重要**: 本项目需要在 **WSL (Windows Subsystem for Linux)** 环境下运行。
### 系统要求
- WSL 2 (Ubuntu 22.04 推荐)
- Python 3.10+
- **NVIDIA GPU + CUDA 12.x (强烈推荐)** - GPU 训练比 CPU 快 10-50 倍
## 功能特点
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
- **字段检测**: 使用 YOLOv8 检测发票字段区域
- **OCR 识别**: 使用 PaddleOCR 提取检测区域的文本
- **智能匹配**: 支持多种格式规范化和上下文关键词增强
## 支持的字段
| 字段 | 说明 |
|------|------|
| InvoiceNumber | 发票号码 |
| InvoiceDate | 发票日期 |
| InvoiceDueDate | 到期日期 |
| OCR | OCR 参考号 (瑞典) |
| Bankgiro | Bankgiro 号码 |
| Plusgiro | Plusgiro 号码 |
| Amount | 金额 |
## 安装 (WSL)
### 1. 进入 WSL 环境
```bash
# 从 Windows 终端进入 WSL
wsl
# 进入项目目录 (Windows 路径映射到 /mnt/)
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
```
### 2. 安装系统依赖
```bash
# 更新系统
sudo apt update && sudo apt upgrade -y
# 安装 Python 和必要工具
sudo apt install -y python3.10 python3.10-venv python3-pip
# 安装 OpenCV 依赖
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6
```
### 3. 创建虚拟环境并安装依赖
```bash
# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate
# 升级 pip
pip install --upgrade pip
# 安装依赖
pip install -r requirements.txt
# 或使用 pip install (开发模式)
pip install -e .
```
### GPU 支持 (可选)
```bash
# 确保 WSL 已配置 CUDA
nvidia-smi # 检查 GPU 是否可用
# 安装 GPU 版本 PaddlePaddle
pip install paddlepaddle-gpu
# 或指定 CUDA 版本
pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
```
## 快速开始
### 1. 准备数据
```
data/
├── raw_pdfs/
│ ├── {DocumentId}.pdf
│ └── ...
└── structured_data/
└── invoices.csv
```
CSV 格式:
```csv
DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114
```
### 2. 自动标注
```bash
python -m src.cli.autolabel \
--csv data/structured_data/invoices.csv \
--pdf-dir data/raw_pdfs \
--output data/dataset \
--report reports/autolabel_report.jsonl
```
### 3. 训练模型
> **重要**: 务必使用 GPU 进行训练CPU 训练速度非常慢。
```bash
# GPU 训练 (强烈推荐)
python -m src.cli.train \
--data data/dataset/dataset.yaml \
--model yolo11n.pt \
--epochs 100 \
--batch 16 \
--device 0 # 使用 GPU
# 验证 GPU 可用
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
```
GPU vs CPU 训练时间对比 (100 epochs, 77 训练图片):
- **GPU (RTX 5080)**: ~2 分钟
- **CPU**: 30+ 分钟
### 4. 推理
```bash
python -m src.cli.infer \
--model runs/train/invoice_fields/weights/best.pt \
--input path/to/invoice.pdf \
--output result.json
```
## 输出示例
```json
{
"DocumentId": "3be53fd7-d5ea-458c-a229-8d360b8ba6a9",
"InvoiceNumber": "100017500321",
"InvoiceDate": "2025-12-13",
"InvoiceDueDate": "2026-01-03",
"OCR": "100017500321",
"Bankgiro": "5393-9484",
"Plusgiro": null,
"Amount": "114.00",
"confidence": {
"InvoiceNumber": 0.96,
"InvoiceDate": 0.92,
"Amount": 0.93
}
}
```
## 项目结构
```
invoice-master-poc-v2/
├── src/
│ ├── pdf/ # PDF 处理模块
│ ├── ocr/ # OCR 提取模块
│ ├── normalize/ # 字段规范化模块
│ ├── matcher/ # 字段匹配模块
│ ├── yolo/ # YOLO 标注生成
│ ├── inference/ # 推理管道
│ ├── data/ # 数据加载模块
│ └── cli/ # 命令行工具
├── configs/ # 配置文件
├── data/ # 数据目录
└── requirements.txt
```
## 开发优先级
1. ✅ 文本层 PDF 自动标注
2. ✅ 扫描图 OCR 自动标注
3. 🔄 金额 / OCR / Bankgiro 三字段稳定
4. ⏳ 日期、Plusgiro 扩展
5. ⏳ 表格 items 处理
## 配置
编辑 `configs/default.yaml` 自定义:
- PDF 渲染 DPI
- OCR 语言
- 匹配置信度阈值
- 上下文关键词
- 数据增强参数
## API 使用
```python
from src.inference import InferencePipeline
# 初始化
pipeline = InferencePipeline(
model_path='models/best.pt',
confidence_threshold=0.5,
ocr_lang='en'
)
# 处理 PDF
result = pipeline.process_pdf('invoice.pdf')
# 获取字段
print(result.fields)
print(result.confidence)
```
## 许可证
MIT License

129
configs/default.yaml Normal file
View File

@@ -0,0 +1,129 @@
# Invoice Master POC v2 Configuration
# Default configuration for invoice field extraction system
# PDF Processing
pdf:
dpi: 300 # Resolution for rendering PDFs to images
min_text_chars: 30 # Minimum chars to consider PDF as text-based
# OCR Settings
ocr:
engine: paddleocr # OCR engine to use
lang: en # Language code (en, sv, ch, etc.)
use_gpu: false # Enable GPU acceleration
# Field Normalization
normalize:
# Bankgiro formats
bankgiro:
format: "XXXX-XXXX" # Standard 8-digit format
alternatives:
- "XXX-XXXX" # 7-digit format
# Plusgiro formats
plusgiro:
format: "XXXXXXX-X" # Standard format with check digit
# Amount formats
amount:
decimal_separator: "," # Swedish uses comma
thousand_separator: " " # Space for thousands
currency_symbols:
- "SEK"
- "kr"
# Date formats
date:
output_format: "%Y-%m-%d"
input_formats:
- "%Y-%m-%d"
- "%Y-%m-%d %H:%M:%S"
- "%d/%m/%Y"
- "%d.%m.%Y"
# Field Matching
matching:
min_score_threshold: 0.7 # Minimum score to accept match
context_radius: 100 # Pixels to search for context keywords
# Context keywords for each field (Swedish)
context_keywords:
InvoiceNumber:
- "fakturanr"
- "fakturanummer"
- "invoice"
InvoiceDate:
- "fakturadatum"
- "datum"
InvoiceDueDate:
- "förfallodatum"
- "förfaller"
- "betalas senast"
OCR:
- "ocr"
- "referens"
Bankgiro:
- "bankgiro"
- "bg"
Plusgiro:
- "plusgiro"
- "pg"
Amount:
- "att betala"
- "summa"
- "total"
- "belopp"
# YOLO Training
yolo:
model: yolov8s # Model architecture (yolov8n/s/m/l/x)
epochs: 100
batch_size: 16
img_size: 1280 # Image size for training
# Data augmentation
augmentation:
rotation: 5 # Max rotation degrees
scale: 0.2 # Scale variation
mosaic: 0.0 # Disable mosaic for documents
hsv_h: 0.0 # No hue variation
hsv_s: 0.1 # Slight saturation variation
hsv_v: 0.2 # Brightness variation
# Class definitions
classes:
0: invoice_number
1: invoice_date
2: invoice_due_date
3: ocr_number
4: bankgiro
5: plusgiro
6: amount
# Auto-labeling
autolabel:
min_confidence: 0.7 # Minimum score to include in training
bbox_padding: 0.02 # Padding around bboxes (fraction of image)
# Dataset Split
dataset:
train_ratio: 0.8
val_ratio: 0.1
test_ratio: 0.1
random_seed: 42
# Inference
inference:
confidence_threshold: 0.5 # Detection confidence threshold
iou_threshold: 0.45 # NMS IOU threshold
enable_fallback: true # Enable regex fallback if YOLO fails
fallback_min_missing: 2 # Min missing fields to trigger fallback
# Paths (relative to project root)
paths:
raw_pdfs: data/raw_pdfs
images: data/images
labels: data/labels
structured_data: data/structured_data
models: models
reports: reports

59
configs/training.yaml Normal file
View File

@@ -0,0 +1,59 @@
# YOLO Training Configuration
# Use with: yolo train data=dataset.yaml cfg=training.yaml
# Model
model: yolov8s.pt
# Training hyperparameters
epochs: 100
patience: 20 # Early stopping patience
batch: 16
imgsz: 1280
# Optimizer
optimizer: AdamW
lr0: 0.001 # Initial learning rate
lrf: 0.01 # Final learning rate factor
momentum: 0.937
weight_decay: 0.0005
# Warmup
warmup_epochs: 3
warmup_momentum: 0.8
warmup_bias_lr: 0.1
# Loss weights
box: 7.5 # Box loss gain
cls: 0.5 # Class loss gain
dfl: 1.5 # DFL loss gain
# Augmentation
# Keep minimal for document images
hsv_h: 0.0 # No hue augmentation
hsv_s: 0.1 # Slight saturation
hsv_v: 0.2 # Brightness variation
degrees: 5.0 # Rotation ±5°
translate: 0.05 # Translation
scale: 0.2 # Scale ±20%
shear: 0.0 # No shear
perspective: 0.0 # No perspective
flipud: 0.0 # No vertical flip
fliplr: 0.0 # No horizontal flip
mosaic: 0.0 # Disable mosaic (not suitable for documents)
mixup: 0.0 # Disable mixup
copy_paste: 0.0 # Disable copy-paste
# Validation
val: true
save: true
save_period: 10
cache: true
# Other
device: 0 # GPU device (0, 1, etc.) or 'cpu'
workers: 8
project: runs/train
name: invoice_fields
exist_ok: true
pretrained: true
verbose: true

78
pyproject.toml Normal file
View File

@@ -0,0 +1,78 @@
[build-system]
requires = ["setuptools>=68.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "invoice-master"
version = "2.0.0"
description = "Automatic invoice information extraction using YOLO + OCR"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "MIT"}
authors = [
{name = "Invoice Master Team"}
]
keywords = ["invoice", "ocr", "yolo", "document-processing", "pdf"]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"PyMuPDF>=1.23.0",
"paddlepaddle>=2.5.0",
"paddleocr>=2.7.0",
"ultralytics>=8.1.0",
"Pillow>=10.0.0",
"numpy>=1.24.0",
"opencv-python>=4.8.0",
"pyyaml>=6.0",
"tqdm>=4.65.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-cov>=4.0.0",
"black>=23.0.0",
"ruff>=0.1.0",
"mypy>=1.0.0",
]
gpu = [
"paddlepaddle-gpu>=2.5.0",
]
[project.scripts]
invoice-autolabel = "src.cli.autolabel:main"
invoice-train = "src.cli.train:main"
invoice-infer = "src.cli.infer:main"
[tool.setuptools.packages.find]
where = ["."]
include = ["src*"]
[tool.black]
line-length = 100
target-version = ["py310", "py311", "py312"]
[tool.ruff]
line-length = 100
target-version = "py310"
select = ["E", "F", "W", "I", "N", "D", "UP", "B", "C4", "SIM"]
ignore = ["D100", "D104"]
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_ignores = true
disallow_untyped_defs = true
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
addopts = "-v --cov=src --cov-report=term-missing"

22
requirements.txt Normal file
View File

@@ -0,0 +1,22 @@
# Invoice Master POC v2 - Dependencies
# PDF Processing
PyMuPDF>=1.23.0 # PDF rendering and text extraction
# OCR
paddlepaddle>=2.5.0 # PaddlePaddle framework
paddleocr>=2.7.0 # PaddleOCR
# YOLO
ultralytics>=8.1.0 # YOLOv8/v11
# Image Processing
Pillow>=10.0.0 # Image handling
numpy>=1.24.0 # Array operations
opencv-python>=4.8.0 # Image processing
# Data Processing
pyyaml>=6.0 # YAML config files
# Utilities
tqdm>=4.65.0 # Progress bars

10
run_autolabel.py Normal file
View File

@@ -0,0 +1,10 @@
#!/usr/bin/env python3
"""
自动标注脚本 - 调用 CLI 模块
在 WSL 中运行: python run_autolabel.py
"""
from src.cli.autolabel import main
if __name__ == '__main__':
main()

55
scripts/run_autolabel.sh Normal file
View File

@@ -0,0 +1,55 @@
#!/bin/bash
# 自动标注运行脚本
# 使用方法: bash scripts/run_autolabel.sh
set -e
# 项目根目录
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_DIR"
# 激活虚拟环境
if [ -f "venv/bin/activate" ]; then
source venv/bin/activate
else
echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh"
exit 1
fi
# 默认参数
CSV_FILE="${CSV_FILE:-data/structured_data/invoices.csv}"
PDF_DIR="${PDF_DIR:-data/raw_pdfs}"
OUTPUT_DIR="${OUTPUT_DIR:-data/dataset}"
REPORT_FILE="${REPORT_FILE:-reports/autolabel_report.jsonl}"
DPI="${DPI:-300}"
MIN_CONFIDENCE="${MIN_CONFIDENCE:-0.7}"
# 显示配置
echo "=========================================="
echo "自动标注配置"
echo "=========================================="
echo "CSV 文件: $CSV_FILE"
echo "PDF 目录: $PDF_DIR"
echo "输出目录: $OUTPUT_DIR"
echo "报告文件: $REPORT_FILE"
echo "DPI: $DPI"
echo "最小置信度: $MIN_CONFIDENCE"
echo "=========================================="
echo ""
# 创建必要目录
mkdir -p "$(dirname "$REPORT_FILE")"
mkdir -p "$OUTPUT_DIR"
# 运行自动标注
python -m src.cli.autolabel \
--csv "$CSV_FILE" \
--pdf-dir "$PDF_DIR" \
--output "$OUTPUT_DIR" \
--report "$REPORT_FILE" \
--dpi "$DPI" \
--min-confidence "$MIN_CONFIDENCE" \
--verbose
echo ""
echo "完成! 数据集已生成到: $OUTPUT_DIR"

67
scripts/run_train.sh Normal file
View File

@@ -0,0 +1,67 @@
#!/bin/bash
# 训练运行脚本
# 使用方法: bash scripts/run_train.sh
set -e
# 项目根目录
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
cd "$PROJECT_DIR"
# 激活虚拟环境
if [ -f "venv/bin/activate" ]; then
source venv/bin/activate
else
echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh"
exit 1
fi
# 默认参数
DATA_YAML="${DATA_YAML:-data/dataset/dataset.yaml}"
MODEL="${MODEL:-yolov8s.pt}"
EPOCHS="${EPOCHS:-100}"
BATCH_SIZE="${BATCH_SIZE:-16}"
IMG_SIZE="${IMG_SIZE:-1280}"
DEVICE="${DEVICE:-0}"
# 检查数据集是否存在
if [ ! -f "$DATA_YAML" ]; then
echo "错误: 数据集配置文件不存在: $DATA_YAML"
echo "请先运行自动标注: bash scripts/run_autolabel.sh"
exit 1
fi
# 显示配置
echo "=========================================="
echo "训练配置"
echo "=========================================="
echo "数据集: $DATA_YAML"
echo "基础模型: $MODEL"
echo "Epochs: $EPOCHS"
echo "Batch Size: $BATCH_SIZE"
echo "图像尺寸: $IMG_SIZE"
echo "设备: $DEVICE"
echo "=========================================="
echo ""
# 检查 GPU
if command -v nvidia-smi &> /dev/null; then
echo "GPU 状态:"
nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader
echo ""
else
echo "警告: 未检测到 GPU将使用 CPU 训练 (较慢)"
DEVICE="cpu"
fi
# 运行训练
python -m src.cli.train \
--data "$DATA_YAML" \
--model "$MODEL" \
--epochs "$EPOCHS" \
--batch "$BATCH_SIZE" \
--imgsz "$IMG_SIZE" \
--device "$DEVICE"
echo ""
echo "训练完成! 模型保存在: runs/train/invoice_fields/weights/"

80
scripts/setup_wsl.sh Normal file
View File

@@ -0,0 +1,80 @@
#!/bin/bash
# WSL 环境安装脚本
# 使用方法: bash scripts/setup_wsl.sh
set -e
echo "=========================================="
echo "Invoice Master POC v2 - WSL 安装脚本"
echo "=========================================="
# 检查是否在 WSL 中运行
if ! grep -qi microsoft /proc/version 2>/dev/null; then
echo "警告: 未检测到 WSL 环境,请在 WSL 中运行此脚本"
echo "提示: 在 Windows 终端中输入 'wsl' 进入 WSL"
exit 1
fi
echo ""
echo "[1/5] 更新系统包..."
sudo apt update
echo ""
echo "[2/5] 安装系统依赖..."
sudo apt install -y \
python3.10 \
python3.10-venv \
python3-pip \
libgl1-mesa-glx \
libglib2.0-0 \
libsm6 \
libxrender1 \
libxext6 \
libgomp1
echo ""
echo "[3/5] 创建 Python 虚拟环境..."
if [ -d "venv" ]; then
echo "虚拟环境已存在,跳过创建"
else
python3 -m venv venv
fi
echo ""
echo "[4/5] 激活虚拟环境并安装依赖..."
source venv/bin/activate
pip install --upgrade pip
echo ""
echo "安装 Python 依赖包..."
pip install -r requirements.txt
echo ""
echo "[5/5] 验证安装..."
python3 -c "import fitz; print(f'PyMuPDF: {fitz.version}')"
python3 -c "from ultralytics import YOLO; print('Ultralytics: OK')"
python3 -c "from paddleocr import PaddleOCR; print('PaddleOCR: OK')"
echo ""
echo "=========================================="
echo "安装完成!"
echo "=========================================="
echo ""
echo "使用方法:"
echo " 1. 激活虚拟环境: source venv/bin/activate"
echo " 2. 运行自动标注: python -m src.cli.autolabel --help"
echo " 3. 训练模型: python -m src.cli.train --help"
echo " 4. 推理: python -m src.cli.infer --help"
echo ""
# 检查 GPU
echo "检查 GPU 支持..."
if command -v nvidia-smi &> /dev/null; then
echo "检测到 NVIDIA GPU:"
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
echo ""
echo "提示: 运行以下命令启用 GPU 加速:"
echo " pip install paddlepaddle-gpu"
else
echo "未检测到 GPU将使用 CPU 模式"
fi

2
src/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# Invoice Master POC v2
# Automatic invoice information extraction system using YOLO + OCR

1
src/cli/__init__.py Normal file
View File

@@ -0,0 +1 @@
# CLI modules for Invoice Master

458
src/cli/autolabel.py Normal file
View File

@@ -0,0 +1,458 @@
#!/usr/bin/env python3
"""
Auto-labeling CLI
Generates YOLO training data from PDFs and structured CSV data.
"""
import argparse
import sys
import time
from pathlib import Path
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing
# Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None
def _init_worker():
"""Initialize worker process with OCR engine (called once per worker)."""
global _worker_ocr_engine
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
def _get_ocr_engine():
"""Get or create OCR engine for current worker."""
global _worker_ocr_engine
if _worker_ocr_engine is None:
from ..ocr import OCREngine
_worker_ocr_engine = OCREngine()
return _worker_ocr_engine
def process_single_document(args_tuple):
"""
Process a single document (worker function for parallel processing).
Args:
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
Returns:
dict with results
"""
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
# Import inside worker to avoid pickling issues
from ..data import AutoLabelReport, FieldMatchResult
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..matcher import FieldMatcher
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
start_time = time.time()
pdf_path = Path(pdf_path_str)
output_dir = Path(output_dir_str)
doc_id = row_dict['DocumentId']
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
result = {
'doc_id': doc_id,
'success': False,
'pages': [],
'report': None,
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
}
try:
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
report.pdf_type = "scanned" if use_ocr else "text"
# Skip OCR if requested
if use_ocr and skip_ocr:
report.errors.append("Skipped (scanned PDF)")
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
# Get OCR engine from worker cache (only created once per worker)
ocr_engine = None
if use_ocr:
ocr_engine = _get_ocr_engine()
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
# Process each page
page_annotations = []
for page_no, image_path in render_pdf_to_images(
pdf_path,
output_dir / 'temp' / doc_id / 'images',
dpi=dpi
):
report.total_pages += 1
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
if use_ocr:
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
# Record result
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords
))
else:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=page_no
))
# Generate annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
if annotations:
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
page_annotations.append({
'image_path': str(image_path),
'label_path': str(label_path),
'count': len(annotations)
})
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result['stats'][class_name] += 1
if page_annotations:
result['pages'] = page_annotations
result['success'] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result['report'] = report.to_dict()
return result
def main():
parser = argparse.ArgumentParser(
description='Generate YOLO annotations from PDFs and CSV data'
)
parser.add_argument(
'--csv', '-c',
default='data/structured_data/document_export_20260109_212743.csv',
help='Path to structured data CSV file'
)
parser.add_argument(
'--pdf-dir', '-p',
default='data/raw_pdfs',
help='Directory containing PDF files'
)
parser.add_argument(
'--output', '-o',
default='data/dataset',
help='Output directory for dataset'
)
parser.add_argument(
'--dpi',
type=int,
default=300,
help='DPI for PDF rendering (default: 300)'
)
parser.add_argument(
'--min-confidence',
type=float,
default=0.7,
help='Minimum match confidence (default: 0.7)'
)
parser.add_argument(
'--train-ratio',
type=float,
default=0.8,
help='Training set ratio (default: 0.8)'
)
parser.add_argument(
'--val-ratio',
type=float,
default=0.1,
help='Validation set ratio (default: 0.1)'
)
parser.add_argument(
'--report',
default='reports/autolabel_report.jsonl',
help='Path for auto-label report (JSONL)'
)
parser.add_argument(
'--single',
help='Process single document ID only'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
parser.add_argument(
'--workers', '-w',
type=int,
default=4,
help='Number of parallel workers (default: 4)'
)
parser.add_argument(
'--skip-ocr',
action='store_true',
help='Skip scanned PDFs (text-layer only)'
)
args = parser.parse_args()
# Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter
from ..yolo import DatasetBuilder
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine
from ..matcher import FieldMatcher
from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
print(f"Loading CSV data from: {args.csv}")
loader = CSVLoader(args.csv, args.pdf_dir)
# Validate data
issues = loader.validate()
if issues:
print(f"Warning: Found {len(issues)} validation issues")
if args.verbose:
for issue in issues[:10]:
print(f" - {issue}")
rows = loader.load_all()
print(f"Loaded {len(rows)} invoice records")
# Filter to single document if specified
if args.single:
rows = [r for r in rows if r.DocumentId == args.single]
if not rows:
print(f"Error: Document {args.single} not found")
sys.exit(1)
print(f"Processing single document: {args.single}")
# Setup output directories
output_dir = Path(args.output)
for split in ['train', 'val', 'test']:
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Generate YOLO config files
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
# Report writer
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report)
# Stats
stats = {
'total': len(rows),
'successful': 0,
'failed': 0,
'skipped': 0,
'annotations': 0,
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Prepare tasks
tasks = []
for row in rows:
pdf_path = loader.get_pdf_path(row)
if not pdf_path:
# Write report for missing PDF
report = AutoLabelReport(document_id=row.DocumentId)
report.errors.append("PDF not found")
report_writer.write(report)
stats['failed'] += 1
continue
# Convert row to dict for pickling
row_dict = {
'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber,
'InvoiceDate': row.InvoiceDate,
'InvoiceDueDate': row.InvoiceDueDate,
'OCR': row.OCR,
'Bankgiro': row.Bankgiro,
'Plusgiro': row.Plusgiro,
'Amount': row.Amount,
}
tasks.append((
row_dict,
str(pdf_path),
str(output_dir),
args.dpi,
args.min_confidence,
args.skip_ocr
))
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
# Process documents in parallel
processed_items = []
# Use single process for debugging or when workers=1
if args.workers == 1:
for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task)
# Write report
if result['report']:
report_writer.write_dict(result['report'])
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
else:
# Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
for task in tasks}
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
doc_id = futures[future]
try:
result = future.result()
# Write report
if result['report']:
report_writer.write_dict(result['report'])
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
except Exception as e:
stats['failed'] += 1
# Write error report for failed documents
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(e)}"]
}
report_writer.write_dict(error_report)
if args.verbose:
print(f"Error processing {doc_id}: {e}")
# Split and move files
import random
random.seed(42)
random.shuffle(processed_items)
n_train = int(len(processed_items) * args.train_ratio)
n_val = int(len(processed_items) * args.val_ratio)
splits = {
'train': processed_items[:n_train],
'val': processed_items[n_train:n_train + n_val],
'test': processed_items[n_train + n_val:]
}
import shutil
for split_name, items in splits.items():
for item in items:
for page in item['pages']:
# Move image
image_path = Path(page['image_path'])
label_path = Path(page['label_path'])
dest_img = output_dir / split_name / 'images' / image_path.name
shutil.move(str(image_path), str(dest_img))
# Move label
dest_label = output_dir / split_name / 'labels' / label_path.name
shutil.move(str(label_path), str(dest_label))
# Cleanup temp
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
# Print summary
print("\n" + "=" * 50)
print("Auto-labeling Complete")
print("=" * 50)
print(f"Total documents: {stats['total']}")
print(f"Successful: {stats['successful']}")
print(f"Failed: {stats['failed']}")
print(f"Skipped (OCR): {stats['skipped']}")
print(f"Total annotations: {stats['annotations']}")
print(f"\nDataset split:")
print(f" Train: {len(splits['train'])} documents")
print(f" Val: {len(splits['val'])} documents")
print(f" Test: {len(splits['test'])} documents")
print(f"\nAnnotations by field:")
for field, count in stats['by_field'].items():
print(f" {field}: {count}")
print(f"\nOutput: {output_dir}")
print(f"Report: {args.report}")
if __name__ == '__main__':
main()

139
src/cli/infer.py Normal file
View File

@@ -0,0 +1,139 @@
#!/usr/bin/env python3
"""
Inference CLI
Runs inference on new PDFs to extract invoice data.
"""
import argparse
import json
import sys
from pathlib import Path
def main():
parser = argparse.ArgumentParser(
description='Extract invoice data from PDFs using trained model'
)
parser.add_argument(
'--model', '-m',
required=True,
help='Path to trained YOLO model (.pt file)'
)
parser.add_argument(
'--input', '-i',
required=True,
help='Input PDF file or directory'
)
parser.add_argument(
'--output', '-o',
help='Output JSON file (default: stdout)'
)
parser.add_argument(
'--confidence',
type=float,
default=0.5,
help='Detection confidence threshold (default: 0.5)'
)
parser.add_argument(
'--dpi',
type=int,
default=300,
help='DPI for PDF rendering (default: 300)'
)
parser.add_argument(
'--no-fallback',
action='store_true',
help='Disable fallback OCR'
)
parser.add_argument(
'--lang',
default='en',
help='OCR language (default: en)'
)
parser.add_argument(
'--gpu',
action='store_true',
help='Use GPU'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Verbose output'
)
args = parser.parse_args()
# Validate model
model_path = Path(args.model)
if not model_path.exists():
print(f"Error: Model not found: {model_path}", file=sys.stderr)
sys.exit(1)
# Get input files
input_path = Path(args.input)
if input_path.is_file():
pdf_files = [input_path]
elif input_path.is_dir():
pdf_files = list(input_path.glob('*.pdf'))
else:
print(f"Error: Input not found: {input_path}", file=sys.stderr)
sys.exit(1)
if not pdf_files:
print("Error: No PDF files found", file=sys.stderr)
sys.exit(1)
if args.verbose:
print(f"Processing {len(pdf_files)} PDF file(s)")
print(f"Model: {model_path}")
from ..inference import InferencePipeline
# Initialize pipeline
pipeline = InferencePipeline(
model_path=model_path,
confidence_threshold=args.confidence,
ocr_lang=args.lang,
use_gpu=args.gpu,
dpi=args.dpi,
enable_fallback=not args.no_fallback
)
# Process files
results = []
for pdf_path in pdf_files:
if args.verbose:
print(f"Processing: {pdf_path.name}")
result = pipeline.process_pdf(pdf_path)
results.append(result.to_json())
if args.verbose:
print(f" Success: {result.success}")
print(f" Fields: {len(result.fields)}")
if result.fallback_used:
print(f" Fallback used: Yes")
if result.errors:
print(f" Errors: {result.errors}")
# Output results
if len(results) == 1:
output = results[0]
else:
output = results
json_output = json.dumps(output, indent=2, ensure_ascii=False)
if args.output:
with open(args.output, 'w', encoding='utf-8') as f:
f.write(json_output)
if args.verbose:
print(f"\nResults written to: {args.output}")
else:
print(json_output)
if __name__ == '__main__':
main()

138
src/cli/train.py Normal file
View File

@@ -0,0 +1,138 @@
#!/usr/bin/env python3
"""
Training CLI
Trains YOLO model on generated dataset.
"""
import argparse
import sys
from pathlib import Path
def main():
parser = argparse.ArgumentParser(
description='Train YOLO model for invoice field detection'
)
parser.add_argument(
'--data', '-d',
required=True,
help='Path to dataset.yaml file'
)
parser.add_argument(
'--model', '-m',
default='yolov8s.pt',
help='Base model (default: yolov8s.pt)'
)
parser.add_argument(
'--epochs', '-e',
type=int,
default=100,
help='Number of epochs (default: 100)'
)
parser.add_argument(
'--batch', '-b',
type=int,
default=16,
help='Batch size (default: 16)'
)
parser.add_argument(
'--imgsz',
type=int,
default=1280,
help='Image size (default: 1280)'
)
parser.add_argument(
'--project',
default='runs/train',
help='Project directory (default: runs/train)'
)
parser.add_argument(
'--name',
default='invoice_fields',
help='Run name (default: invoice_fields)'
)
parser.add_argument(
'--device',
default='0',
help='Device (0, 1, cpu, mps)'
)
parser.add_argument(
'--resume',
help='Resume from checkpoint'
)
parser.add_argument(
'--config',
help='Path to training config YAML'
)
args = parser.parse_args()
# Validate data file
data_path = Path(args.data)
if not data_path.exists():
print(f"Error: Dataset file not found: {data_path}")
sys.exit(1)
print(f"Training YOLO model for invoice field detection")
print(f"Dataset: {args.data}")
print(f"Model: {args.model}")
print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch}")
print(f"Image size: {args.imgsz}")
from ultralytics import YOLO
# Load model
if args.resume:
print(f"Resuming from: {args.resume}")
model = YOLO(args.resume)
else:
model = YOLO(args.model)
# Training arguments
train_args = {
'data': str(data_path.absolute()),
'epochs': args.epochs,
'batch': args.batch,
'imgsz': args.imgsz,
'project': args.project,
'name': args.name,
'device': args.device,
'exist_ok': True,
'pretrained': True,
'verbose': True,
# Document-specific augmentation settings
'degrees': 5.0,
'translate': 0.05,
'scale': 0.2,
'shear': 0.0,
'perspective': 0.0,
'flipud': 0.0,
'fliplr': 0.0,
'mosaic': 0.0,
'mixup': 0.0,
'hsv_h': 0.0,
'hsv_s': 0.1,
'hsv_v': 0.2,
}
# Train
results = model.train(**train_args)
# Print results
print("\n" + "=" * 50)
print("Training Complete")
print("=" * 50)
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
# Validate on test set
print("\nRunning validation...")
metrics = model.val()
print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}")
if __name__ == '__main__':
main()

4
src/data/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .csv_loader import CSVLoader, InvoiceRow
from .autolabel_report import AutoLabelReport, FieldMatchResult
__all__ = ['CSVLoader', 'InvoiceRow', 'AutoLabelReport', 'FieldMatchResult']

View File

@@ -0,0 +1,252 @@
"""
Auto-Label Report Generator
Generates quality control reports for auto-labeling process.
"""
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Any
@dataclass
class FieldMatchResult:
"""Result of matching a single field."""
field_name: str
csv_value: str | None
matched: bool
score: float = 0.0
matched_text: str | None = None
candidate_used: str | None = None # Which normalized variant matched
bbox: tuple[float, float, float, float] | None = None
page_no: int = 0
context_keywords: list[str] = field(default_factory=list)
error: str | None = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
# Convert bbox to native Python floats to avoid numpy serialization issues
bbox_list = None
if self.bbox:
bbox_list = [float(x) for x in self.bbox]
return {
'field_name': self.field_name,
'csv_value': self.csv_value,
'matched': self.matched,
'score': float(self.score) if self.score else 0.0,
'matched_text': self.matched_text,
'candidate_used': self.candidate_used,
'bbox': bbox_list,
'page_no': int(self.page_no) if self.page_no else 0,
'context_keywords': self.context_keywords,
'error': self.error
}
@dataclass
class AutoLabelReport:
"""Report for a single document's auto-labeling process."""
document_id: str
pdf_path: str | None = None
pdf_type: str | None = None # 'text' | 'scanned' | 'mixed'
success: bool = False
total_pages: int = 0
fields_matched: int = 0
fields_total: int = 0
field_results: list[FieldMatchResult] = field(default_factory=list)
annotations_generated: int = 0
image_paths: list[str] = field(default_factory=list)
label_paths: list[str] = field(default_factory=list)
processing_time_ms: float = 0.0
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
errors: list[str] = field(default_factory=list)
def add_field_result(self, result: FieldMatchResult) -> None:
"""Add a field matching result."""
self.field_results.append(result)
self.fields_total += 1
if result.matched:
self.fields_matched += 1
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
'document_id': self.document_id,
'pdf_path': self.pdf_path,
'pdf_type': self.pdf_type,
'success': self.success,
'total_pages': self.total_pages,
'fields_matched': self.fields_matched,
'fields_total': self.fields_total,
'field_results': [r.to_dict() for r in self.field_results],
'annotations_generated': self.annotations_generated,
'image_paths': self.image_paths,
'label_paths': self.label_paths,
'processing_time_ms': self.processing_time_ms,
'timestamp': self.timestamp,
'errors': self.errors
}
def to_json(self, indent: int | None = None) -> str:
"""Convert to JSON string."""
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
@property
def match_rate(self) -> float:
"""Calculate field match rate."""
if self.fields_total == 0:
return 0.0
return self.fields_matched / self.fields_total
def get_summary(self) -> dict:
"""Get a summary of the report."""
return {
'document_id': self.document_id,
'success': self.success,
'match_rate': f"{self.match_rate:.1%}",
'fields': f"{self.fields_matched}/{self.fields_total}",
'annotations': self.annotations_generated,
'errors': len(self.errors)
}
class ReportWriter:
"""Writes auto-label reports to file."""
def __init__(self, output_path: str | Path):
"""
Initialize report writer.
Args:
output_path: Path to output JSONL file
"""
self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True)
def write(self, report: AutoLabelReport) -> None:
"""Append a report to the output file."""
with open(self.output_path, 'a', encoding='utf-8') as f:
f.write(report.to_json() + '\n')
def write_dict(self, report_dict: dict) -> None:
"""Append a report dict to the output file (for parallel processing)."""
import json
with open(self.output_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
f.flush()
def write_batch(self, reports: list[AutoLabelReport]) -> None:
"""Write multiple reports."""
with open(self.output_path, 'a', encoding='utf-8') as f:
for report in reports:
f.write(report.to_json() + '\n')
class ReportReader:
"""Reads auto-label reports from file."""
def __init__(self, input_path: str | Path):
"""
Initialize report reader.
Args:
input_path: Path to input JSONL file
"""
self.input_path = Path(input_path)
def read_all(self) -> list[AutoLabelReport]:
"""Read all reports from file."""
reports = []
if not self.input_path.exists():
return reports
with open(self.input_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
data = json.loads(line)
report = self._dict_to_report(data)
reports.append(report)
return reports
def _dict_to_report(self, data: dict) -> AutoLabelReport:
"""Convert dictionary to AutoLabelReport."""
field_results = []
for fr_data in data.get('field_results', []):
bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None
field_results.append(FieldMatchResult(
field_name=fr_data['field_name'],
csv_value=fr_data.get('csv_value'),
matched=fr_data.get('matched', False),
score=fr_data.get('score', 0.0),
matched_text=fr_data.get('matched_text'),
candidate_used=fr_data.get('candidate_used'),
bbox=bbox,
page_no=fr_data.get('page_no', 0),
context_keywords=fr_data.get('context_keywords', []),
error=fr_data.get('error')
))
return AutoLabelReport(
document_id=data['document_id'],
pdf_path=data.get('pdf_path'),
pdf_type=data.get('pdf_type'),
success=data.get('success', False),
total_pages=data.get('total_pages', 0),
fields_matched=data.get('fields_matched', 0),
fields_total=data.get('fields_total', 0),
field_results=field_results,
annotations_generated=data.get('annotations_generated', 0),
image_paths=data.get('image_paths', []),
label_paths=data.get('label_paths', []),
processing_time_ms=data.get('processing_time_ms', 0.0),
timestamp=data.get('timestamp', ''),
errors=data.get('errors', [])
)
def get_statistics(self) -> dict:
"""Calculate statistics from all reports."""
reports = self.read_all()
if not reports:
return {'total': 0}
successful = sum(1 for r in reports if r.success)
total_fields_matched = sum(r.fields_matched for r in reports)
total_fields = sum(r.fields_total for r in reports)
total_annotations = sum(r.annotations_generated for r in reports)
# Per-field statistics
field_stats = {}
for report in reports:
for fr in report.field_results:
if fr.field_name not in field_stats:
field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0}
field_stats[fr.field_name]['total'] += 1
if fr.matched:
field_stats[fr.field_name]['matched'] += 1
field_stats[fr.field_name]['avg_score'] += fr.score
# Calculate averages
for field_name, stats in field_stats.items():
if stats['matched'] > 0:
stats['avg_score'] /= stats['matched']
stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0
return {
'total': len(reports),
'successful': successful,
'success_rate': successful / len(reports),
'total_fields_matched': total_fields_matched,
'total_fields': total_fields,
'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0,
'total_annotations': total_annotations,
'field_statistics': field_stats
}

306
src/data/csv_loader.py Normal file
View File

@@ -0,0 +1,306 @@
"""
CSV Data Loader
Loads and parses structured invoice data from CSV files.
Follows the CSV specification for invoice data.
"""
import csv
from dataclasses import dataclass, field
from datetime import datetime, date
from decimal import Decimal, InvalidOperation
from pathlib import Path
from typing import Any, Iterator
@dataclass
class InvoiceRow:
"""Parsed invoice data row."""
DocumentId: str
InvoiceDate: date | None = None
InvoiceNumber: str | None = None
InvoiceDueDate: date | None = None
OCR: str | None = None
Message: str | None = None
Bankgiro: str | None = None
Plusgiro: str | None = None
Amount: Decimal | None = None
# Raw values for reference
raw_data: dict = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for matching."""
return {
'DocumentId': self.DocumentId,
'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None,
'InvoiceNumber': self.InvoiceNumber,
'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None,
'OCR': self.OCR,
'Bankgiro': self.Bankgiro,
'Plusgiro': self.Plusgiro,
'Amount': str(self.Amount) if self.Amount else None,
}
def get_field_value(self, field_name: str) -> str | None:
"""Get field value as string for matching."""
value = getattr(self, field_name, None)
if value is None:
return None
if isinstance(value, date):
return value.isoformat()
if isinstance(value, Decimal):
return str(value)
return str(value) if value else None
class CSVLoader:
"""Loads invoice data from CSV files."""
# Expected field mappings (CSV header -> InvoiceRow attribute)
FIELD_MAPPINGS = {
'DocumentId': 'DocumentId',
'InvoiceDate': 'InvoiceDate',
'InvoiceNumber': 'InvoiceNumber',
'InvoiceDueDate': 'InvoiceDueDate',
'OCR': 'OCR',
'Message': 'Message',
'Bankgiro': 'Bankgiro',
'Plusgiro': 'Plusgiro',
'Amount': 'Amount',
}
def __init__(
self,
csv_path: str | Path,
pdf_dir: str | Path | None = None,
doc_map_path: str | Path | None = None,
encoding: str = 'utf-8'
):
"""
Initialize CSV loader.
Args:
csv_path: Path to the CSV file
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
doc_map_path: Optional path to document mapping CSV
encoding: CSV file encoding (default: utf-8)
"""
self.csv_path = Path(csv_path)
self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs'
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
self.encoding = encoding
# Load document mapping if provided
self.doc_map = self._load_doc_map() if self.doc_map_path else {}
def _load_doc_map(self) -> dict[str, str]:
"""Load document ID to filename mapping."""
mapping = {}
if self.doc_map_path and self.doc_map_path.exists():
with open(self.doc_map_path, 'r', encoding=self.encoding) as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row.get('DocumentId', '').strip()
filename = row.get('FileName', '').strip()
if doc_id and filename:
mapping[doc_id] = filename
return mapping
def _parse_date(self, value: str | None) -> date | None:
"""Parse date from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Try different date formats
formats = [
'%Y-%m-%d',
'%Y-%m-%d %H:%M:%S',
'%Y-%m-%d %H:%M:%S.%f',
'%d/%m/%Y',
'%d.%m.%Y',
'%d-%m-%Y',
'%Y%m%d',
]
for fmt in formats:
try:
return datetime.strptime(value, fmt).date()
except ValueError:
continue
return None
def _parse_amount(self, value: str | None) -> Decimal | None:
"""Parse monetary amount from various formats."""
if not value or not value.strip():
return None
value = value.strip()
# Remove currency symbols and common suffixes
value = value.replace('SEK', '').replace('kr', '').replace(':-', '')
value = value.strip()
# Remove spaces (thousand separators)
value = value.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator (European format)
if ',' in value and '.' not in value:
value = value.replace(',', '.')
elif ',' in value and '.' in value:
# Assume comma is thousands separator, dot is decimal
value = value.replace(',', '')
try:
return Decimal(value)
except InvalidOperation:
return None
def _parse_string(self, value: str | None) -> str | None:
"""Parse string field with cleanup."""
if value is None:
return None
value = value.strip()
return value if value else None
def _parse_row(self, row: dict) -> InvoiceRow | None:
"""Parse a single CSV row into InvoiceRow."""
doc_id = self._parse_string(row.get('DocumentId'))
if not doc_id:
return None
return InvoiceRow(
DocumentId=doc_id,
InvoiceDate=self._parse_date(row.get('InvoiceDate')),
InvoiceNumber=self._parse_string(row.get('InvoiceNumber')),
InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')),
OCR=self._parse_string(row.get('OCR')),
Message=self._parse_string(row.get('Message')),
Bankgiro=self._parse_string(row.get('Bankgiro')),
Plusgiro=self._parse_string(row.get('Plusgiro')),
Amount=self._parse_amount(row.get('Amount')),
raw_data=dict(row)
)
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows."""
# Handle BOM - try utf-8-sig first to handle BOM correctly
encodings = ['utf-8-sig', self.encoding, 'latin-1']
for enc in encodings:
try:
with open(self.csv_path, 'r', encoding=enc) as f:
reader = csv.DictReader(f)
for row in reader:
parsed = self._parse_row(row)
if parsed:
yield parsed
return
except UnicodeDecodeError:
continue
raise ValueError(f"Could not read CSV file with any supported encoding")
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
"""
Get PDF path for an invoice row.
Uses document mapping if available, otherwise assumes
DocumentId.pdf naming convention.
"""
doc_id = invoice_row.DocumentId
# Check document mapping first
if doc_id in self.doc_map:
filename = self.doc_map[doc_id]
pdf_path = self.pdf_dir / filename
if pdf_path.exists():
return pdf_path
# Try default naming patterns
patterns = [
f"{doc_id}.pdf",
f"{doc_id.lower()}.pdf",
f"{doc_id.upper()}.pdf",
]
for pattern in patterns:
pdf_path = self.pdf_dir / pattern
if pdf_path.exists():
return pdf_path
# Try glob patterns for partial matches
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"):
return pdf_file
return None
def get_row_by_id(self, doc_id: str) -> InvoiceRow | None:
"""Get a specific row by DocumentId."""
for row in self.iter_rows():
if row.DocumentId == doc_id:
return row
return None
def validate(self) -> list[dict]:
"""
Validate CSV data and return issues.
Returns:
List of validation issues
"""
issues = []
for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1)
# Check required DocumentId
if not row.DocumentId:
issues.append({
'row': i,
'field': 'DocumentId',
'issue': 'Missing required DocumentId'
})
continue
# Check if PDF exists
pdf_path = self.get_pdf_path(row)
if not pdf_path:
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'PDF',
'issue': 'PDF file not found'
})
# Check for at least one matchable field
matchable_fields = [
row.InvoiceNumber,
row.OCR,
row.Bankgiro,
row.Plusgiro,
row.Amount
]
if not any(matchable_fields):
issues.append({
'row': i,
'doc_id': row.DocumentId,
'field': 'All',
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)'
})
return issues
def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV."""
loader = CSVLoader(csv_path, pdf_dir)
return loader.load_all()

View File

@@ -0,0 +1,5 @@
from .pipeline import InferencePipeline, InferenceResult
from .yolo_detector import YOLODetector, Detection
from .field_extractor import FieldExtractor
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']

View File

@@ -0,0 +1,382 @@
"""
Field Extractor Module
Extracts and validates field values from detected regions.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import re
import numpy as np
from PIL import Image
from .yolo_detector import Detection, CLASS_TO_FIELD
@dataclass
class ExtractedField:
"""Represents an extracted field value."""
field_name: str
raw_text: str
normalized_value: str | None
confidence: float
detection_confidence: float
ocr_confidence: float
bbox: tuple[float, float, float, float]
page_no: int
is_valid: bool = True
validation_error: str | None = None
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
'field_name': self.field_name,
'value': self.normalized_value,
'raw_text': self.raw_text,
'confidence': self.confidence,
'bbox': list(self.bbox),
'page_no': self.page_no,
'is_valid': self.is_valid,
'validation_error': self.validation_error
}
class FieldExtractor:
"""Extracts field values from detected regions using OCR or PDF text."""
def __init__(
self,
ocr_lang: str = 'en',
use_gpu: bool = False,
bbox_padding: float = 0.1,
dpi: int = 300
):
"""
Initialize field extractor.
Args:
ocr_lang: Language for OCR
use_gpu: Whether to use GPU for OCR
bbox_padding: Padding to add around bboxes (as fraction)
dpi: DPI used for rendering (for coordinate conversion)
"""
self.ocr_lang = ocr_lang
self.use_gpu = use_gpu
self.bbox_padding = bbox_padding
self.dpi = dpi
self._ocr_engine = None # Lazy init
@property
def ocr_engine(self):
"""Lazy-load OCR engine only when needed."""
if self._ocr_engine is None:
from ..ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu)
return self._ocr_engine
def extract_from_detection_with_pdf(
self,
detection: Detection,
pdf_tokens: list,
image_width: int,
image_height: int
) -> ExtractedField:
"""
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
Args:
detection: Detection object with bbox in pixel coordinates
pdf_tokens: List of Token objects from PDF text extraction
image_width: Width of rendered image in pixels
image_height: Height of rendered image in pixels
Returns:
ExtractedField object
"""
# Convert detection bbox from pixels to PDF points
scale = 72 / self.dpi # points per pixel
x0_pdf = detection.bbox[0] * scale
y0_pdf = detection.bbox[1] * scale
x1_pdf = detection.bbox[2] * scale
y1_pdf = detection.bbox[3] * scale
# Add padding in points
pad = 3 # Small padding in points
# Find tokens that overlap with detection bbox
matching_tokens = []
for token in pdf_tokens:
if token.page_no != detection.page_no:
continue
tx0, ty0, tx1, ty1 = token.bbox
# Check overlap
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
# Calculate overlap ratio to prioritize better matches
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
if overlap_x > 0 and overlap_y > 0:
token_area = (tx1 - tx0) * (ty1 - ty0)
overlap_area = overlap_x * overlap_y
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
detection_confidence=detection.confidence,
ocr_confidence=1.0, # PDF text is always accurate
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def extract_from_detection(
self,
detection: Detection,
image: np.ndarray | Image.Image
) -> ExtractedField:
"""
Extract field value from a detection region using OCR.
Args:
detection: Detection object
image: Full page image
Returns:
ExtractedField object
"""
if isinstance(image, Image.Image):
image = np.array(image)
# Get padded bbox
h, w = image.shape[:2]
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
# Crop region
x0, y0, x1, y1 = [int(v) for v in bbox]
region = image[y0:y1, x0:x1]
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
)
# Combined confidence
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized_value,
confidence=confidence,
detection_confidence=detection.confidence,
ocr_confidence=ocr_confidence,
bbox=detection.bbox,
page_no=detection.page_no,
is_valid=is_valid,
validation_error=validation_error
)
def _normalize_and_validate(
self,
field_name: str,
raw_text: str
) -> tuple[str | None, bool, str | None]:
"""
Normalize and validate extracted text for a field.
Returns:
(normalized_value, is_valid, validation_error)
"""
text = raw_text.strip()
if not text:
return None, False, "Empty text"
if field_name == 'InvoiceNumber':
return self._normalize_invoice_number(text)
elif field_name == 'OCR':
return self._normalize_ocr_number(text)
elif field_name == 'Bankgiro':
return self._normalize_bankgiro(text)
elif field_name == 'Plusgiro':
return self._normalize_plusgiro(text)
elif field_name == 'Amount':
return self._normalize_amount(text)
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
return self._normalize_date(text)
else:
return text, True, None
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize invoice number."""
# Extract digits only
digits = re.sub(r'\D', '', text)
if len(digits) < 3:
return None, False, f"Too few digits: {len(digits)}"
return digits, True, None
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize OCR number."""
digits = re.sub(r'\D', '', text)
if len(digits) < 5:
return None, False, f"Too few digits for OCR: {len(digits)}"
return digits, True, None
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize Bankgiro number."""
digits = re.sub(r'\D', '', text)
if len(digits) == 8:
# Format as XXXX-XXXX
formatted = f"{digits[:4]}-{digits[4:]}"
return formatted, True, None
elif len(digits) == 7:
# Format as XXX-XXXX
formatted = f"{digits[:3]}-{digits[3:]}"
return formatted, True, None
elif 6 <= len(digits) <= 9:
return digits, True, None
else:
return None, False, f"Invalid Bankgiro length: {len(digits)}"
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize Plusgiro number."""
digits = re.sub(r'\D', '', text)
if len(digits) >= 6:
# Format as XXXXXXX-X
formatted = f"{digits[:-1]}-{digits[-1]}"
return formatted, True, None
else:
return None, False, f"Invalid Plusgiro length: {len(digits)}"
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize monetary amount."""
# Remove currency and common suffixes
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE)
text = text.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator
if ',' in text and '.' not in text:
text = text.replace(',', '.')
# Try to parse as float
try:
amount = float(text)
return f"{amount:.2f}", True, None
except ValueError:
return None, False, f"Cannot parse amount: {text}"
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize date."""
from datetime import datetime
# Common date patterns
patterns = [
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"),
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"),
]
for pattern, formatter in patterns:
match = re.search(pattern, text)
if match:
try:
date_str = formatter(match)
# Validate date
datetime.strptime(date_str, '%Y-%m-%d')
return date_str, True, None
except ValueError:
continue
return None, False, f"Cannot parse date: {text}"
def extract_all_fields(
self,
detections: list[Detection],
image: np.ndarray | Image.Image
) -> list[ExtractedField]:
"""
Extract fields from all detections.
Args:
detections: List of detections
image: Full page image
Returns:
List of ExtractedField objects
"""
fields = []
for detection in detections:
field = self.extract_from_detection(detection, image)
fields.append(field)
return fields
@staticmethod
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
"""
Infer OCR field from InvoiceNumber if not detected.
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
Args:
fields: Dict of field_name -> normalized_value
Returns:
Updated fields dict with inferred OCR if applicable
"""
# If OCR already exists, no need to infer
if fields.get('OCR'):
return fields
# If InvoiceNumber exists and is numeric, use it as OCR
invoice_number = fields.get('InvoiceNumber')
if invoice_number:
# Check if it's mostly digits (valid OCR reference)
digits_only = re.sub(r'\D', '', invoice_number)
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
fields['OCR'] = invoice_number
return fields

297
src/inference/pipeline.py Normal file
View File

@@ -0,0 +1,297 @@
"""
Inference Pipeline
Complete pipeline for extracting invoice data from PDFs.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import time
import re
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
from .field_extractor import FieldExtractor, ExtractedField
@dataclass
class InferenceResult:
"""Result of invoice processing."""
document_id: str | None = None
success: bool = False
fields: dict[str, Any] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
raw_detections: list[Detection] = field(default_factory=list)
extracted_fields: list[ExtractedField] = field(default_factory=list)
processing_time_ms: float = 0.0
errors: list[str] = field(default_factory=list)
fallback_used: bool = False
def to_json(self) -> dict:
"""Convert to JSON-serializable dictionary."""
return {
'DocumentId': self.document_id,
'InvoiceNumber': self.fields.get('InvoiceNumber'),
'InvoiceDate': self.fields.get('InvoiceDate'),
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
'OCR': self.fields.get('OCR'),
'Bankgiro': self.fields.get('Bankgiro'),
'Plusgiro': self.fields.get('Plusgiro'),
'Amount': self.fields.get('Amount'),
'confidence': self.confidence,
'success': self.success,
'fallback_used': self.fallback_used
}
def get_field(self, field_name: str) -> tuple[Any, float]:
"""Get field value and confidence."""
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
class InferencePipeline:
"""
Complete inference pipeline for invoice data extraction.
Pipeline flow:
1. PDF -> Image rendering
2. YOLO detection of field regions
3. OCR extraction from detected regions
4. Field normalization and validation
5. Fallback to full-page OCR if YOLO fails
"""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
ocr_lang: str = 'en',
use_gpu: bool = False,
dpi: int = 300,
enable_fallback: bool = True
):
"""
Initialize inference pipeline.
Args:
model_path: Path to trained YOLO model
confidence_threshold: Detection confidence threshold
ocr_lang: Language for OCR
use_gpu: Whether to use GPU
dpi: Resolution for PDF rendering
enable_fallback: Enable fallback to full-page OCR
"""
self.detector = YOLODetector(
model_path,
confidence_threshold=confidence_threshold,
device='cuda' if use_gpu else 'cpu'
)
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
self.dpi = dpi
self.enable_fallback = enable_fallback
def process_pdf(
self,
pdf_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a PDF and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from ..pdf.renderer import render_pdf_to_images
from PIL import Image
import io
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(pdf_path).stem
)
try:
all_detections = []
all_extracted = []
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run YOLO detection
detections = self.detector.detect(image_array, page_no=page_no)
all_detections.extend(detections)
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
all_extracted.append(extracted)
result.raw_detections = all_detections
result.extracted_fields = all_extracted
# Merge extracted fields (prefer highest confidence)
self._merge_fields(result)
# Fallback if key fields are missing
if self.enable_fallback and self._needs_fallback(result):
self._run_fallback(pdf_path, result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _merge_fields(self, result: InferenceResult) -> None:
"""Merge extracted fields, keeping highest confidence for each field."""
field_candidates: dict[str, list[ExtractedField]] = {}
for extracted in result.extracted_fields:
if not extracted.is_valid or not extracted.normalized_value:
continue
if extracted.field_name not in field_candidates:
field_candidates[extracted.field_name] = []
field_candidates[extracted.field_name].append(extracted)
# Select best candidate for each field
for field_name, candidates in field_candidates.items():
best = max(candidates, key=lambda x: x.confidence)
result.fields[field_name] = best.normalized_value
result.confidence[field_name] = best.confidence
def _needs_fallback(self, result: InferenceResult) -> bool:
"""Check if fallback OCR is needed."""
# Check for key fields
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
missing = sum(1 for f in key_fields if f not in result.fields)
return missing >= 2 # Fallback if 2+ key fields missing
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
"""Run full-page OCR fallback."""
from ..pdf.renderer import render_pdf_to_images
from ..ocr import OCREngine
from PIL import Image
import io
import numpy as np
result.fallback_used = True
ocr_engine = OCREngine()
try:
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Full page OCR
tokens = ocr_engine.extract_from_image(image_array, page_no)
full_text = ' '.join(t.text for t in tokens)
# Try to extract missing fields with regex patterns
self._extract_with_patterns(full_text, result)
except Exception as e:
result.errors.append(f"Fallback OCR error: {e}")
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
"""Extract fields using regex patterns (fallback)."""
patterns = {
'Amount': [
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
],
'Bankgiro': [
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
],
'OCR': [
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
],
'InvoiceNumber': [
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
],
}
for field_name, field_patterns in patterns.items():
if field_name in result.fields:
continue
for pattern in field_patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
value = match.group(1).strip()
# Normalize the value
if field_name == 'Amount':
value = value.replace(' ', '').replace(',', '.')
try:
value = f"{float(value):.2f}"
except ValueError:
continue
elif field_name == 'Bankgiro':
digits = re.sub(r'\D', '', value)
if len(digits) == 8:
value = f"{digits[:4]}-{digits[4:]}"
result.fields[field_name] = value
result.confidence[field_name] = 0.5 # Lower confidence for regex
break
def process_image(
self,
image_path: str | Path,
document_id: str | None = None
) -> InferenceResult:
"""
Process a single image (for pre-rendered pages).
Args:
image_path: Path to image file
document_id: Optional document ID
Returns:
InferenceResult with extracted fields
"""
from PIL import Image
import numpy as np
start_time = time.time()
result = InferenceResult(
document_id=document_id or Path(image_path).stem
)
try:
image = Image.open(image_path)
image_array = np.array(image)
# Run detection
detections = self.detector.detect(image_array, page_no=0)
result.raw_detections = detections
# Extract fields
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
result.extracted_fields.append(extracted)
# Merge fields
self._merge_fields(result)
result.success = len(result.fields) > 0
except Exception as e:
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result

View File

@@ -0,0 +1,204 @@
"""
YOLO Detection Module
Runs YOLO model inference for field detection.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
@dataclass
class Detection:
"""Represents a single YOLO detection."""
class_id: int
class_name: str
confidence: float
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
def get_padded_bbox(
self,
padding: float = 0.1,
image_width: float | None = None,
image_height: float | None = None
) -> tuple[float, float, float, float]:
"""Get bbox with padding for OCR extraction."""
pad_x = self.width * padding
pad_y = self.height * padding
x0 = self.x0 - pad_x
y0 = self.y0 - pad_y
x1 = self.x1 + pad_x
y1 = self.y1 + pad_y
if image_width:
x0 = max(0, x0)
x1 = min(image_width, x1)
if image_height:
y0 = max(0, y0)
y1 = min(image_height, y1)
return (x0, y0, x1, y1)
# Class names (must match training configuration)
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
]
# Mapping from class name to field name
CLASS_TO_FIELD = {
'invoice_number': 'InvoiceNumber',
'invoice_date': 'InvoiceDate',
'invoice_due_date': 'InvoiceDueDate',
'ocr_number': 'OCR',
'bankgiro': 'Bankgiro',
'plusgiro': 'Plusgiro',
'amount': 'Amount',
}
class YOLODetector:
"""YOLO model wrapper for field detection."""
def __init__(
self,
model_path: str | Path,
confidence_threshold: float = 0.5,
iou_threshold: float = 0.45,
device: str = 'auto'
):
"""
Initialize YOLO detector.
Args:
model_path: Path to trained YOLO model (.pt file)
confidence_threshold: Minimum confidence for detections
iou_threshold: IOU threshold for NMS
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
"""
from ultralytics import YOLO
self.model = YOLO(model_path)
self.confidence_threshold = confidence_threshold
self.iou_threshold = iou_threshold
self.device = device
def detect(
self,
image: str | Path | np.ndarray,
page_no: int = 0
) -> list[Detection]:
"""
Run detection on an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
Returns:
List of Detection objects
"""
results = self.model.predict(
source=image,
conf=self.confidence_threshold,
iou=self.iou_threshold,
device=self.device,
verbose=False
)
detections = []
for result in results:
boxes = result.boxes
if boxes is None:
continue
for i in range(len(boxes)):
class_id = int(boxes.cls[i])
confidence = float(boxes.conf[i])
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
detections.append(Detection(
class_id=class_id,
class_name=class_name,
confidence=confidence,
bbox=tuple(bbox),
page_no=page_no
))
return detections
def detect_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> dict[int, list[Detection]]:
"""
Run detection on all pages of a PDF.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
Returns:
Dict mapping page number to list of detections
"""
from ..pdf.renderer import render_pdf_to_images
from PIL import Image
import io
results = {}
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
detections = self.detect(image_array, page_no=page_no)
results[page_no] = detections
return results
def get_field_name(self, class_name: str) -> str:
"""Convert class name to field name."""
return CLASS_TO_FIELD.get(class_name, class_name)

3
src/matcher/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .field_matcher import FieldMatcher, Match, find_field_matches
__all__ = ['FieldMatcher', 'Match', 'find_field_matches']

View File

@@ -0,0 +1,618 @@
"""
Field Matching Module
Matches normalized field values to tokens extracted from documents.
"""
from dataclasses import dataclass
from typing import Protocol
import re
class TokenLike(Protocol):
"""Protocol for token objects."""
text: str
bbox: tuple[float, float, float, float]
page_no: int
@dataclass
class Match:
"""Represents a matched field in the document."""
field: str
value: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
score: float # 0-1 confidence score
matched_text: str # Actual text that matched
context_keywords: list[str] # Nearby keywords that boosted confidence
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
"""Convert to YOLO annotation format."""
x0, y0, x1, y1 = self.bbox
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
# Context keywords for each field type (Swedish invoice terms)
CONTEXT_KEYWORDS = {
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
'förfallodag', 'oss tillhanda senast', 'senast'],
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
}
class FieldMatcher:
"""Matches field values to document tokens."""
def __init__(
self,
context_radius: float = 100.0, # pixels
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
def find_matches(
self,
tokens: list[TokenLike],
field_name: str,
normalized_values: list[str],
page_no: int = 0
) -> list[Match]:
"""
Find all matches for a field in the token list.
Args:
tokens: List of tokens from the document
field_name: Name of the field to match
normalized_values: List of normalized value variants to search for
page_no: Page number to filter tokens
Returns:
List of Match objects sorted by score (descending)
"""
matches = []
page_tokens = [t for t in tokens if t.page_no == page_no]
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
matches.extend(exact_matches)
# Strategy 2: Multi-token concatenation
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
matches.extend(concat_matches)
# Strategy 3: Fuzzy match (for amounts and dates only)
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
matches.extend(fuzzy_matches)
# Strategy 4: Substring match (for dates embedded in longer text)
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
# Only if no exact matches found for date fields
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
flexible_matches = self._find_flexible_date_matches(
page_tokens, normalized_values, field_name
)
matches.extend(flexible_matches)
# Deduplicate and sort by score
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
for token in tokens:
token_text = token.text.strip()
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match
elif token_text.lower() == value.lower():
score = 0.95
# Digits-only match for numeric fields
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
token_digits = re.sub(r'\D', '', token_text)
value_digits = re.sub(r'\D', '', value)
if token_digits and token_digits == value_digits:
score = 0.9
else:
continue
else:
continue
# Boost score if context keywords are nearby
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
score = min(1.0, score + context_boost)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=score,
matched_text=token_text,
context_keywords=context_keywords
))
return matches
def _find_concatenated_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = re.sub(r'\s+', '', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
for i, start_token in enumerate(sorted_tokens):
# Try to build the value by concatenating nearby tokens
concat_text = start_token.text.strip()
concat_bbox = list(start_token.bbox)
used_tokens = [start_token]
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
next_token = sorted_tokens[j]
# Check if tokens are on the same line (y overlap)
if not self._tokens_on_same_line(start_token, next_token):
break
# Check horizontal proximity
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
break
concat_text += next_token.text.strip()
used_tokens.append(next_token)
# Update bounding box
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = re.sub(r'\s+', '', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=tuple(concat_bbox),
page_no=start_token.page_no,
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
matched_text=concat_text,
context_keywords=context_keywords
))
break
return matches
def _find_substring_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""
Find value as a substring within longer tokens.
Handles cases like 'Fakturadatum: 2026-01-09' where the date
is embedded in a longer text string.
Uses lower score (0.75) than exact match to prefer exact matches.
Only matches if the value appears as a distinct segment (not part of a number).
"""
matches = []
# Only use for date fields - other fields risk false positives
if field_name not in ('InvoiceDate', 'InvoiceDueDate'):
return matches
for token in tokens:
token_text = token.text.strip()
# Skip if token is the same length as value (would be exact match)
if len(token_text) <= len(value):
continue
# Check if value appears as substring
if value in token_text:
# Verify it's a proper boundary match (not part of a larger number)
idx = token_text.find(value)
# Check character before (if exists)
if idx > 0:
char_before = token_text[idx - 1]
# Must be non-digit (allow : space - etc)
if char_before.isdigit():
continue
# Check character after (if exists)
end_idx = idx + len(value)
if end_idx < len(token_text):
char_after = token_text[end_idx]
# Must be non-digit
if char_after.isdigit():
continue
# Found valid substring match
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if context keyword is in the same token (like "Fakturadatum:")
token_lower = token_text.lower()
inline_context = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_context.append(keyword)
# Boost score if keyword is inline
inline_boost = 0.1 if inline_context else 0
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox, # Use full token bbox
page_no=token.page_no,
score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match
matched_text=token_text,
context_keywords=context_keywords + inline_context
))
return matches
def _find_fuzzy_matches(
self,
tokens: list[TokenLike],
value: str,
field_name: str
) -> list[Match]:
"""Find approximate matches for amounts and dates."""
matches = []
for token in tokens:
token_text = token.text.strip()
if field_name == 'Amount':
# Try to parse both as numbers
try:
token_num = self._parse_amount(token_text)
value_num = self._parse_amount(value)
if token_num is not None and value_num is not None:
if abs(token_num - value_num) < 0.01: # Within 1 cent
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
matches.append(Match(
field=field_name,
value=value,
bbox=token.bbox,
page_no=token.page_no,
score=min(1.0, 0.8 + context_boost),
matched_text=token_text,
context_keywords=context_keywords
))
except:
pass
return matches
def _find_flexible_date_matches(
self,
tokens: list[TokenLike],
normalized_values: list[str],
field_name: str
) -> list[Match]:
"""
Flexible date matching when exact match fails.
Strategies:
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
2. Nearby date match: Match dates within 7 days of CSV value
3. Heuristic selection: Use context keywords to select the best date
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
but we can still find a reasonable date to label.
"""
from datetime import datetime, timedelta
matches = []
# Parse the target date from normalized values
target_date = None
for value in normalized_values:
# Try to parse YYYY-MM-DD format
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
if date_match:
try:
target_date = datetime(
int(date_match.group(1)),
int(date_match.group(2)),
int(date_match.group(3))
)
break
except ValueError:
continue
if not target_date:
return matches
# Find all date-like tokens in the document
date_candidates = []
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token
for match in date_pattern.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
int(match.group(2)),
int(match.group(3))
)
date_str = match.group(0)
# Calculate date difference
days_diff = abs((found_date - target_date).days)
# Check for context keywords
context_keywords, context_boost = self._find_context_keywords(
tokens, token, field_name
)
# Check if keyword is in the same token
token_lower = token_text.lower()
inline_keywords = []
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
if keyword in token_lower:
inline_keywords.append(keyword)
date_candidates.append({
'token': token,
'date': found_date,
'date_str': date_str,
'matched_text': token_text,
'days_diff': days_diff,
'context_keywords': context_keywords + inline_keywords,
'context_boost': context_boost + (0.1 if inline_keywords else 0),
'same_year_month': (found_date.year == target_date.year and
found_date.month == target_date.month),
})
except ValueError:
continue
if not date_candidates:
return matches
# Score and rank candidates
for candidate in date_candidates:
score = 0.0
# Strategy 1: Same year-month gets higher score
if candidate['same_year_month']:
score = 0.7
# Bonus if day is close
if candidate['days_diff'] <= 7:
score = 0.75
if candidate['days_diff'] <= 3:
score = 0.8
# Strategy 2: Nearby dates (within 14 days)
elif candidate['days_diff'] <= 14:
score = 0.6
elif candidate['days_diff'] <= 30:
score = 0.55
else:
# Too far apart, skip unless has strong context
if not candidate['context_keywords']:
continue
score = 0.5
# Strategy 3: Boost with context keywords
score = min(1.0, score + candidate['context_boost'])
# For InvoiceDate, prefer dates that appear near invoice-related keywords
# For InvoiceDueDate, prefer dates near due-date keywords
if candidate['context_keywords']:
score = min(1.0, score + 0.05)
if score >= self.min_score_threshold:
matches.append(Match(
field=field_name,
value=candidate['date_str'],
bbox=candidate['token'].bbox,
page_no=candidate['token'].page_no,
score=score,
matched_text=candidate['matched_text'],
context_keywords=candidate['context_keywords']
))
# Sort by score and return best matches
matches.sort(key=lambda m: m.score, reverse=True)
# Only return the best match to avoid multiple labels for same field
return matches[:1] if matches else []
def _find_context_keywords(
self,
tokens: list[TokenLike],
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""Find context keywords near the target token."""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
found_keywords = []
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
# Calculate distance
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
boost = min(0.15, len(found_keywords) * 0.05)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
"""Check if two tokens are on the same line."""
# Check vertical overlap
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
return y_overlap > min_height * 0.5
def _parse_amount(self, text: str) -> float | None:
"""Try to parse text as a monetary amount."""
# Remove currency and spaces
text = re.sub(r'[SEK|kr|:-]', '', text, flags=re.IGNORECASE)
text = text.replace(' ', '').replace('\xa0', '')
# Try comma as decimal separator
if ',' in text and '.' not in text:
text = text.replace(',', '.')
try:
return float(text)
except ValueError:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""Remove duplicate matches based on bbox overlap."""
if not matches:
return []
# Sort by score descending
matches.sort(key=lambda m: m.score, reverse=True)
unique = []
for match in matches:
is_duplicate = False
for existing in unique:
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if not is_duplicate:
unique.append(match)
return unique
def _bbox_overlap(
self,
bbox1: tuple[float, float, float, float],
bbox2: tuple[float, float, float, float]
) -> float:
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0
def find_field_matches(
tokens: list[TokenLike],
field_values: dict[str, str],
page_no: int = 0
) -> dict[str, list[Match]]:
"""
Convenience function to find matches for multiple fields.
Args:
tokens: List of tokens from the document
field_values: Dict of field_name -> value to search for
page_no: Page number
Returns:
Dict of field_name -> list of matches
"""
from ..normalize import normalize_field
matcher = FieldMatcher()
results = {}
for field_name, value in field_values.items():
if value is None or str(value).strip() == '':
continue
normalized_values = normalize_field(field_name, str(value))
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
results[field_name] = matches
return results

View File

@@ -0,0 +1,3 @@
from .normalizer import normalize_field, FieldNormalizer
__all__ = ['normalize_field', 'FieldNormalizer']

290
src/normalize/normalizer.py Normal file
View File

@@ -0,0 +1,290 @@
"""
Field Normalization Module
Normalizes field values to generate multiple candidate forms for matching.
"""
import re
from dataclasses import dataclass
from datetime import datetime
from typing import Callable
@dataclass
class NormalizedValue:
"""Represents a normalized value with its variants."""
original: str
variants: list[str]
field_type: str
class FieldNormalizer:
"""Handles normalization of different invoice field types."""
# Common Swedish month names for date parsing
SWEDISH_MONTHS = {
'januari': '01', 'jan': '01',
'februari': '02', 'feb': '02',
'mars': '03', 'mar': '03',
'april': '04', 'apr': '04',
'maj': '05',
'juni': '06', 'jun': '06',
'juli': '07', 'jul': '07',
'augusti': '08', 'aug': '08',
'september': '09', 'sep': '09', 'sept': '09',
'oktober': '10', 'okt': '10',
'november': '11', 'nov': '11',
'december': '12', 'dec': '12'
}
@staticmethod
def clean_text(text: str) -> str:
"""Remove invisible characters and normalize whitespace."""
# Remove zero-width characters
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
# Normalize whitespace
text = ' '.join(text.split())
return text.strip()
@staticmethod
def normalize_invoice_number(value: str) -> list[str]:
"""
Normalize invoice number.
Keeps only digits for matching.
Examples:
'100017500321' -> ['100017500321']
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
"""
value = FieldNormalizer.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only and digits_only != value:
variants.append(digits_only)
return list(set(v for v in variants if v))
@staticmethod
def normalize_ocr_number(value: str) -> list[str]:
"""
Normalize OCR number (Swedish payment reference).
Similar to invoice number - digits only.
"""
return FieldNormalizer.normalize_invoice_number(value)
@staticmethod
def normalize_bankgiro(value: str) -> list[str]:
"""
Normalize Bankgiro number.
Examples:
'5393-9484' -> ['5393-9484', '53939484']
'53939484' -> ['53939484', '5393-9484']
"""
value = FieldNormalizer.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only:
# Add without dash
variants.append(digits_only)
# Add with dash (format: XXXX-XXXX for 8 digits)
if len(digits_only) == 8:
with_dash = f"{digits_only[:4]}-{digits_only[4:]}"
variants.append(with_dash)
elif len(digits_only) == 7:
# Some bankgiro numbers are 7 digits: XXX-XXXX
with_dash = f"{digits_only[:3]}-{digits_only[3:]}"
variants.append(with_dash)
return list(set(v for v in variants if v))
@staticmethod
def normalize_plusgiro(value: str) -> list[str]:
"""
Normalize Plusgiro number.
Examples:
'1234567-8' -> ['1234567-8', '12345678']
'12345678' -> ['12345678', '1234567-8']
"""
value = FieldNormalizer.clean_text(value)
digits_only = re.sub(r'\D', '', value)
variants = [value]
if digits_only:
variants.append(digits_only)
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
if len(digits_only) == 8:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
# Also handle 6+1 format
elif len(digits_only) == 7:
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
variants.append(with_dash)
return list(set(v for v in variants if v))
@staticmethod
def normalize_amount(value: str) -> list[str]:
"""
Normalize monetary amount.
Examples:
'114' -> ['114', '114,00', '114.00']
'114,00' -> ['114,00', '114.00', '114']
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
"""
value = FieldNormalizer.clean_text(value)
# Remove currency symbols and common suffixes
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
# Remove spaces (thousand separators)
no_space = value.replace(' ', '').replace('\xa0', '')
variants = [value]
# Normalize decimal separator
if ',' in no_space:
dot_version = no_space.replace(',', '.')
variants.append(no_space)
variants.append(dot_version)
elif '.' in no_space:
comma_version = no_space.replace('.', ',')
variants.append(no_space)
variants.append(comma_version)
else:
# Integer amount - add decimal versions
variants.append(no_space)
variants.append(f"{no_space},00")
variants.append(f"{no_space}.00")
# Try to parse and get clean numeric value
try:
# Parse as float
clean = no_space.replace(',', '.')
num = float(clean)
# Integer if no decimals
if num == int(num):
variants.append(str(int(num)))
variants.append(f"{int(num)},00")
variants.append(f"{int(num)}.00")
else:
variants.append(f"{num:.2f}")
variants.append(f"{num:.2f}".replace('.', ','))
except ValueError:
pass
return list(set(v for v in variants if v))
@staticmethod
def normalize_date(value: str) -> list[str]:
"""
Normalize date to YYYY-MM-DD and generate variants.
Handles:
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
'13 december 2025' -> ['2025-12-13', ...]
"""
value = FieldNormalizer.clean_text(value)
variants = [value]
parsed_date = None
# Try different date formats
date_patterns = [
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
# European format with /
(r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with .
(r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# European format with -
(r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
# Swedish format: YYMMDD
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
# Swedish format: YYYYMMDD
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
]
for pattern, extractor in date_patterns:
match = re.match(pattern, value)
if match:
try:
year, month, day = extractor(match)
parsed_date = datetime(year, month, day)
break
except ValueError:
continue
# Try Swedish month names
if not parsed_date:
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
if month_name in value.lower():
# Extract day and year
numbers = re.findall(r'\d+', value)
if len(numbers) >= 2:
day = int(numbers[0])
year = int(numbers[-1])
if year < 100:
year = 2000 + year if year < 50 else 1900 + year
try:
parsed_date = datetime(year, int(month_num), day)
break
except ValueError:
continue
if parsed_date:
# Generate different formats
iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y')
eu_dot = parsed_date.strftime('%d.%m.%Y')
compact = parsed_date.strftime('%Y%m%d')
variants.extend([iso, eu_slash, eu_dot, compact])
return list(set(v for v in variants if v))
# Field type to normalizer mapping
NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
'InvoiceNumber': FieldNormalizer.normalize_invoice_number,
'OCR': FieldNormalizer.normalize_ocr_number,
'Bankgiro': FieldNormalizer.normalize_bankgiro,
'Plusgiro': FieldNormalizer.normalize_plusgiro,
'Amount': FieldNormalizer.normalize_amount,
'InvoiceDate': FieldNormalizer.normalize_date,
'InvoiceDueDate': FieldNormalizer.normalize_date,
}
def normalize_field(field_name: str, value: str) -> list[str]:
"""
Normalize a field value based on its type.
Args:
field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount')
value: Raw value to normalize
Returns:
List of normalized variants
"""
if value is None or (isinstance(value, str) and not value.strip()):
return []
value = str(value)
normalizer = NORMALIZERS.get(field_name)
if normalizer:
return normalizer(value)
# Default: just clean the text
return [FieldNormalizer.clean_text(value)]

3
src/ocr/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .paddle_ocr import OCREngine, extract_ocr_tokens
__all__ = ['OCREngine', 'extract_ocr_tokens']

188
src/ocr/paddle_ocr.py Normal file
View File

@@ -0,0 +1,188 @@
"""
OCR Extraction Module using PaddleOCR
Extracts text tokens with bounding boxes from scanned PDFs.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import numpy as np
@dataclass
class OCRToken:
"""Represents an OCR-extracted text token with its bounding box."""
text: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
confidence: float
page_no: int = 0
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
class OCREngine:
"""PaddleOCR wrapper for text extraction."""
def __init__(
self,
lang: str = "en",
use_gpu: bool = False,
det_model_dir: str | None = None,
rec_model_dir: str | None = None
):
"""
Initialize OCR engine.
Args:
lang: Language code ('en', 'sv', 'ch', etc.)
use_gpu: Whether to use GPU acceleration
det_model_dir: Custom detection model directory
rec_model_dir: Custom recognition model directory
"""
from paddleocr import PaddleOCR
# PaddleOCR 3.x API - simplified init
init_params = {'lang': lang}
if det_model_dir:
init_params['text_detection_model_dir'] = det_model_dir
if rec_model_dir:
init_params['text_recognition_model_dir'] = rec_model_dir
self.ocr = PaddleOCR(**init_params)
def extract_from_image(
self,
image: str | Path | np.ndarray,
page_no: int = 0
) -> list[OCRToken]:
"""
Extract text tokens from an image.
Args:
image: Image path or numpy array
page_no: Page number for reference
Returns:
List of OCRToken objects
"""
if isinstance(image, (str, Path)):
image = str(image)
# PaddleOCR 3.x uses predict() method instead of ocr()
result = self.ocr.predict(image)
tokens = []
if result:
for item in result:
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
if isinstance(item, dict):
rec_texts = item.get('rec_texts', [])
rec_scores = item.get('rec_scores', [])
dt_polys = item.get('dt_polys', [])
for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)):
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
x_coords = [p[0] for p in poly]
y_coords = [p[1] for p in poly]
bbox = (
min(x_coords),
min(y_coords),
max(x_coords),
max(y_coords)
)
tokens.append(OCRToken(
text=text,
bbox=bbox,
confidence=float(score),
page_no=page_no
))
elif isinstance(item, (list, tuple)) and len(item) == 2:
# Legacy format: [[bbox_points], (text, confidence)]
bbox_points, (text, confidence) = item
x_coords = [p[0] for p in bbox_points]
y_coords = [p[1] for p in bbox_points]
bbox = (
min(x_coords),
min(y_coords),
max(x_coords),
max(y_coords)
)
tokens.append(OCRToken(
text=text,
bbox=bbox,
confidence=confidence,
page_no=page_no
))
return tokens
def extract_from_pdf(
self,
pdf_path: str | Path,
dpi: int = 300
) -> Generator[list[OCRToken], None, None]:
"""
Extract text from all pages of a scanned PDF.
Args:
pdf_path: Path to the PDF file
dpi: Resolution for rendering
Yields:
List of OCRToken for each page
"""
from ..pdf.renderer import render_pdf_to_images
import io
from PIL import Image
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
# Convert bytes to numpy array
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
tokens = self.extract_from_image(image_array, page_no=page_no)
yield tokens
def extract_ocr_tokens(
image_path: str | Path,
lang: str = "en",
page_no: int = 0
) -> list[OCRToken]:
"""
Convenience function to extract OCR tokens from an image.
Args:
image_path: Path to the image file
lang: Language code
page_no: Page number for reference
Returns:
List of OCRToken objects
"""
engine = OCREngine(lang=lang)
return engine.extract_from_image(image_path, page_no=page_no)

5
src/pdf/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
from .detector import is_text_pdf, get_pdf_type
from .renderer import render_pdf_to_images
from .extractor import extract_text_tokens
__all__ = ['is_text_pdf', 'get_pdf_type', 'render_pdf_to_images', 'extract_text_tokens']

98
src/pdf/detector.py Normal file
View File

@@ -0,0 +1,98 @@
"""
PDF Type Detection Module
Automatically distinguishes between:
- Text-based PDFs (digitally generated)
- Scanned image PDFs
"""
from pathlib import Path
from typing import Literal
import fitz # PyMuPDF
PDFType = Literal["text", "scanned", "mixed"]
def extract_text_first_page(pdf_path: str | Path) -> str:
"""Extract text from the first page of a PDF."""
doc = fitz.open(pdf_path)
if len(doc) == 0:
return ""
first_page = doc[0]
text = first_page.get_text()
doc.close()
return text
def is_text_pdf(pdf_path: str | Path, min_chars: int = 30) -> bool:
"""
Check if PDF has extractable text layer.
Args:
pdf_path: Path to the PDF file
min_chars: Minimum characters to consider it a text PDF
Returns:
True if PDF has text layer, False if scanned
"""
text = extract_text_first_page(pdf_path)
return len(text.strip()) > min_chars
def get_pdf_type(pdf_path: str | Path) -> PDFType:
"""
Determine the PDF type.
Returns:
'text' - Has extractable text layer
'scanned' - Image-based, needs OCR
'mixed' - Some pages have text, some don't
"""
doc = fitz.open(pdf_path)
if len(doc) == 0:
doc.close()
return "scanned"
text_pages = 0
for page in doc:
text = page.get_text().strip()
if len(text) > 30:
text_pages += 1
doc.close()
total_pages = len(doc)
if text_pages == total_pages:
return "text"
elif text_pages == 0:
return "scanned"
else:
return "mixed"
def get_page_info(pdf_path: str | Path) -> list[dict]:
"""
Get information about each page in the PDF.
Returns:
List of dicts with page info (number, width, height, has_text)
"""
doc = fitz.open(pdf_path)
pages = []
for i, page in enumerate(doc):
text = page.get_text().strip()
rect = page.rect
pages.append({
"page_no": i,
"width": rect.width,
"height": rect.height,
"has_text": len(text) > 30,
"char_count": len(text)
})
doc.close()
return pages

176
src/pdf/extractor.py Normal file
View File

@@ -0,0 +1,176 @@
"""
PDF Text Extraction Module
Extracts text tokens with bounding boxes from text-layer PDFs.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import fitz # PyMuPDF
@dataclass
class Token:
"""Represents a text token with its bounding box."""
text: str
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
page_no: int
@property
def x0(self) -> float:
return self.bbox[0]
@property
def y0(self) -> float:
return self.bbox[1]
@property
def x1(self) -> float:
return self.bbox[2]
@property
def y1(self) -> float:
return self.bbox[3]
@property
def width(self) -> float:
return self.x1 - self.x0
@property
def height(self) -> float:
return self.y1 - self.y0
@property
def center(self) -> tuple[float, float]:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
def extract_text_tokens(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract text tokens with bounding boxes from PDF.
Args:
pdf_path: Path to the PDF file
page_no: Specific page to extract (None for all pages)
Yields:
Token objects with text and bbox
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
# Get text with position info using "dict" mode
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0: # Skip non-text blocks
continue
for line in block.get("lines", []):
for span in line.get("spans", []):
text = span.get("text", "").strip()
if not text:
continue
bbox = span.get("bbox")
if bbox:
yield Token(
text=text,
bbox=tuple(bbox),
page_no=pg_no
)
doc.close()
def extract_words(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract individual words with bounding boxes.
Uses PyMuPDF's word extraction which splits text into words.
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
# get_text("words") returns list of (x0, y0, x1, y1, word, block_no, line_no, word_no)
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close()
def extract_lines(
pdf_path: str | Path,
page_no: int | None = None
) -> Generator[Token, None, None]:
"""
Extract text lines with bounding boxes.
"""
doc = fitz.open(pdf_path)
pages_to_process = [page_no] if page_no is not None else range(len(doc))
for pg_no in pages_to_process:
page = doc[pg_no]
text_dict = page.get_text("dict")
for block in text_dict.get("blocks", []):
if block.get("type") != 0:
continue
for line in block.get("lines", []):
spans = line.get("spans", [])
if not spans:
continue
# Combine all spans in the line
line_text = " ".join(s.get("text", "") for s in spans).strip()
if not line_text:
continue
# Get line bbox from all spans
x0 = min(s["bbox"][0] for s in spans)
y0 = min(s["bbox"][1] for s in spans)
x1 = max(s["bbox"][2] for s in spans)
y1 = max(s["bbox"][3] for s in spans)
yield Token(
text=line_text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close()
def get_page_dimensions(pdf_path: str | Path, page_no: int = 0) -> tuple[float, float]:
"""Get the dimensions of a PDF page in points."""
doc = fitz.open(pdf_path)
page = doc[page_no]
rect = page.rect
doc.close()
return rect.width, rect.height

117
src/pdf/renderer.py Normal file
View File

@@ -0,0 +1,117 @@
"""
PDF Rendering Module
Converts PDF pages to images for YOLO training.
"""
from pathlib import Path
from typing import Generator
import fitz # PyMuPDF
def render_pdf_to_images(
pdf_path: str | Path,
output_dir: str | Path | None = None,
dpi: int = 300,
image_format: str = "png"
) -> Generator[tuple[int, Path | bytes], None, None]:
"""
Render PDF pages to images.
Args:
pdf_path: Path to the PDF file
output_dir: Directory to save images (if None, returns bytes)
dpi: Resolution for rendering (default 300)
image_format: Output format ('png' or 'jpg')
Yields:
Tuple of (page_number, image_path or image_bytes)
"""
doc = fitz.open(pdf_path)
# Calculate zoom factor for desired DPI (72 is base DPI for PDF)
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
if output_dir:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
pdf_name = Path(pdf_path).stem
for page_no, page in enumerate(doc):
# Render page to pixmap
pix = page.get_pixmap(matrix=matrix)
if output_dir:
# Save to file
ext = "jpg" if image_format.lower() in ("jpg", "jpeg") else "png"
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.{ext}"
if ext == "jpg":
pix.save(str(image_path), "jpeg")
else:
pix.save(str(image_path))
yield page_no, image_path
else:
# Return bytes
if image_format.lower() in ("jpg", "jpeg"):
yield page_no, pix.tobytes("jpeg")
else:
yield page_no, pix.tobytes("png")
doc.close()
def render_page_to_image(
pdf_path: str | Path,
page_no: int,
dpi: int = 300
) -> bytes:
"""
Render a single page to image bytes.
Args:
pdf_path: Path to the PDF file
page_no: Page number (0-indexed)
dpi: Resolution for rendering
Returns:
PNG image bytes
"""
doc = fitz.open(pdf_path)
if page_no >= len(doc):
doc.close()
raise ValueError(f"Page {page_no} does not exist (PDF has {len(doc)} pages)")
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
page = doc[page_no]
pix = page.get_pixmap(matrix=matrix)
image_bytes = pix.tobytes("png")
doc.close()
return image_bytes
def get_render_dimensions(pdf_path: str | Path, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
"""
Get the dimensions of a rendered page.
Returns:
(width, height) in pixels
"""
doc = fitz.open(pdf_path)
page = doc[page_no]
zoom = dpi / 72
rect = page.rect
width = int(rect.width * zoom)
height = int(rect.height * zoom)
doc.close()
return width, height

4
src/yolo/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from .annotation_generator import AnnotationGenerator, generate_annotations
from .dataset_builder import DatasetBuilder
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder']

View File

@@ -0,0 +1,281 @@
"""
YOLO Annotation Generator
Generates YOLO format annotations from matched fields.
"""
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import csv
# Field class mapping for YOLO
FIELD_CLASSES = {
'InvoiceNumber': 0,
'InvoiceDate': 1,
'InvoiceDueDate': 2,
'OCR': 3,
'Bankgiro': 4,
'Plusgiro': 5,
'Amount': 6,
}
CLASS_NAMES = [
'invoice_number',
'invoice_date',
'invoice_due_date',
'ocr_number',
'bankgiro',
'plusgiro',
'amount',
]
@dataclass
class YOLOAnnotation:
"""Represents a single YOLO annotation."""
class_id: int
x_center: float # normalized 0-1
y_center: float # normalized 0-1
width: float # normalized 0-1
height: float # normalized 0-1
confidence: float = 1.0
def to_string(self) -> str:
"""Convert to YOLO format string."""
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
class AnnotationGenerator:
"""Generates YOLO annotations from document matches."""
def __init__(
self,
min_confidence: float = 0.7,
bbox_padding_px: int = 20, # Absolute padding in pixels
min_bbox_height_px: int = 30 # Minimum bbox height
):
"""
Initialize annotation generator.
Args:
min_confidence: Minimum match score to include in training
bbox_padding_px: Absolute padding in pixels to add around bboxes
min_bbox_height_px: Minimum bbox height in pixels
"""
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
def generate_from_matches(
self,
matches: dict[str, list[Any]], # field_name -> list of Match
image_width: float,
image_height: float,
dpi: int = 300
) -> list[YOLOAnnotation]:
"""
Generate YOLO annotations from field matches.
Args:
matches: Dict of field_name -> list of Match objects
image_width: Width of the rendered image in pixels
image_height: Height of the rendered image in pixels
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
Returns:
List of YOLOAnnotation objects
"""
annotations = []
# Scale factor to convert PDF points (72 DPI) to rendered pixels
scale = dpi / 72.0
for field_name, field_matches in matches.items():
if field_name not in FIELD_CLASSES:
continue
class_id = FIELD_CLASSES[field_name]
# Take only the best match per field
if field_matches:
best_match = field_matches[0] # Already sorted by score
if best_match.score < self.min_confidence:
continue
# best_match.bbox is in PDF points (72 DPI), convert to pixels
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Ensure minimum height
current_height = y1 - y0
if current_height < self.min_bbox_height_px:
extra = (self.min_bbox_height_px - current_height) / 2
y0 = max(0, y0 - extra)
y1 = min(image_height, y1 + extra)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
y_center = (y0 + y1) / 2 / image_height
width = (x1 - x0) / image_width
height = (y1 - y0) / image_height
# Clamp values to 0-1
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
annotations.append(YOLOAnnotation(
class_id=class_id,
x_center=x_center,
y_center=y_center,
width=width,
height=height,
confidence=best_match.score
))
return annotations
def save_annotations(
self,
annotations: list[YOLOAnnotation],
output_path: str | Path
) -> None:
"""Save annotations to a YOLO format text file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for ann in annotations:
f.write(ann.to_string() + '\n')
@staticmethod
def generate_classes_file(output_path: str | Path) -> None:
"""Generate the classes.txt file for YOLO."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, 'w') as f:
for class_name in CLASS_NAMES:
f.write(class_name + '\n')
@staticmethod
def generate_yaml_config(
output_path: str | Path,
train_path: str = 'train/images',
val_path: str = 'val/images',
test_path: str = 'test/images'
) -> None:
"""Generate YOLO dataset YAML configuration."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Use absolute path for WSL compatibility
dataset_dir = output_path.parent.absolute()
# Convert Windows path to WSL path if needed
dataset_path_str = str(dataset_dir).replace('\\', '/')
if dataset_path_str[1] == ':':
# Windows path like C:/... -> /mnt/c/...
drive = dataset_path_str[0].lower()
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
config = f"""# Invoice Field Detection Dataset
path: {dataset_path_str}
train: {train_path}
val: {val_path}
test: {test_path}
# Classes
names:
"""
for i, name in enumerate(CLASS_NAMES):
config += f" {i}: {name}\n"
with open(output_path, 'w') as f:
f.write(config)
def generate_annotations(
pdf_path: str | Path,
structured_data: dict[str, Any],
output_dir: str | Path,
dpi: int = 300
) -> list[Path]:
"""
Generate YOLO annotations for a PDF using structured data.
Args:
pdf_path: Path to the PDF file
structured_data: Dict with field values from CSV
output_dir: Directory to save images and labels
dpi: Resolution for rendering
Returns:
List of paths to generated annotation files
"""
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine
from ..matcher import FieldMatcher
from ..normalize import normalize_field
output_dir = Path(output_dir)
images_dir = output_dir / 'images'
labels_dir = output_dir / 'labels'
images_dir.mkdir(parents=True, exist_ok=True)
labels_dir.mkdir(parents=True, exist_ok=True)
generator = AnnotationGenerator()
matcher = FieldMatcher()
annotation_files = []
# Check PDF type
use_ocr = not is_text_pdf(pdf_path)
# Initialize OCR if needed
ocr_engine = OCREngine() if use_ocr else None
# Process each page
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
# Get image dimensions
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Extract tokens
if use_ocr:
from PIL import Image
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
else:
tokens = list(extract_text_tokens(pdf_path, page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = structured_data.get(field_name)
if value is None or str(value).strip() == '':
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
matches[field_name] = field_matches
# Generate annotations (pass DPI for coordinate conversion)
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
# Save annotations
if annotations:
label_path = labels_dir / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
annotation_files.append(label_path)
return annotation_files

249
src/yolo/dataset_builder.py Normal file
View File

@@ -0,0 +1,249 @@
"""
YOLO Dataset Builder
Builds training dataset from PDFs and structured CSV data.
"""
import csv
import shutil
from dataclasses import dataclass
from pathlib import Path
from typing import Generator
import random
@dataclass
class DatasetStats:
"""Statistics about the generated dataset."""
total_documents: int
successful: int
failed: int
total_annotations: int
annotations_per_class: dict[str, int]
train_count: int
val_count: int
test_count: int
class DatasetBuilder:
"""Builds YOLO training dataset from PDFs and CSV data."""
def __init__(
self,
pdf_dir: str | Path,
csv_path: str | Path,
output_dir: str | Path,
document_id_column: str = 'DocumentId',
dpi: int = 300,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1
):
"""
Initialize dataset builder.
Args:
pdf_dir: Directory containing PDF files
csv_path: Path to structured data CSV
output_dir: Output directory for dataset
document_id_column: Column name for document ID
dpi: Resolution for rendering
train_ratio: Fraction for training set
val_ratio: Fraction for validation set
test_ratio: Fraction for test set
"""
self.pdf_dir = Path(pdf_dir)
self.csv_path = Path(csv_path)
self.output_dir = Path(output_dir)
self.document_id_column = document_id_column
self.dpi = dpi
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.test_ratio = test_ratio
def load_structured_data(self) -> dict[str, dict]:
"""Load structured data from CSV."""
data = {}
with open(self.csv_path, 'r', encoding='utf-8') as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row.get(self.document_id_column)
if doc_id:
data[doc_id] = row
return data
def find_pdf_for_document(self, doc_id: str) -> Path | None:
"""Find PDF file for a document ID."""
# Try common naming patterns
patterns = [
f"{doc_id}.pdf",
f"{doc_id.lower()}.pdf",
f"{doc_id.upper()}.pdf",
f"*{doc_id}*.pdf",
]
for pattern in patterns:
matches = list(self.pdf_dir.glob(pattern))
if matches:
return matches[0]
return None
def build(self, seed: int = 42) -> DatasetStats:
"""
Build the complete dataset.
Args:
seed: Random seed for train/val/test split
Returns:
DatasetStats with build results
"""
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
random.seed(seed)
# Setup output directories
for split in ['train', 'val', 'test']:
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Generate config files
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
# Load structured data
structured_data = self.load_structured_data()
# Stats tracking
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'annotations': 0,
'per_class': {name: 0 for name in CLASS_NAMES},
'splits': {'train': 0, 'val': 0, 'test': 0}
}
# Process each document
processed_items = []
for doc_id, data in structured_data.items():
stats['total'] += 1
pdf_path = self.find_pdf_for_document(doc_id)
if not pdf_path:
print(f"Warning: PDF not found for document {doc_id}")
stats['failed'] += 1
continue
try:
# Generate to temp dir first
temp_dir = self.output_dir / 'temp' / doc_id
temp_dir.mkdir(parents=True, exist_ok=True)
from .annotation_generator import generate_annotations
annotation_files = generate_annotations(
pdf_path, data, temp_dir, self.dpi
)
if annotation_files:
processed_items.append({
'doc_id': doc_id,
'temp_dir': temp_dir,
'annotation_files': annotation_files
})
stats['successful'] += 1
else:
print(f"Warning: No annotations generated for {doc_id}")
stats['failed'] += 1
except Exception as e:
print(f"Error processing {doc_id}: {e}")
stats['failed'] += 1
# Shuffle and split
random.shuffle(processed_items)
n_train = int(len(processed_items) * self.train_ratio)
n_val = int(len(processed_items) * self.val_ratio)
splits = {
'train': processed_items[:n_train],
'val': processed_items[n_train:n_train + n_val],
'test': processed_items[n_train + n_val:]
}
# Move files to final locations
for split_name, items in splits.items():
for item in items:
temp_dir = item['temp_dir']
# Move images
for img in (temp_dir / 'images').glob('*'):
dest = self.output_dir / split_name / 'images' / img.name
shutil.move(str(img), str(dest))
# Move labels and count annotations
for label in (temp_dir / 'labels').glob('*.txt'):
dest = self.output_dir / split_name / 'labels' / label.name
shutil.move(str(label), str(dest))
# Count annotations per class
with open(dest, 'r') as f:
for line in f:
class_id = int(line.strip().split()[0])
if 0 <= class_id < len(CLASS_NAMES):
stats['per_class'][CLASS_NAMES[class_id]] += 1
stats['annotations'] += 1
stats['splits'][split_name] += 1
# Cleanup temp dir
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
return DatasetStats(
total_documents=stats['total'],
successful=stats['successful'],
failed=stats['failed'],
total_annotations=stats['annotations'],
annotations_per_class=stats['per_class'],
train_count=stats['splits']['train'],
val_count=stats['splits']['val'],
test_count=stats['splits']['test']
)
def process_single_document(
self,
doc_id: str,
data: dict,
split: str = 'train'
) -> bool:
"""
Process a single document (for incremental building).
Args:
doc_id: Document ID
data: Structured data dict
split: Which split to add to
Returns:
True if successful
"""
from .annotation_generator import generate_annotations
pdf_path = self.find_pdf_for_document(doc_id)
if not pdf_path:
return False
try:
output_subdir = self.output_dir / split
annotation_files = generate_annotations(
pdf_path, data, output_subdir, self.dpi
)
return len(annotation_files) > 0
except Exception as e:
print(f"Error processing {doc_id}: {e}")
return False