Initial commit: Invoice field extraction system using YOLO + OCR
Features: - Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations - Flexible date matching: year-month match, nearby date tolerance - PDF text extraction with PyMuPDF - OCR support for scanned documents (PaddleOCR) - YOLO training and inference pipeline - 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
71
.gitignore
vendored
Normal file
71
.gitignore
vendored
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.so
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env/
|
||||||
|
.venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.idea/
|
||||||
|
.vscode/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Data files (large files)
|
||||||
|
data/raw_pdfs/
|
||||||
|
data/dataset/train/images/
|
||||||
|
data/dataset/val/images/
|
||||||
|
data/dataset/test/images/
|
||||||
|
data/dataset/train/labels/
|
||||||
|
data/dataset/val/labels/
|
||||||
|
data/dataset/test/labels/
|
||||||
|
*.pdf
|
||||||
|
*.png
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
|
||||||
|
# Model weights
|
||||||
|
models/weights/
|
||||||
|
runs/
|
||||||
|
*.pt
|
||||||
|
*.onnx
|
||||||
|
|
||||||
|
# Reports and logs
|
||||||
|
reports/*.jsonl
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Jupyter
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Credentials
|
||||||
|
.env
|
||||||
|
*.key
|
||||||
|
*.pem
|
||||||
|
credentials.json
|
||||||
226
README.md
Normal file
226
README.md
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
# Invoice Master POC v2
|
||||||
|
|
||||||
|
自动账单信息提取系统 - 使用 YOLO + OCR 从 PDF 发票中提取结构化数据。
|
||||||
|
|
||||||
|
## 运行环境
|
||||||
|
|
||||||
|
> **重要**: 本项目需要在 **WSL (Windows Subsystem for Linux)** 环境下运行。
|
||||||
|
|
||||||
|
### 系统要求
|
||||||
|
|
||||||
|
- WSL 2 (Ubuntu 22.04 推荐)
|
||||||
|
- Python 3.10+
|
||||||
|
- **NVIDIA GPU + CUDA 12.x (强烈推荐)** - GPU 训练比 CPU 快 10-50 倍
|
||||||
|
|
||||||
|
## 功能特点
|
||||||
|
|
||||||
|
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
|
||||||
|
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
|
||||||
|
- **字段检测**: 使用 YOLOv8 检测发票字段区域
|
||||||
|
- **OCR 识别**: 使用 PaddleOCR 提取检测区域的文本
|
||||||
|
- **智能匹配**: 支持多种格式规范化和上下文关键词增强
|
||||||
|
|
||||||
|
## 支持的字段
|
||||||
|
|
||||||
|
| 字段 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| InvoiceNumber | 发票号码 |
|
||||||
|
| InvoiceDate | 发票日期 |
|
||||||
|
| InvoiceDueDate | 到期日期 |
|
||||||
|
| OCR | OCR 参考号 (瑞典) |
|
||||||
|
| Bankgiro | Bankgiro 号码 |
|
||||||
|
| Plusgiro | Plusgiro 号码 |
|
||||||
|
| Amount | 金额 |
|
||||||
|
|
||||||
|
## 安装 (WSL)
|
||||||
|
|
||||||
|
### 1. 进入 WSL 环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 从 Windows 终端进入 WSL
|
||||||
|
wsl
|
||||||
|
|
||||||
|
# 进入项目目录 (Windows 路径映射到 /mnt/)
|
||||||
|
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装系统依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 更新系统
|
||||||
|
sudo apt update && sudo apt upgrade -y
|
||||||
|
|
||||||
|
# 安装 Python 和必要工具
|
||||||
|
sudo apt install -y python3.10 python3.10-venv python3-pip
|
||||||
|
|
||||||
|
# 安装 OpenCV 依赖
|
||||||
|
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 创建虚拟环境并安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 创建虚拟环境
|
||||||
|
python3 -m venv venv
|
||||||
|
source venv/bin/activate
|
||||||
|
|
||||||
|
# 升级 pip
|
||||||
|
pip install --upgrade pip
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# 或使用 pip install (开发模式)
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU 支持 (可选)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 确保 WSL 已配置 CUDA
|
||||||
|
nvidia-smi # 检查 GPU 是否可用
|
||||||
|
|
||||||
|
# 安装 GPU 版本 PaddlePaddle
|
||||||
|
pip install paddlepaddle-gpu
|
||||||
|
|
||||||
|
# 或指定 CUDA 版本
|
||||||
|
pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 准备数据
|
||||||
|
|
||||||
|
```
|
||||||
|
data/
|
||||||
|
├── raw_pdfs/
|
||||||
|
│ ├── {DocumentId}.pdf
|
||||||
|
│ └── ...
|
||||||
|
└── structured_data/
|
||||||
|
└── invoices.csv
|
||||||
|
```
|
||||||
|
|
||||||
|
CSV 格式:
|
||||||
|
```csv
|
||||||
|
DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
|
||||||
|
3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 自动标注
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.cli.autolabel \
|
||||||
|
--csv data/structured_data/invoices.csv \
|
||||||
|
--pdf-dir data/raw_pdfs \
|
||||||
|
--output data/dataset \
|
||||||
|
--report reports/autolabel_report.jsonl
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 训练模型
|
||||||
|
|
||||||
|
> **重要**: 务必使用 GPU 进行训练!CPU 训练速度非常慢。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# GPU 训练 (强烈推荐)
|
||||||
|
python -m src.cli.train \
|
||||||
|
--data data/dataset/dataset.yaml \
|
||||||
|
--model yolo11n.pt \
|
||||||
|
--epochs 100 \
|
||||||
|
--batch 16 \
|
||||||
|
--device 0 # 使用 GPU
|
||||||
|
|
||||||
|
# 验证 GPU 可用
|
||||||
|
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
|
||||||
|
```
|
||||||
|
|
||||||
|
GPU vs CPU 训练时间对比 (100 epochs, 77 训练图片):
|
||||||
|
- **GPU (RTX 5080)**: ~2 分钟
|
||||||
|
- **CPU**: 30+ 分钟
|
||||||
|
|
||||||
|
### 4. 推理
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.cli.infer \
|
||||||
|
--model runs/train/invoice_fields/weights/best.pt \
|
||||||
|
--input path/to/invoice.pdf \
|
||||||
|
--output result.json
|
||||||
|
```
|
||||||
|
|
||||||
|
## 输出示例
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"DocumentId": "3be53fd7-d5ea-458c-a229-8d360b8ba6a9",
|
||||||
|
"InvoiceNumber": "100017500321",
|
||||||
|
"InvoiceDate": "2025-12-13",
|
||||||
|
"InvoiceDueDate": "2026-01-03",
|
||||||
|
"OCR": "100017500321",
|
||||||
|
"Bankgiro": "5393-9484",
|
||||||
|
"Plusgiro": null,
|
||||||
|
"Amount": "114.00",
|
||||||
|
"confidence": {
|
||||||
|
"InvoiceNumber": 0.96,
|
||||||
|
"InvoiceDate": 0.92,
|
||||||
|
"Amount": 0.93
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
invoice-master-poc-v2/
|
||||||
|
├── src/
|
||||||
|
│ ├── pdf/ # PDF 处理模块
|
||||||
|
│ ├── ocr/ # OCR 提取模块
|
||||||
|
│ ├── normalize/ # 字段规范化模块
|
||||||
|
│ ├── matcher/ # 字段匹配模块
|
||||||
|
│ ├── yolo/ # YOLO 标注生成
|
||||||
|
│ ├── inference/ # 推理管道
|
||||||
|
│ ├── data/ # 数据加载模块
|
||||||
|
│ └── cli/ # 命令行工具
|
||||||
|
├── configs/ # 配置文件
|
||||||
|
├── data/ # 数据目录
|
||||||
|
└── requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 开发优先级
|
||||||
|
|
||||||
|
1. ✅ 文本层 PDF 自动标注
|
||||||
|
2. ✅ 扫描图 OCR 自动标注
|
||||||
|
3. 🔄 金额 / OCR / Bankgiro 三字段稳定
|
||||||
|
4. ⏳ 日期、Plusgiro 扩展
|
||||||
|
5. ⏳ 表格 items 处理
|
||||||
|
|
||||||
|
## 配置
|
||||||
|
|
||||||
|
编辑 `configs/default.yaml` 自定义:
|
||||||
|
- PDF 渲染 DPI
|
||||||
|
- OCR 语言
|
||||||
|
- 匹配置信度阈值
|
||||||
|
- 上下文关键词
|
||||||
|
- 数据增强参数
|
||||||
|
|
||||||
|
## API 使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.inference import InferencePipeline
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
pipeline = InferencePipeline(
|
||||||
|
model_path='models/best.pt',
|
||||||
|
confidence_threshold=0.5,
|
||||||
|
ocr_lang='en'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理 PDF
|
||||||
|
result = pipeline.process_pdf('invoice.pdf')
|
||||||
|
|
||||||
|
# 获取字段
|
||||||
|
print(result.fields)
|
||||||
|
print(result.confidence)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT License
|
||||||
129
configs/default.yaml
Normal file
129
configs/default.yaml
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
# Invoice Master POC v2 Configuration
|
||||||
|
# Default configuration for invoice field extraction system
|
||||||
|
|
||||||
|
# PDF Processing
|
||||||
|
pdf:
|
||||||
|
dpi: 300 # Resolution for rendering PDFs to images
|
||||||
|
min_text_chars: 30 # Minimum chars to consider PDF as text-based
|
||||||
|
|
||||||
|
# OCR Settings
|
||||||
|
ocr:
|
||||||
|
engine: paddleocr # OCR engine to use
|
||||||
|
lang: en # Language code (en, sv, ch, etc.)
|
||||||
|
use_gpu: false # Enable GPU acceleration
|
||||||
|
|
||||||
|
# Field Normalization
|
||||||
|
normalize:
|
||||||
|
# Bankgiro formats
|
||||||
|
bankgiro:
|
||||||
|
format: "XXXX-XXXX" # Standard 8-digit format
|
||||||
|
alternatives:
|
||||||
|
- "XXX-XXXX" # 7-digit format
|
||||||
|
|
||||||
|
# Plusgiro formats
|
||||||
|
plusgiro:
|
||||||
|
format: "XXXXXXX-X" # Standard format with check digit
|
||||||
|
|
||||||
|
# Amount formats
|
||||||
|
amount:
|
||||||
|
decimal_separator: "," # Swedish uses comma
|
||||||
|
thousand_separator: " " # Space for thousands
|
||||||
|
currency_symbols:
|
||||||
|
- "SEK"
|
||||||
|
- "kr"
|
||||||
|
|
||||||
|
# Date formats
|
||||||
|
date:
|
||||||
|
output_format: "%Y-%m-%d"
|
||||||
|
input_formats:
|
||||||
|
- "%Y-%m-%d"
|
||||||
|
- "%Y-%m-%d %H:%M:%S"
|
||||||
|
- "%d/%m/%Y"
|
||||||
|
- "%d.%m.%Y"
|
||||||
|
|
||||||
|
# Field Matching
|
||||||
|
matching:
|
||||||
|
min_score_threshold: 0.7 # Minimum score to accept match
|
||||||
|
context_radius: 100 # Pixels to search for context keywords
|
||||||
|
|
||||||
|
# Context keywords for each field (Swedish)
|
||||||
|
context_keywords:
|
||||||
|
InvoiceNumber:
|
||||||
|
- "fakturanr"
|
||||||
|
- "fakturanummer"
|
||||||
|
- "invoice"
|
||||||
|
InvoiceDate:
|
||||||
|
- "fakturadatum"
|
||||||
|
- "datum"
|
||||||
|
InvoiceDueDate:
|
||||||
|
- "förfallodatum"
|
||||||
|
- "förfaller"
|
||||||
|
- "betalas senast"
|
||||||
|
OCR:
|
||||||
|
- "ocr"
|
||||||
|
- "referens"
|
||||||
|
Bankgiro:
|
||||||
|
- "bankgiro"
|
||||||
|
- "bg"
|
||||||
|
Plusgiro:
|
||||||
|
- "plusgiro"
|
||||||
|
- "pg"
|
||||||
|
Amount:
|
||||||
|
- "att betala"
|
||||||
|
- "summa"
|
||||||
|
- "total"
|
||||||
|
- "belopp"
|
||||||
|
|
||||||
|
# YOLO Training
|
||||||
|
yolo:
|
||||||
|
model: yolov8s # Model architecture (yolov8n/s/m/l/x)
|
||||||
|
epochs: 100
|
||||||
|
batch_size: 16
|
||||||
|
img_size: 1280 # Image size for training
|
||||||
|
|
||||||
|
# Data augmentation
|
||||||
|
augmentation:
|
||||||
|
rotation: 5 # Max rotation degrees
|
||||||
|
scale: 0.2 # Scale variation
|
||||||
|
mosaic: 0.0 # Disable mosaic for documents
|
||||||
|
hsv_h: 0.0 # No hue variation
|
||||||
|
hsv_s: 0.1 # Slight saturation variation
|
||||||
|
hsv_v: 0.2 # Brightness variation
|
||||||
|
|
||||||
|
# Class definitions
|
||||||
|
classes:
|
||||||
|
0: invoice_number
|
||||||
|
1: invoice_date
|
||||||
|
2: invoice_due_date
|
||||||
|
3: ocr_number
|
||||||
|
4: bankgiro
|
||||||
|
5: plusgiro
|
||||||
|
6: amount
|
||||||
|
|
||||||
|
# Auto-labeling
|
||||||
|
autolabel:
|
||||||
|
min_confidence: 0.7 # Minimum score to include in training
|
||||||
|
bbox_padding: 0.02 # Padding around bboxes (fraction of image)
|
||||||
|
|
||||||
|
# Dataset Split
|
||||||
|
dataset:
|
||||||
|
train_ratio: 0.8
|
||||||
|
val_ratio: 0.1
|
||||||
|
test_ratio: 0.1
|
||||||
|
random_seed: 42
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
inference:
|
||||||
|
confidence_threshold: 0.5 # Detection confidence threshold
|
||||||
|
iou_threshold: 0.45 # NMS IOU threshold
|
||||||
|
enable_fallback: true # Enable regex fallback if YOLO fails
|
||||||
|
fallback_min_missing: 2 # Min missing fields to trigger fallback
|
||||||
|
|
||||||
|
# Paths (relative to project root)
|
||||||
|
paths:
|
||||||
|
raw_pdfs: data/raw_pdfs
|
||||||
|
images: data/images
|
||||||
|
labels: data/labels
|
||||||
|
structured_data: data/structured_data
|
||||||
|
models: models
|
||||||
|
reports: reports
|
||||||
59
configs/training.yaml
Normal file
59
configs/training.yaml
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# YOLO Training Configuration
|
||||||
|
# Use with: yolo train data=dataset.yaml cfg=training.yaml
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model: yolov8s.pt
|
||||||
|
|
||||||
|
# Training hyperparameters
|
||||||
|
epochs: 100
|
||||||
|
patience: 20 # Early stopping patience
|
||||||
|
batch: 16
|
||||||
|
imgsz: 1280
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer: AdamW
|
||||||
|
lr0: 0.001 # Initial learning rate
|
||||||
|
lrf: 0.01 # Final learning rate factor
|
||||||
|
momentum: 0.937
|
||||||
|
weight_decay: 0.0005
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
warmup_epochs: 3
|
||||||
|
warmup_momentum: 0.8
|
||||||
|
warmup_bias_lr: 0.1
|
||||||
|
|
||||||
|
# Loss weights
|
||||||
|
box: 7.5 # Box loss gain
|
||||||
|
cls: 0.5 # Class loss gain
|
||||||
|
dfl: 1.5 # DFL loss gain
|
||||||
|
|
||||||
|
# Augmentation
|
||||||
|
# Keep minimal for document images
|
||||||
|
hsv_h: 0.0 # No hue augmentation
|
||||||
|
hsv_s: 0.1 # Slight saturation
|
||||||
|
hsv_v: 0.2 # Brightness variation
|
||||||
|
degrees: 5.0 # Rotation ±5°
|
||||||
|
translate: 0.05 # Translation
|
||||||
|
scale: 0.2 # Scale ±20%
|
||||||
|
shear: 0.0 # No shear
|
||||||
|
perspective: 0.0 # No perspective
|
||||||
|
flipud: 0.0 # No vertical flip
|
||||||
|
fliplr: 0.0 # No horizontal flip
|
||||||
|
mosaic: 0.0 # Disable mosaic (not suitable for documents)
|
||||||
|
mixup: 0.0 # Disable mixup
|
||||||
|
copy_paste: 0.0 # Disable copy-paste
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
val: true
|
||||||
|
save: true
|
||||||
|
save_period: 10
|
||||||
|
cache: true
|
||||||
|
|
||||||
|
# Other
|
||||||
|
device: 0 # GPU device (0, 1, etc.) or 'cpu'
|
||||||
|
workers: 8
|
||||||
|
project: runs/train
|
||||||
|
name: invoice_fields
|
||||||
|
exist_ok: true
|
||||||
|
pretrained: true
|
||||||
|
verbose: true
|
||||||
78
pyproject.toml
Normal file
78
pyproject.toml
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=68.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "invoice-master"
|
||||||
|
version = "2.0.0"
|
||||||
|
description = "Automatic invoice information extraction using YOLO + OCR"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Invoice Master Team"}
|
||||||
|
]
|
||||||
|
keywords = ["invoice", "ocr", "yolo", "document-processing", "pdf"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"PyMuPDF>=1.23.0",
|
||||||
|
"paddlepaddle>=2.5.0",
|
||||||
|
"paddleocr>=2.7.0",
|
||||||
|
"ultralytics>=8.1.0",
|
||||||
|
"Pillow>=10.0.0",
|
||||||
|
"numpy>=1.24.0",
|
||||||
|
"opencv-python>=4.8.0",
|
||||||
|
"pyyaml>=6.0",
|
||||||
|
"tqdm>=4.65.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
"mypy>=1.0.0",
|
||||||
|
]
|
||||||
|
gpu = [
|
||||||
|
"paddlepaddle-gpu>=2.5.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
invoice-autolabel = "src.cli.autolabel:main"
|
||||||
|
invoice-train = "src.cli.train:main"
|
||||||
|
invoice-infer = "src.cli.infer:main"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
where = ["."]
|
||||||
|
include = ["src*"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ["py310", "py311", "py312"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
select = ["E", "F", "W", "I", "N", "D", "UP", "B", "C4", "SIM"]
|
||||||
|
ignore = ["D100", "D104"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
addopts = "-v --cov=src --cov-report=term-missing"
|
||||||
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# Invoice Master POC v2 - Dependencies
|
||||||
|
|
||||||
|
# PDF Processing
|
||||||
|
PyMuPDF>=1.23.0 # PDF rendering and text extraction
|
||||||
|
|
||||||
|
# OCR
|
||||||
|
paddlepaddle>=2.5.0 # PaddlePaddle framework
|
||||||
|
paddleocr>=2.7.0 # PaddleOCR
|
||||||
|
|
||||||
|
# YOLO
|
||||||
|
ultralytics>=8.1.0 # YOLOv8/v11
|
||||||
|
|
||||||
|
# Image Processing
|
||||||
|
Pillow>=10.0.0 # Image handling
|
||||||
|
numpy>=1.24.0 # Array operations
|
||||||
|
opencv-python>=4.8.0 # Image processing
|
||||||
|
|
||||||
|
# Data Processing
|
||||||
|
pyyaml>=6.0 # YAML config files
|
||||||
|
|
||||||
|
# Utilities
|
||||||
|
tqdm>=4.65.0 # Progress bars
|
||||||
10
run_autolabel.py
Normal file
10
run_autolabel.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
自动标注脚本 - 调用 CLI 模块
|
||||||
|
在 WSL 中运行: python run_autolabel.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
from src.cli.autolabel import main
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
55
scripts/run_autolabel.sh
Normal file
55
scripts/run_autolabel.sh
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# 自动标注运行脚本
|
||||||
|
# 使用方法: bash scripts/run_autolabel.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# 项目根目录
|
||||||
|
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
cd "$PROJECT_DIR"
|
||||||
|
|
||||||
|
# 激活虚拟环境
|
||||||
|
if [ -f "venv/bin/activate" ]; then
|
||||||
|
source venv/bin/activate
|
||||||
|
else
|
||||||
|
echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 默认参数
|
||||||
|
CSV_FILE="${CSV_FILE:-data/structured_data/invoices.csv}"
|
||||||
|
PDF_DIR="${PDF_DIR:-data/raw_pdfs}"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-data/dataset}"
|
||||||
|
REPORT_FILE="${REPORT_FILE:-reports/autolabel_report.jsonl}"
|
||||||
|
DPI="${DPI:-300}"
|
||||||
|
MIN_CONFIDENCE="${MIN_CONFIDENCE:-0.7}"
|
||||||
|
|
||||||
|
# 显示配置
|
||||||
|
echo "=========================================="
|
||||||
|
echo "自动标注配置"
|
||||||
|
echo "=========================================="
|
||||||
|
echo "CSV 文件: $CSV_FILE"
|
||||||
|
echo "PDF 目录: $PDF_DIR"
|
||||||
|
echo "输出目录: $OUTPUT_DIR"
|
||||||
|
echo "报告文件: $REPORT_FILE"
|
||||||
|
echo "DPI: $DPI"
|
||||||
|
echo "最小置信度: $MIN_CONFIDENCE"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 创建必要目录
|
||||||
|
mkdir -p "$(dirname "$REPORT_FILE")"
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# 运行自动标注
|
||||||
|
python -m src.cli.autolabel \
|
||||||
|
--csv "$CSV_FILE" \
|
||||||
|
--pdf-dir "$PDF_DIR" \
|
||||||
|
--output "$OUTPUT_DIR" \
|
||||||
|
--report "$REPORT_FILE" \
|
||||||
|
--dpi "$DPI" \
|
||||||
|
--min-confidence "$MIN_CONFIDENCE" \
|
||||||
|
--verbose
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "完成! 数据集已生成到: $OUTPUT_DIR"
|
||||||
67
scripts/run_train.sh
Normal file
67
scripts/run_train.sh
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# 训练运行脚本
|
||||||
|
# 使用方法: bash scripts/run_train.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# 项目根目录
|
||||||
|
PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||||
|
cd "$PROJECT_DIR"
|
||||||
|
|
||||||
|
# 激活虚拟环境
|
||||||
|
if [ -f "venv/bin/activate" ]; then
|
||||||
|
source venv/bin/activate
|
||||||
|
else
|
||||||
|
echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 默认参数
|
||||||
|
DATA_YAML="${DATA_YAML:-data/dataset/dataset.yaml}"
|
||||||
|
MODEL="${MODEL:-yolov8s.pt}"
|
||||||
|
EPOCHS="${EPOCHS:-100}"
|
||||||
|
BATCH_SIZE="${BATCH_SIZE:-16}"
|
||||||
|
IMG_SIZE="${IMG_SIZE:-1280}"
|
||||||
|
DEVICE="${DEVICE:-0}"
|
||||||
|
|
||||||
|
# 检查数据集是否存在
|
||||||
|
if [ ! -f "$DATA_YAML" ]; then
|
||||||
|
echo "错误: 数据集配置文件不存在: $DATA_YAML"
|
||||||
|
echo "请先运行自动标注: bash scripts/run_autolabel.sh"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 显示配置
|
||||||
|
echo "=========================================="
|
||||||
|
echo "训练配置"
|
||||||
|
echo "=========================================="
|
||||||
|
echo "数据集: $DATA_YAML"
|
||||||
|
echo "基础模型: $MODEL"
|
||||||
|
echo "Epochs: $EPOCHS"
|
||||||
|
echo "Batch Size: $BATCH_SIZE"
|
||||||
|
echo "图像尺寸: $IMG_SIZE"
|
||||||
|
echo "设备: $DEVICE"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 检查 GPU
|
||||||
|
if command -v nvidia-smi &> /dev/null; then
|
||||||
|
echo "GPU 状态:"
|
||||||
|
nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader
|
||||||
|
echo ""
|
||||||
|
else
|
||||||
|
echo "警告: 未检测到 GPU,将使用 CPU 训练 (较慢)"
|
||||||
|
DEVICE="cpu"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# 运行训练
|
||||||
|
python -m src.cli.train \
|
||||||
|
--data "$DATA_YAML" \
|
||||||
|
--model "$MODEL" \
|
||||||
|
--epochs "$EPOCHS" \
|
||||||
|
--batch "$BATCH_SIZE" \
|
||||||
|
--imgsz "$IMG_SIZE" \
|
||||||
|
--device "$DEVICE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "训练完成! 模型保存在: runs/train/invoice_fields/weights/"
|
||||||
80
scripts/setup_wsl.sh
Normal file
80
scripts/setup_wsl.sh
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# WSL 环境安装脚本
|
||||||
|
# 使用方法: bash scripts/setup_wsl.sh
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "=========================================="
|
||||||
|
echo "Invoice Master POC v2 - WSL 安装脚本"
|
||||||
|
echo "=========================================="
|
||||||
|
|
||||||
|
# 检查是否在 WSL 中运行
|
||||||
|
if ! grep -qi microsoft /proc/version 2>/dev/null; then
|
||||||
|
echo "警告: 未检测到 WSL 环境,请在 WSL 中运行此脚本"
|
||||||
|
echo "提示: 在 Windows 终端中输入 'wsl' 进入 WSL"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[1/5] 更新系统包..."
|
||||||
|
sudo apt update
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[2/5] 安装系统依赖..."
|
||||||
|
sudo apt install -y \
|
||||||
|
python3.10 \
|
||||||
|
python3.10-venv \
|
||||||
|
python3-pip \
|
||||||
|
libgl1-mesa-glx \
|
||||||
|
libglib2.0-0 \
|
||||||
|
libsm6 \
|
||||||
|
libxrender1 \
|
||||||
|
libxext6 \
|
||||||
|
libgomp1
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[3/5] 创建 Python 虚拟环境..."
|
||||||
|
if [ -d "venv" ]; then
|
||||||
|
echo "虚拟环境已存在,跳过创建"
|
||||||
|
else
|
||||||
|
python3 -m venv venv
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[4/5] 激活虚拟环境并安装依赖..."
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install --upgrade pip
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "安装 Python 依赖包..."
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[5/5] 验证安装..."
|
||||||
|
python3 -c "import fitz; print(f'PyMuPDF: {fitz.version}')"
|
||||||
|
python3 -c "from ultralytics import YOLO; print('Ultralytics: OK')"
|
||||||
|
python3 -c "from paddleocr import PaddleOCR; print('PaddleOCR: OK')"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo "安装完成!"
|
||||||
|
echo "=========================================="
|
||||||
|
echo ""
|
||||||
|
echo "使用方法:"
|
||||||
|
echo " 1. 激活虚拟环境: source venv/bin/activate"
|
||||||
|
echo " 2. 运行自动标注: python -m src.cli.autolabel --help"
|
||||||
|
echo " 3. 训练模型: python -m src.cli.train --help"
|
||||||
|
echo " 4. 推理: python -m src.cli.infer --help"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# 检查 GPU
|
||||||
|
echo "检查 GPU 支持..."
|
||||||
|
if command -v nvidia-smi &> /dev/null; then
|
||||||
|
echo "检测到 NVIDIA GPU:"
|
||||||
|
nvidia-smi --query-gpu=name,memory.total --format=csv,noheader
|
||||||
|
echo ""
|
||||||
|
echo "提示: 运行以下命令启用 GPU 加速:"
|
||||||
|
echo " pip install paddlepaddle-gpu"
|
||||||
|
else
|
||||||
|
echo "未检测到 GPU,将使用 CPU 模式"
|
||||||
|
fi
|
||||||
2
src/__init__.py
Normal file
2
src/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Invoice Master POC v2
|
||||||
|
# Automatic invoice information extraction system using YOLO + OCR
|
||||||
1
src/cli/__init__.py
Normal file
1
src/cli/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
# CLI modules for Invoice Master
|
||||||
458
src/cli/autolabel.py
Normal file
458
src/cli/autolabel.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Auto-labeling CLI
|
||||||
|
|
||||||
|
Generates YOLO training data from PDFs and structured CSV data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
import multiprocessing
|
||||||
|
|
||||||
|
# Global OCR engine for worker processes (initialized once per worker)
|
||||||
|
_worker_ocr_engine = None
|
||||||
|
|
||||||
|
|
||||||
|
def _init_worker():
|
||||||
|
"""Initialize worker process with OCR engine (called once per worker)."""
|
||||||
|
global _worker_ocr_engine
|
||||||
|
# OCR engine will be lazily initialized on first use
|
||||||
|
_worker_ocr_engine = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_ocr_engine():
|
||||||
|
"""Get or create OCR engine for current worker."""
|
||||||
|
global _worker_ocr_engine
|
||||||
|
if _worker_ocr_engine is None:
|
||||||
|
from ..ocr import OCREngine
|
||||||
|
_worker_ocr_engine = OCREngine()
|
||||||
|
return _worker_ocr_engine
|
||||||
|
|
||||||
|
|
||||||
|
def process_single_document(args_tuple):
|
||||||
|
"""
|
||||||
|
Process a single document (worker function for parallel processing).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with results
|
||||||
|
"""
|
||||||
|
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple
|
||||||
|
|
||||||
|
# Import inside worker to avoid pickling issues
|
||||||
|
from ..data import AutoLabelReport, FieldMatchResult
|
||||||
|
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||||
|
from ..pdf.renderer import get_render_dimensions
|
||||||
|
from ..matcher import FieldMatcher
|
||||||
|
from ..normalize import normalize_field
|
||||||
|
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
pdf_path = Path(pdf_path_str)
|
||||||
|
output_dir = Path(output_dir_str)
|
||||||
|
doc_id = row_dict['DocumentId']
|
||||||
|
|
||||||
|
report = AutoLabelReport(document_id=doc_id)
|
||||||
|
report.pdf_path = str(pdf_path)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
'doc_id': doc_id,
|
||||||
|
'success': False,
|
||||||
|
'pages': [],
|
||||||
|
'report': None,
|
||||||
|
'stats': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check PDF type
|
||||||
|
use_ocr = not is_text_pdf(pdf_path)
|
||||||
|
report.pdf_type = "scanned" if use_ocr else "text"
|
||||||
|
|
||||||
|
# Skip OCR if requested
|
||||||
|
if use_ocr and skip_ocr:
|
||||||
|
report.errors.append("Skipped (scanned PDF)")
|
||||||
|
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||||
|
result['report'] = report.to_dict()
|
||||||
|
return result
|
||||||
|
|
||||||
|
# Get OCR engine from worker cache (only created once per worker)
|
||||||
|
ocr_engine = None
|
||||||
|
if use_ocr:
|
||||||
|
ocr_engine = _get_ocr_engine()
|
||||||
|
|
||||||
|
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||||
|
matcher = FieldMatcher()
|
||||||
|
|
||||||
|
# Process each page
|
||||||
|
page_annotations = []
|
||||||
|
|
||||||
|
for page_no, image_path in render_pdf_to_images(
|
||||||
|
pdf_path,
|
||||||
|
output_dir / 'temp' / doc_id / 'images',
|
||||||
|
dpi=dpi
|
||||||
|
):
|
||||||
|
report.total_pages += 1
|
||||||
|
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||||
|
|
||||||
|
# Extract tokens
|
||||||
|
if use_ocr:
|
||||||
|
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||||
|
else:
|
||||||
|
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||||
|
|
||||||
|
# Match fields
|
||||||
|
matches = {}
|
||||||
|
for field_name in FIELD_CLASSES.keys():
|
||||||
|
value = row_dict.get(field_name)
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized = normalize_field(field_name, str(value))
|
||||||
|
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||||
|
|
||||||
|
# Record result
|
||||||
|
if field_matches:
|
||||||
|
best = field_matches[0]
|
||||||
|
matches[field_name] = field_matches
|
||||||
|
report.add_field_result(FieldMatchResult(
|
||||||
|
field_name=field_name,
|
||||||
|
csv_value=str(value),
|
||||||
|
matched=True,
|
||||||
|
score=best.score,
|
||||||
|
matched_text=best.matched_text,
|
||||||
|
candidate_used=best.value,
|
||||||
|
bbox=best.bbox,
|
||||||
|
page_no=page_no,
|
||||||
|
context_keywords=best.context_keywords
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
report.add_field_result(FieldMatchResult(
|
||||||
|
field_name=field_name,
|
||||||
|
csv_value=str(value),
|
||||||
|
matched=False,
|
||||||
|
page_no=page_no
|
||||||
|
))
|
||||||
|
|
||||||
|
# Generate annotations
|
||||||
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||||
|
|
||||||
|
if annotations:
|
||||||
|
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
|
||||||
|
generator.save_annotations(annotations, label_path)
|
||||||
|
page_annotations.append({
|
||||||
|
'image_path': str(image_path),
|
||||||
|
'label_path': str(label_path),
|
||||||
|
'count': len(annotations)
|
||||||
|
})
|
||||||
|
|
||||||
|
report.annotations_generated += len(annotations)
|
||||||
|
for ann in annotations:
|
||||||
|
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||||
|
result['stats'][class_name] += 1
|
||||||
|
|
||||||
|
if page_annotations:
|
||||||
|
result['pages'] = page_annotations
|
||||||
|
result['success'] = True
|
||||||
|
report.success = True
|
||||||
|
else:
|
||||||
|
report.errors.append("No annotations generated")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
report.errors.append(str(e))
|
||||||
|
|
||||||
|
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||||
|
result['report'] = report.to_dict()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Generate YOLO annotations from PDFs and CSV data'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--csv', '-c',
|
||||||
|
default='data/structured_data/document_export_20260109_212743.csv',
|
||||||
|
help='Path to structured data CSV file'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--pdf-dir', '-p',
|
||||||
|
default='data/raw_pdfs',
|
||||||
|
help='Directory containing PDF files'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output', '-o',
|
||||||
|
default='data/dataset',
|
||||||
|
help='Output directory for dataset'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--dpi',
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help='DPI for PDF rendering (default: 300)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--min-confidence',
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help='Minimum match confidence (default: 0.7)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--train-ratio',
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help='Training set ratio (default: 0.8)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--val-ratio',
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help='Validation set ratio (default: 0.1)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--report',
|
||||||
|
default='reports/autolabel_report.jsonl',
|
||||||
|
help='Path for auto-label report (JSONL)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--single',
|
||||||
|
help='Process single document ID only'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--verbose', '-v',
|
||||||
|
action='store_true',
|
||||||
|
help='Verbose output'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--workers', '-w',
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help='Number of parallel workers (default: 4)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--skip-ocr',
|
||||||
|
action='store_true',
|
||||||
|
help='Skip scanned PDFs (text-layer only)'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Import here to avoid slow startup
|
||||||
|
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||||
|
from ..data.autolabel_report import ReportWriter
|
||||||
|
from ..yolo import DatasetBuilder
|
||||||
|
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||||
|
from ..pdf.renderer import get_render_dimensions
|
||||||
|
from ..ocr import OCREngine
|
||||||
|
from ..matcher import FieldMatcher
|
||||||
|
from ..normalize import normalize_field
|
||||||
|
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||||
|
|
||||||
|
print(f"Loading CSV data from: {args.csv}")
|
||||||
|
loader = CSVLoader(args.csv, args.pdf_dir)
|
||||||
|
|
||||||
|
# Validate data
|
||||||
|
issues = loader.validate()
|
||||||
|
if issues:
|
||||||
|
print(f"Warning: Found {len(issues)} validation issues")
|
||||||
|
if args.verbose:
|
||||||
|
for issue in issues[:10]:
|
||||||
|
print(f" - {issue}")
|
||||||
|
|
||||||
|
rows = loader.load_all()
|
||||||
|
print(f"Loaded {len(rows)} invoice records")
|
||||||
|
|
||||||
|
# Filter to single document if specified
|
||||||
|
if args.single:
|
||||||
|
rows = [r for r in rows if r.DocumentId == args.single]
|
||||||
|
if not rows:
|
||||||
|
print(f"Error: Document {args.single} not found")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"Processing single document: {args.single}")
|
||||||
|
|
||||||
|
# Setup output directories
|
||||||
|
output_dir = Path(args.output)
|
||||||
|
for split in ['train', 'val', 'test']:
|
||||||
|
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||||
|
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate YOLO config files
|
||||||
|
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
|
||||||
|
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
|
||||||
|
|
||||||
|
# Report writer
|
||||||
|
report_path = Path(args.report)
|
||||||
|
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
report_writer = ReportWriter(args.report)
|
||||||
|
|
||||||
|
# Stats
|
||||||
|
stats = {
|
||||||
|
'total': len(rows),
|
||||||
|
'successful': 0,
|
||||||
|
'failed': 0,
|
||||||
|
'skipped': 0,
|
||||||
|
'annotations': 0,
|
||||||
|
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Prepare tasks
|
||||||
|
tasks = []
|
||||||
|
for row in rows:
|
||||||
|
pdf_path = loader.get_pdf_path(row)
|
||||||
|
if not pdf_path:
|
||||||
|
# Write report for missing PDF
|
||||||
|
report = AutoLabelReport(document_id=row.DocumentId)
|
||||||
|
report.errors.append("PDF not found")
|
||||||
|
report_writer.write(report)
|
||||||
|
stats['failed'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Convert row to dict for pickling
|
||||||
|
row_dict = {
|
||||||
|
'DocumentId': row.DocumentId,
|
||||||
|
'InvoiceNumber': row.InvoiceNumber,
|
||||||
|
'InvoiceDate': row.InvoiceDate,
|
||||||
|
'InvoiceDueDate': row.InvoiceDueDate,
|
||||||
|
'OCR': row.OCR,
|
||||||
|
'Bankgiro': row.Bankgiro,
|
||||||
|
'Plusgiro': row.Plusgiro,
|
||||||
|
'Amount': row.Amount,
|
||||||
|
}
|
||||||
|
|
||||||
|
tasks.append((
|
||||||
|
row_dict,
|
||||||
|
str(pdf_path),
|
||||||
|
str(output_dir),
|
||||||
|
args.dpi,
|
||||||
|
args.min_confidence,
|
||||||
|
args.skip_ocr
|
||||||
|
))
|
||||||
|
|
||||||
|
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||||
|
|
||||||
|
# Process documents in parallel
|
||||||
|
processed_items = []
|
||||||
|
|
||||||
|
# Use single process for debugging or when workers=1
|
||||||
|
if args.workers == 1:
|
||||||
|
for task in tqdm(tasks, desc="Processing"):
|
||||||
|
result = process_single_document(task)
|
||||||
|
|
||||||
|
# Write report
|
||||||
|
if result['report']:
|
||||||
|
report_writer.write_dict(result['report'])
|
||||||
|
|
||||||
|
if result['success']:
|
||||||
|
processed_items.append({
|
||||||
|
'doc_id': result['doc_id'],
|
||||||
|
'pages': result['pages']
|
||||||
|
})
|
||||||
|
stats['successful'] += 1
|
||||||
|
for field, count in result['stats'].items():
|
||||||
|
stats['by_field'][field] += count
|
||||||
|
stats['annotations'] += count
|
||||||
|
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||||
|
stats['skipped'] += 1
|
||||||
|
else:
|
||||||
|
stats['failed'] += 1
|
||||||
|
else:
|
||||||
|
# Parallel processing with worker initialization
|
||||||
|
# Each worker initializes OCR engine once and reuses it
|
||||||
|
with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor:
|
||||||
|
futures = {executor.submit(process_single_document, task): task[0]['DocumentId']
|
||||||
|
for task in tasks}
|
||||||
|
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"):
|
||||||
|
doc_id = futures[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
|
||||||
|
# Write report
|
||||||
|
if result['report']:
|
||||||
|
report_writer.write_dict(result['report'])
|
||||||
|
|
||||||
|
if result['success']:
|
||||||
|
processed_items.append({
|
||||||
|
'doc_id': result['doc_id'],
|
||||||
|
'pages': result['pages']
|
||||||
|
})
|
||||||
|
stats['successful'] += 1
|
||||||
|
for field, count in result['stats'].items():
|
||||||
|
stats['by_field'][field] += count
|
||||||
|
stats['annotations'] += count
|
||||||
|
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||||
|
stats['skipped'] += 1
|
||||||
|
else:
|
||||||
|
stats['failed'] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stats['failed'] += 1
|
||||||
|
# Write error report for failed documents
|
||||||
|
error_report = {
|
||||||
|
'document_id': doc_id,
|
||||||
|
'success': False,
|
||||||
|
'errors': [f"Worker error: {str(e)}"]
|
||||||
|
}
|
||||||
|
report_writer.write_dict(error_report)
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Error processing {doc_id}: {e}")
|
||||||
|
|
||||||
|
# Split and move files
|
||||||
|
import random
|
||||||
|
random.seed(42)
|
||||||
|
random.shuffle(processed_items)
|
||||||
|
|
||||||
|
n_train = int(len(processed_items) * args.train_ratio)
|
||||||
|
n_val = int(len(processed_items) * args.val_ratio)
|
||||||
|
|
||||||
|
splits = {
|
||||||
|
'train': processed_items[:n_train],
|
||||||
|
'val': processed_items[n_train:n_train + n_val],
|
||||||
|
'test': processed_items[n_train + n_val:]
|
||||||
|
}
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
for split_name, items in splits.items():
|
||||||
|
for item in items:
|
||||||
|
for page in item['pages']:
|
||||||
|
# Move image
|
||||||
|
image_path = Path(page['image_path'])
|
||||||
|
label_path = Path(page['label_path'])
|
||||||
|
dest_img = output_dir / split_name / 'images' / image_path.name
|
||||||
|
shutil.move(str(image_path), str(dest_img))
|
||||||
|
|
||||||
|
# Move label
|
||||||
|
dest_label = output_dir / split_name / 'labels' / label_path.name
|
||||||
|
shutil.move(str(label_path), str(dest_label))
|
||||||
|
|
||||||
|
# Cleanup temp
|
||||||
|
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Auto-labeling Complete")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Total documents: {stats['total']}")
|
||||||
|
print(f"Successful: {stats['successful']}")
|
||||||
|
print(f"Failed: {stats['failed']}")
|
||||||
|
print(f"Skipped (OCR): {stats['skipped']}")
|
||||||
|
print(f"Total annotations: {stats['annotations']}")
|
||||||
|
print(f"\nDataset split:")
|
||||||
|
print(f" Train: {len(splits['train'])} documents")
|
||||||
|
print(f" Val: {len(splits['val'])} documents")
|
||||||
|
print(f" Test: {len(splits['test'])} documents")
|
||||||
|
print(f"\nAnnotations by field:")
|
||||||
|
for field, count in stats['by_field'].items():
|
||||||
|
print(f" {field}: {count}")
|
||||||
|
print(f"\nOutput: {output_dir}")
|
||||||
|
print(f"Report: {args.report}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
139
src/cli/infer.py
Normal file
139
src/cli/infer.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Inference CLI
|
||||||
|
|
||||||
|
Runs inference on new PDFs to extract invoice data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Extract invoice data from PDFs using trained model'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--model', '-m',
|
||||||
|
required=True,
|
||||||
|
help='Path to trained YOLO model (.pt file)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--input', '-i',
|
||||||
|
required=True,
|
||||||
|
help='Input PDF file or directory'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--output', '-o',
|
||||||
|
help='Output JSON file (default: stdout)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--confidence',
|
||||||
|
type=float,
|
||||||
|
default=0.5,
|
||||||
|
help='Detection confidence threshold (default: 0.5)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--dpi',
|
||||||
|
type=int,
|
||||||
|
default=300,
|
||||||
|
help='DPI for PDF rendering (default: 300)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--no-fallback',
|
||||||
|
action='store_true',
|
||||||
|
help='Disable fallback OCR'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--lang',
|
||||||
|
default='en',
|
||||||
|
help='OCR language (default: en)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--gpu',
|
||||||
|
action='store_true',
|
||||||
|
help='Use GPU'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--verbose', '-v',
|
||||||
|
action='store_true',
|
||||||
|
help='Verbose output'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate model
|
||||||
|
model_path = Path(args.model)
|
||||||
|
if not model_path.exists():
|
||||||
|
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Get input files
|
||||||
|
input_path = Path(args.input)
|
||||||
|
if input_path.is_file():
|
||||||
|
pdf_files = [input_path]
|
||||||
|
elif input_path.is_dir():
|
||||||
|
pdf_files = list(input_path.glob('*.pdf'))
|
||||||
|
else:
|
||||||
|
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not pdf_files:
|
||||||
|
print("Error: No PDF files found", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||||
|
print(f"Model: {model_path}")
|
||||||
|
|
||||||
|
from ..inference import InferencePipeline
|
||||||
|
|
||||||
|
# Initialize pipeline
|
||||||
|
pipeline = InferencePipeline(
|
||||||
|
model_path=model_path,
|
||||||
|
confidence_threshold=args.confidence,
|
||||||
|
ocr_lang=args.lang,
|
||||||
|
use_gpu=args.gpu,
|
||||||
|
dpi=args.dpi,
|
||||||
|
enable_fallback=not args.no_fallback
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process files
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for pdf_path in pdf_files:
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Processing: {pdf_path.name}")
|
||||||
|
|
||||||
|
result = pipeline.process_pdf(pdf_path)
|
||||||
|
results.append(result.to_json())
|
||||||
|
|
||||||
|
if args.verbose:
|
||||||
|
print(f" Success: {result.success}")
|
||||||
|
print(f" Fields: {len(result.fields)}")
|
||||||
|
if result.fallback_used:
|
||||||
|
print(f" Fallback used: Yes")
|
||||||
|
if result.errors:
|
||||||
|
print(f" Errors: {result.errors}")
|
||||||
|
|
||||||
|
# Output results
|
||||||
|
if len(results) == 1:
|
||||||
|
output = results[0]
|
||||||
|
else:
|
||||||
|
output = results
|
||||||
|
|
||||||
|
json_output = json.dumps(output, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
with open(args.output, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(json_output)
|
||||||
|
if args.verbose:
|
||||||
|
print(f"\nResults written to: {args.output}")
|
||||||
|
else:
|
||||||
|
print(json_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
138
src/cli/train.py
Normal file
138
src/cli/train.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Training CLI
|
||||||
|
|
||||||
|
Trains YOLO model on generated dataset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Train YOLO model for invoice field detection'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--data', '-d',
|
||||||
|
required=True,
|
||||||
|
help='Path to dataset.yaml file'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--model', '-m',
|
||||||
|
default='yolov8s.pt',
|
||||||
|
help='Base model (default: yolov8s.pt)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--epochs', '-e',
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help='Number of epochs (default: 100)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--batch', '-b',
|
||||||
|
type=int,
|
||||||
|
default=16,
|
||||||
|
help='Batch size (default: 16)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--imgsz',
|
||||||
|
type=int,
|
||||||
|
default=1280,
|
||||||
|
help='Image size (default: 1280)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--project',
|
||||||
|
default='runs/train',
|
||||||
|
help='Project directory (default: runs/train)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--name',
|
||||||
|
default='invoice_fields',
|
||||||
|
help='Run name (default: invoice_fields)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device',
|
||||||
|
default='0',
|
||||||
|
help='Device (0, 1, cpu, mps)'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--resume',
|
||||||
|
help='Resume from checkpoint'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--config',
|
||||||
|
help='Path to training config YAML'
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Validate data file
|
||||||
|
data_path = Path(args.data)
|
||||||
|
if not data_path.exists():
|
||||||
|
print(f"Error: Dataset file not found: {data_path}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Training YOLO model for invoice field detection")
|
||||||
|
print(f"Dataset: {args.data}")
|
||||||
|
print(f"Model: {args.model}")
|
||||||
|
print(f"Epochs: {args.epochs}")
|
||||||
|
print(f"Batch size: {args.batch}")
|
||||||
|
print(f"Image size: {args.imgsz}")
|
||||||
|
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
if args.resume:
|
||||||
|
print(f"Resuming from: {args.resume}")
|
||||||
|
model = YOLO(args.resume)
|
||||||
|
else:
|
||||||
|
model = YOLO(args.model)
|
||||||
|
|
||||||
|
# Training arguments
|
||||||
|
train_args = {
|
||||||
|
'data': str(data_path.absolute()),
|
||||||
|
'epochs': args.epochs,
|
||||||
|
'batch': args.batch,
|
||||||
|
'imgsz': args.imgsz,
|
||||||
|
'project': args.project,
|
||||||
|
'name': args.name,
|
||||||
|
'device': args.device,
|
||||||
|
'exist_ok': True,
|
||||||
|
'pretrained': True,
|
||||||
|
'verbose': True,
|
||||||
|
# Document-specific augmentation settings
|
||||||
|
'degrees': 5.0,
|
||||||
|
'translate': 0.05,
|
||||||
|
'scale': 0.2,
|
||||||
|
'shear': 0.0,
|
||||||
|
'perspective': 0.0,
|
||||||
|
'flipud': 0.0,
|
||||||
|
'fliplr': 0.0,
|
||||||
|
'mosaic': 0.0,
|
||||||
|
'mixup': 0.0,
|
||||||
|
'hsv_h': 0.0,
|
||||||
|
'hsv_s': 0.1,
|
||||||
|
'hsv_v': 0.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Train
|
||||||
|
results = model.train(**train_args)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
print("\n" + "=" * 50)
|
||||||
|
print("Training Complete")
|
||||||
|
print("=" * 50)
|
||||||
|
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
||||||
|
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
||||||
|
|
||||||
|
# Validate on test set
|
||||||
|
print("\nRunning validation...")
|
||||||
|
metrics = model.val()
|
||||||
|
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||||
|
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
4
src/data/__init__.py
Normal file
4
src/data/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .csv_loader import CSVLoader, InvoiceRow
|
||||||
|
from .autolabel_report import AutoLabelReport, FieldMatchResult
|
||||||
|
|
||||||
|
__all__ = ['CSVLoader', 'InvoiceRow', 'AutoLabelReport', 'FieldMatchResult']
|
||||||
252
src/data/autolabel_report.py
Normal file
252
src/data/autolabel_report.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""
|
||||||
|
Auto-Label Report Generator
|
||||||
|
|
||||||
|
Generates quality control reports for auto-labeling process.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field, asdict
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FieldMatchResult:
|
||||||
|
"""Result of matching a single field."""
|
||||||
|
field_name: str
|
||||||
|
csv_value: str | None
|
||||||
|
matched: bool
|
||||||
|
score: float = 0.0
|
||||||
|
matched_text: str | None = None
|
||||||
|
candidate_used: str | None = None # Which normalized variant matched
|
||||||
|
bbox: tuple[float, float, float, float] | None = None
|
||||||
|
page_no: int = 0
|
||||||
|
context_keywords: list[str] = field(default_factory=list)
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
# Convert bbox to native Python floats to avoid numpy serialization issues
|
||||||
|
bbox_list = None
|
||||||
|
if self.bbox:
|
||||||
|
bbox_list = [float(x) for x in self.bbox]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'field_name': self.field_name,
|
||||||
|
'csv_value': self.csv_value,
|
||||||
|
'matched': self.matched,
|
||||||
|
'score': float(self.score) if self.score else 0.0,
|
||||||
|
'matched_text': self.matched_text,
|
||||||
|
'candidate_used': self.candidate_used,
|
||||||
|
'bbox': bbox_list,
|
||||||
|
'page_no': int(self.page_no) if self.page_no else 0,
|
||||||
|
'context_keywords': self.context_keywords,
|
||||||
|
'error': self.error
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AutoLabelReport:
|
||||||
|
"""Report for a single document's auto-labeling process."""
|
||||||
|
document_id: str
|
||||||
|
pdf_path: str | None = None
|
||||||
|
pdf_type: str | None = None # 'text' | 'scanned' | 'mixed'
|
||||||
|
success: bool = False
|
||||||
|
total_pages: int = 0
|
||||||
|
fields_matched: int = 0
|
||||||
|
fields_total: int = 0
|
||||||
|
field_results: list[FieldMatchResult] = field(default_factory=list)
|
||||||
|
annotations_generated: int = 0
|
||||||
|
image_paths: list[str] = field(default_factory=list)
|
||||||
|
label_paths: list[str] = field(default_factory=list)
|
||||||
|
processing_time_ms: float = 0.0
|
||||||
|
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def add_field_result(self, result: FieldMatchResult) -> None:
|
||||||
|
"""Add a field matching result."""
|
||||||
|
self.field_results.append(result)
|
||||||
|
self.fields_total += 1
|
||||||
|
if result.matched:
|
||||||
|
self.fields_matched += 1
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'document_id': self.document_id,
|
||||||
|
'pdf_path': self.pdf_path,
|
||||||
|
'pdf_type': self.pdf_type,
|
||||||
|
'success': self.success,
|
||||||
|
'total_pages': self.total_pages,
|
||||||
|
'fields_matched': self.fields_matched,
|
||||||
|
'fields_total': self.fields_total,
|
||||||
|
'field_results': [r.to_dict() for r in self.field_results],
|
||||||
|
'annotations_generated': self.annotations_generated,
|
||||||
|
'image_paths': self.image_paths,
|
||||||
|
'label_paths': self.label_paths,
|
||||||
|
'processing_time_ms': self.processing_time_ms,
|
||||||
|
'timestamp': self.timestamp,
|
||||||
|
'errors': self.errors
|
||||||
|
}
|
||||||
|
|
||||||
|
def to_json(self, indent: int | None = None) -> str:
|
||||||
|
"""Convert to JSON string."""
|
||||||
|
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def match_rate(self) -> float:
|
||||||
|
"""Calculate field match rate."""
|
||||||
|
if self.fields_total == 0:
|
||||||
|
return 0.0
|
||||||
|
return self.fields_matched / self.fields_total
|
||||||
|
|
||||||
|
def get_summary(self) -> dict:
|
||||||
|
"""Get a summary of the report."""
|
||||||
|
return {
|
||||||
|
'document_id': self.document_id,
|
||||||
|
'success': self.success,
|
||||||
|
'match_rate': f"{self.match_rate:.1%}",
|
||||||
|
'fields': f"{self.fields_matched}/{self.fields_total}",
|
||||||
|
'annotations': self.annotations_generated,
|
||||||
|
'errors': len(self.errors)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ReportWriter:
|
||||||
|
"""Writes auto-label reports to file."""
|
||||||
|
|
||||||
|
def __init__(self, output_path: str | Path):
|
||||||
|
"""
|
||||||
|
Initialize report writer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_path: Path to output JSONL file
|
||||||
|
"""
|
||||||
|
self.output_path = Path(output_path)
|
||||||
|
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def write(self, report: AutoLabelReport) -> None:
|
||||||
|
"""Append a report to the output file."""
|
||||||
|
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(report.to_json() + '\n')
|
||||||
|
|
||||||
|
def write_dict(self, report_dict: dict) -> None:
|
||||||
|
"""Append a report dict to the output file (for parallel processing)."""
|
||||||
|
import json
|
||||||
|
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
def write_batch(self, reports: list[AutoLabelReport]) -> None:
|
||||||
|
"""Write multiple reports."""
|
||||||
|
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||||
|
for report in reports:
|
||||||
|
f.write(report.to_json() + '\n')
|
||||||
|
|
||||||
|
|
||||||
|
class ReportReader:
|
||||||
|
"""Reads auto-label reports from file."""
|
||||||
|
|
||||||
|
def __init__(self, input_path: str | Path):
|
||||||
|
"""
|
||||||
|
Initialize report reader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_path: Path to input JSONL file
|
||||||
|
"""
|
||||||
|
self.input_path = Path(input_path)
|
||||||
|
|
||||||
|
def read_all(self) -> list[AutoLabelReport]:
|
||||||
|
"""Read all reports from file."""
|
||||||
|
reports = []
|
||||||
|
|
||||||
|
if not self.input_path.exists():
|
||||||
|
return reports
|
||||||
|
|
||||||
|
with open(self.input_path, 'r', encoding='utf-8') as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data = json.loads(line)
|
||||||
|
report = self._dict_to_report(data)
|
||||||
|
reports.append(report)
|
||||||
|
|
||||||
|
return reports
|
||||||
|
|
||||||
|
def _dict_to_report(self, data: dict) -> AutoLabelReport:
|
||||||
|
"""Convert dictionary to AutoLabelReport."""
|
||||||
|
field_results = []
|
||||||
|
for fr_data in data.get('field_results', []):
|
||||||
|
bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None
|
||||||
|
field_results.append(FieldMatchResult(
|
||||||
|
field_name=fr_data['field_name'],
|
||||||
|
csv_value=fr_data.get('csv_value'),
|
||||||
|
matched=fr_data.get('matched', False),
|
||||||
|
score=fr_data.get('score', 0.0),
|
||||||
|
matched_text=fr_data.get('matched_text'),
|
||||||
|
candidate_used=fr_data.get('candidate_used'),
|
||||||
|
bbox=bbox,
|
||||||
|
page_no=fr_data.get('page_no', 0),
|
||||||
|
context_keywords=fr_data.get('context_keywords', []),
|
||||||
|
error=fr_data.get('error')
|
||||||
|
))
|
||||||
|
|
||||||
|
return AutoLabelReport(
|
||||||
|
document_id=data['document_id'],
|
||||||
|
pdf_path=data.get('pdf_path'),
|
||||||
|
pdf_type=data.get('pdf_type'),
|
||||||
|
success=data.get('success', False),
|
||||||
|
total_pages=data.get('total_pages', 0),
|
||||||
|
fields_matched=data.get('fields_matched', 0),
|
||||||
|
fields_total=data.get('fields_total', 0),
|
||||||
|
field_results=field_results,
|
||||||
|
annotations_generated=data.get('annotations_generated', 0),
|
||||||
|
image_paths=data.get('image_paths', []),
|
||||||
|
label_paths=data.get('label_paths', []),
|
||||||
|
processing_time_ms=data.get('processing_time_ms', 0.0),
|
||||||
|
timestamp=data.get('timestamp', ''),
|
||||||
|
errors=data.get('errors', [])
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_statistics(self) -> dict:
|
||||||
|
"""Calculate statistics from all reports."""
|
||||||
|
reports = self.read_all()
|
||||||
|
|
||||||
|
if not reports:
|
||||||
|
return {'total': 0}
|
||||||
|
|
||||||
|
successful = sum(1 for r in reports if r.success)
|
||||||
|
total_fields_matched = sum(r.fields_matched for r in reports)
|
||||||
|
total_fields = sum(r.fields_total for r in reports)
|
||||||
|
total_annotations = sum(r.annotations_generated for r in reports)
|
||||||
|
|
||||||
|
# Per-field statistics
|
||||||
|
field_stats = {}
|
||||||
|
for report in reports:
|
||||||
|
for fr in report.field_results:
|
||||||
|
if fr.field_name not in field_stats:
|
||||||
|
field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0}
|
||||||
|
field_stats[fr.field_name]['total'] += 1
|
||||||
|
if fr.matched:
|
||||||
|
field_stats[fr.field_name]['matched'] += 1
|
||||||
|
field_stats[fr.field_name]['avg_score'] += fr.score
|
||||||
|
|
||||||
|
# Calculate averages
|
||||||
|
for field_name, stats in field_stats.items():
|
||||||
|
if stats['matched'] > 0:
|
||||||
|
stats['avg_score'] /= stats['matched']
|
||||||
|
stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
'total': len(reports),
|
||||||
|
'successful': successful,
|
||||||
|
'success_rate': successful / len(reports),
|
||||||
|
'total_fields_matched': total_fields_matched,
|
||||||
|
'total_fields': total_fields,
|
||||||
|
'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0,
|
||||||
|
'total_annotations': total_annotations,
|
||||||
|
'field_statistics': field_stats
|
||||||
|
}
|
||||||
306
src/data/csv_loader.py
Normal file
306
src/data/csv_loader.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
"""
|
||||||
|
CSV Data Loader
|
||||||
|
|
||||||
|
Loads and parses structured invoice data from CSV files.
|
||||||
|
Follows the CSV specification for invoice data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, date
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterator
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvoiceRow:
|
||||||
|
"""Parsed invoice data row."""
|
||||||
|
DocumentId: str
|
||||||
|
InvoiceDate: date | None = None
|
||||||
|
InvoiceNumber: str | None = None
|
||||||
|
InvoiceDueDate: date | None = None
|
||||||
|
OCR: str | None = None
|
||||||
|
Message: str | None = None
|
||||||
|
Bankgiro: str | None = None
|
||||||
|
Plusgiro: str | None = None
|
||||||
|
Amount: Decimal | None = None
|
||||||
|
|
||||||
|
# Raw values for reference
|
||||||
|
raw_data: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Convert to dictionary for matching."""
|
||||||
|
return {
|
||||||
|
'DocumentId': self.DocumentId,
|
||||||
|
'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None,
|
||||||
|
'InvoiceNumber': self.InvoiceNumber,
|
||||||
|
'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None,
|
||||||
|
'OCR': self.OCR,
|
||||||
|
'Bankgiro': self.Bankgiro,
|
||||||
|
'Plusgiro': self.Plusgiro,
|
||||||
|
'Amount': str(self.Amount) if self.Amount else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_field_value(self, field_name: str) -> str | None:
|
||||||
|
"""Get field value as string for matching."""
|
||||||
|
value = getattr(self, field_name, None)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, date):
|
||||||
|
return value.isoformat()
|
||||||
|
if isinstance(value, Decimal):
|
||||||
|
return str(value)
|
||||||
|
return str(value) if value else None
|
||||||
|
|
||||||
|
|
||||||
|
class CSVLoader:
|
||||||
|
"""Loads invoice data from CSV files."""
|
||||||
|
|
||||||
|
# Expected field mappings (CSV header -> InvoiceRow attribute)
|
||||||
|
FIELD_MAPPINGS = {
|
||||||
|
'DocumentId': 'DocumentId',
|
||||||
|
'InvoiceDate': 'InvoiceDate',
|
||||||
|
'InvoiceNumber': 'InvoiceNumber',
|
||||||
|
'InvoiceDueDate': 'InvoiceDueDate',
|
||||||
|
'OCR': 'OCR',
|
||||||
|
'Message': 'Message',
|
||||||
|
'Bankgiro': 'Bankgiro',
|
||||||
|
'Plusgiro': 'Plusgiro',
|
||||||
|
'Amount': 'Amount',
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
csv_path: str | Path,
|
||||||
|
pdf_dir: str | Path | None = None,
|
||||||
|
doc_map_path: str | Path | None = None,
|
||||||
|
encoding: str = 'utf-8'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize CSV loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
csv_path: Path to the CSV file
|
||||||
|
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
|
||||||
|
doc_map_path: Optional path to document mapping CSV
|
||||||
|
encoding: CSV file encoding (default: utf-8)
|
||||||
|
"""
|
||||||
|
self.csv_path = Path(csv_path)
|
||||||
|
self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs'
|
||||||
|
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
|
||||||
|
self.encoding = encoding
|
||||||
|
|
||||||
|
# Load document mapping if provided
|
||||||
|
self.doc_map = self._load_doc_map() if self.doc_map_path else {}
|
||||||
|
|
||||||
|
def _load_doc_map(self) -> dict[str, str]:
|
||||||
|
"""Load document ID to filename mapping."""
|
||||||
|
mapping = {}
|
||||||
|
if self.doc_map_path and self.doc_map_path.exists():
|
||||||
|
with open(self.doc_map_path, 'r', encoding=self.encoding) as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
doc_id = row.get('DocumentId', '').strip()
|
||||||
|
filename = row.get('FileName', '').strip()
|
||||||
|
if doc_id and filename:
|
||||||
|
mapping[doc_id] = filename
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def _parse_date(self, value: str | None) -> date | None:
|
||||||
|
"""Parse date from various formats."""
|
||||||
|
if not value or not value.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Try different date formats
|
||||||
|
formats = [
|
||||||
|
'%Y-%m-%d',
|
||||||
|
'%Y-%m-%d %H:%M:%S',
|
||||||
|
'%Y-%m-%d %H:%M:%S.%f',
|
||||||
|
'%d/%m/%Y',
|
||||||
|
'%d.%m.%Y',
|
||||||
|
'%d-%m-%Y',
|
||||||
|
'%Y%m%d',
|
||||||
|
]
|
||||||
|
|
||||||
|
for fmt in formats:
|
||||||
|
try:
|
||||||
|
return datetime.strptime(value, fmt).date()
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_amount(self, value: str | None) -> Decimal | None:
|
||||||
|
"""Parse monetary amount from various formats."""
|
||||||
|
if not value or not value.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Remove currency symbols and common suffixes
|
||||||
|
value = value.replace('SEK', '').replace('kr', '').replace(':-', '')
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Remove spaces (thousand separators)
|
||||||
|
value = value.replace(' ', '').replace('\xa0', '')
|
||||||
|
|
||||||
|
# Handle comma as decimal separator (European format)
|
||||||
|
if ',' in value and '.' not in value:
|
||||||
|
value = value.replace(',', '.')
|
||||||
|
elif ',' in value and '.' in value:
|
||||||
|
# Assume comma is thousands separator, dot is decimal
|
||||||
|
value = value.replace(',', '')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return Decimal(value)
|
||||||
|
except InvalidOperation:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _parse_string(self, value: str | None) -> str | None:
|
||||||
|
"""Parse string field with cleanup."""
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
value = value.strip()
|
||||||
|
return value if value else None
|
||||||
|
|
||||||
|
def _parse_row(self, row: dict) -> InvoiceRow | None:
|
||||||
|
"""Parse a single CSV row into InvoiceRow."""
|
||||||
|
doc_id = self._parse_string(row.get('DocumentId'))
|
||||||
|
if not doc_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return InvoiceRow(
|
||||||
|
DocumentId=doc_id,
|
||||||
|
InvoiceDate=self._parse_date(row.get('InvoiceDate')),
|
||||||
|
InvoiceNumber=self._parse_string(row.get('InvoiceNumber')),
|
||||||
|
InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')),
|
||||||
|
OCR=self._parse_string(row.get('OCR')),
|
||||||
|
Message=self._parse_string(row.get('Message')),
|
||||||
|
Bankgiro=self._parse_string(row.get('Bankgiro')),
|
||||||
|
Plusgiro=self._parse_string(row.get('Plusgiro')),
|
||||||
|
Amount=self._parse_amount(row.get('Amount')),
|
||||||
|
raw_data=dict(row)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_all(self) -> list[InvoiceRow]:
|
||||||
|
"""Load all rows from CSV."""
|
||||||
|
rows = []
|
||||||
|
for row in self.iter_rows():
|
||||||
|
rows.append(row)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
def iter_rows(self) -> Iterator[InvoiceRow]:
|
||||||
|
"""Iterate over CSV rows."""
|
||||||
|
# Handle BOM - try utf-8-sig first to handle BOM correctly
|
||||||
|
encodings = ['utf-8-sig', self.encoding, 'latin-1']
|
||||||
|
|
||||||
|
for enc in encodings:
|
||||||
|
try:
|
||||||
|
with open(self.csv_path, 'r', encoding=enc) as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
parsed = self._parse_row(row)
|
||||||
|
if parsed:
|
||||||
|
yield parsed
|
||||||
|
return
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError(f"Could not read CSV file with any supported encoding")
|
||||||
|
|
||||||
|
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
|
||||||
|
"""
|
||||||
|
Get PDF path for an invoice row.
|
||||||
|
|
||||||
|
Uses document mapping if available, otherwise assumes
|
||||||
|
DocumentId.pdf naming convention.
|
||||||
|
"""
|
||||||
|
doc_id = invoice_row.DocumentId
|
||||||
|
|
||||||
|
# Check document mapping first
|
||||||
|
if doc_id in self.doc_map:
|
||||||
|
filename = self.doc_map[doc_id]
|
||||||
|
pdf_path = self.pdf_dir / filename
|
||||||
|
if pdf_path.exists():
|
||||||
|
return pdf_path
|
||||||
|
|
||||||
|
# Try default naming patterns
|
||||||
|
patterns = [
|
||||||
|
f"{doc_id}.pdf",
|
||||||
|
f"{doc_id.lower()}.pdf",
|
||||||
|
f"{doc_id.upper()}.pdf",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
pdf_path = self.pdf_dir / pattern
|
||||||
|
if pdf_path.exists():
|
||||||
|
return pdf_path
|
||||||
|
|
||||||
|
# Try glob patterns for partial matches
|
||||||
|
for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"):
|
||||||
|
return pdf_file
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_row_by_id(self, doc_id: str) -> InvoiceRow | None:
|
||||||
|
"""Get a specific row by DocumentId."""
|
||||||
|
for row in self.iter_rows():
|
||||||
|
if row.DocumentId == doc_id:
|
||||||
|
return row
|
||||||
|
return None
|
||||||
|
|
||||||
|
def validate(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Validate CSV data and return issues.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of validation issues
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1)
|
||||||
|
# Check required DocumentId
|
||||||
|
if not row.DocumentId:
|
||||||
|
issues.append({
|
||||||
|
'row': i,
|
||||||
|
'field': 'DocumentId',
|
||||||
|
'issue': 'Missing required DocumentId'
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if PDF exists
|
||||||
|
pdf_path = self.get_pdf_path(row)
|
||||||
|
if not pdf_path:
|
||||||
|
issues.append({
|
||||||
|
'row': i,
|
||||||
|
'doc_id': row.DocumentId,
|
||||||
|
'field': 'PDF',
|
||||||
|
'issue': 'PDF file not found'
|
||||||
|
})
|
||||||
|
|
||||||
|
# Check for at least one matchable field
|
||||||
|
matchable_fields = [
|
||||||
|
row.InvoiceNumber,
|
||||||
|
row.OCR,
|
||||||
|
row.Bankgiro,
|
||||||
|
row.Plusgiro,
|
||||||
|
row.Amount
|
||||||
|
]
|
||||||
|
if not any(matchable_fields):
|
||||||
|
issues.append({
|
||||||
|
'row': i,
|
||||||
|
'doc_id': row.DocumentId,
|
||||||
|
'field': 'All',
|
||||||
|
'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)'
|
||||||
|
})
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
|
||||||
|
def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
|
||||||
|
"""Convenience function to load invoice CSV."""
|
||||||
|
loader = CSVLoader(csv_path, pdf_dir)
|
||||||
|
return loader.load_all()
|
||||||
5
src/inference/__init__.py
Normal file
5
src/inference/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .pipeline import InferencePipeline, InferenceResult
|
||||||
|
from .yolo_detector import YOLODetector, Detection
|
||||||
|
from .field_extractor import FieldExtractor
|
||||||
|
|
||||||
|
__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor']
|
||||||
382
src/inference/field_extractor.py
Normal file
382
src/inference/field_extractor.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
Field Extractor Module
|
||||||
|
|
||||||
|
Extracts and validates field values from detected regions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
import re
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from .yolo_detector import Detection, CLASS_TO_FIELD
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtractedField:
|
||||||
|
"""Represents an extracted field value."""
|
||||||
|
field_name: str
|
||||||
|
raw_text: str
|
||||||
|
normalized_value: str | None
|
||||||
|
confidence: float
|
||||||
|
detection_confidence: float
|
||||||
|
ocr_confidence: float
|
||||||
|
bbox: tuple[float, float, float, float]
|
||||||
|
page_no: int
|
||||||
|
is_valid: bool = True
|
||||||
|
validation_error: str | None = None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
"""Convert to dictionary."""
|
||||||
|
return {
|
||||||
|
'field_name': self.field_name,
|
||||||
|
'value': self.normalized_value,
|
||||||
|
'raw_text': self.raw_text,
|
||||||
|
'confidence': self.confidence,
|
||||||
|
'bbox': list(self.bbox),
|
||||||
|
'page_no': self.page_no,
|
||||||
|
'is_valid': self.is_valid,
|
||||||
|
'validation_error': self.validation_error
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FieldExtractor:
|
||||||
|
"""Extracts field values from detected regions using OCR or PDF text."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ocr_lang: str = 'en',
|
||||||
|
use_gpu: bool = False,
|
||||||
|
bbox_padding: float = 0.1,
|
||||||
|
dpi: int = 300
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize field extractor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ocr_lang: Language for OCR
|
||||||
|
use_gpu: Whether to use GPU for OCR
|
||||||
|
bbox_padding: Padding to add around bboxes (as fraction)
|
||||||
|
dpi: DPI used for rendering (for coordinate conversion)
|
||||||
|
"""
|
||||||
|
self.ocr_lang = ocr_lang
|
||||||
|
self.use_gpu = use_gpu
|
||||||
|
self.bbox_padding = bbox_padding
|
||||||
|
self.dpi = dpi
|
||||||
|
self._ocr_engine = None # Lazy init
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ocr_engine(self):
|
||||||
|
"""Lazy-load OCR engine only when needed."""
|
||||||
|
if self._ocr_engine is None:
|
||||||
|
from ..ocr import OCREngine
|
||||||
|
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu)
|
||||||
|
return self._ocr_engine
|
||||||
|
|
||||||
|
def extract_from_detection_with_pdf(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
pdf_tokens: list,
|
||||||
|
image_width: int,
|
||||||
|
image_height: int
|
||||||
|
) -> ExtractedField:
|
||||||
|
"""
|
||||||
|
Extract field value using PDF text tokens (faster and more accurate for text PDFs).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection: Detection object with bbox in pixel coordinates
|
||||||
|
pdf_tokens: List of Token objects from PDF text extraction
|
||||||
|
image_width: Width of rendered image in pixels
|
||||||
|
image_height: Height of rendered image in pixels
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedField object
|
||||||
|
"""
|
||||||
|
# Convert detection bbox from pixels to PDF points
|
||||||
|
scale = 72 / self.dpi # points per pixel
|
||||||
|
x0_pdf = detection.bbox[0] * scale
|
||||||
|
y0_pdf = detection.bbox[1] * scale
|
||||||
|
x1_pdf = detection.bbox[2] * scale
|
||||||
|
y1_pdf = detection.bbox[3] * scale
|
||||||
|
|
||||||
|
# Add padding in points
|
||||||
|
pad = 3 # Small padding in points
|
||||||
|
|
||||||
|
# Find tokens that overlap with detection bbox
|
||||||
|
matching_tokens = []
|
||||||
|
for token in pdf_tokens:
|
||||||
|
if token.page_no != detection.page_no:
|
||||||
|
continue
|
||||||
|
tx0, ty0, tx1, ty1 = token.bbox
|
||||||
|
# Check overlap
|
||||||
|
if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and
|
||||||
|
ty0 < y1_pdf + pad and ty1 > y0_pdf - pad):
|
||||||
|
# Calculate overlap ratio to prioritize better matches
|
||||||
|
overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf)
|
||||||
|
overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf)
|
||||||
|
if overlap_x > 0 and overlap_y > 0:
|
||||||
|
token_area = (tx1 - tx0) * (ty1 - ty0)
|
||||||
|
overlap_area = overlap_x * overlap_y
|
||||||
|
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
||||||
|
matching_tokens.append((token, overlap_ratio))
|
||||||
|
|
||||||
|
# Sort by overlap ratio and combine text
|
||||||
|
matching_tokens.sort(key=lambda x: -x[1])
|
||||||
|
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
||||||
|
|
||||||
|
# Get field name
|
||||||
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||||
|
|
||||||
|
# Normalize and validate
|
||||||
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||||
|
field_name, raw_text
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractedField(
|
||||||
|
field_name=field_name,
|
||||||
|
raw_text=raw_text,
|
||||||
|
normalized_value=normalized_value,
|
||||||
|
confidence=detection.confidence if normalized_value else detection.confidence * 0.5,
|
||||||
|
detection_confidence=detection.confidence,
|
||||||
|
ocr_confidence=1.0, # PDF text is always accurate
|
||||||
|
bbox=detection.bbox,
|
||||||
|
page_no=detection.page_no,
|
||||||
|
is_valid=is_valid,
|
||||||
|
validation_error=validation_error
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_from_detection(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
image: np.ndarray | Image.Image
|
||||||
|
) -> ExtractedField:
|
||||||
|
"""
|
||||||
|
Extract field value from a detection region using OCR.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detection: Detection object
|
||||||
|
image: Full page image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExtractedField object
|
||||||
|
"""
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
image = np.array(image)
|
||||||
|
|
||||||
|
# Get padded bbox
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
bbox = detection.get_padded_bbox(self.bbox_padding, w, h)
|
||||||
|
|
||||||
|
# Crop region
|
||||||
|
x0, y0, x1, y1 = [int(v) for v in bbox]
|
||||||
|
region = image[y0:y1, x0:x1]
|
||||||
|
|
||||||
|
# Run OCR on region
|
||||||
|
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
||||||
|
|
||||||
|
# Combine all OCR text
|
||||||
|
raw_text = ' '.join(t.text for t in ocr_tokens)
|
||||||
|
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
||||||
|
|
||||||
|
# Get field name
|
||||||
|
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||||
|
|
||||||
|
# Normalize and validate
|
||||||
|
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||||
|
field_name, raw_text
|
||||||
|
)
|
||||||
|
|
||||||
|
# Combined confidence
|
||||||
|
confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5
|
||||||
|
|
||||||
|
return ExtractedField(
|
||||||
|
field_name=field_name,
|
||||||
|
raw_text=raw_text,
|
||||||
|
normalized_value=normalized_value,
|
||||||
|
confidence=confidence,
|
||||||
|
detection_confidence=detection.confidence,
|
||||||
|
ocr_confidence=ocr_confidence,
|
||||||
|
bbox=detection.bbox,
|
||||||
|
page_no=detection.page_no,
|
||||||
|
is_valid=is_valid,
|
||||||
|
validation_error=validation_error
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_and_validate(
|
||||||
|
self,
|
||||||
|
field_name: str,
|
||||||
|
raw_text: str
|
||||||
|
) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""
|
||||||
|
Normalize and validate extracted text for a field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(normalized_value, is_valid, validation_error)
|
||||||
|
"""
|
||||||
|
text = raw_text.strip()
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return None, False, "Empty text"
|
||||||
|
|
||||||
|
if field_name == 'InvoiceNumber':
|
||||||
|
return self._normalize_invoice_number(text)
|
||||||
|
|
||||||
|
elif field_name == 'OCR':
|
||||||
|
return self._normalize_ocr_number(text)
|
||||||
|
|
||||||
|
elif field_name == 'Bankgiro':
|
||||||
|
return self._normalize_bankgiro(text)
|
||||||
|
|
||||||
|
elif field_name == 'Plusgiro':
|
||||||
|
return self._normalize_plusgiro(text)
|
||||||
|
|
||||||
|
elif field_name == 'Amount':
|
||||||
|
return self._normalize_amount(text)
|
||||||
|
|
||||||
|
elif field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||||
|
return self._normalize_date(text)
|
||||||
|
|
||||||
|
else:
|
||||||
|
return text, True, None
|
||||||
|
|
||||||
|
def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize invoice number."""
|
||||||
|
# Extract digits only
|
||||||
|
digits = re.sub(r'\D', '', text)
|
||||||
|
|
||||||
|
if len(digits) < 3:
|
||||||
|
return None, False, f"Too few digits: {len(digits)}"
|
||||||
|
|
||||||
|
return digits, True, None
|
||||||
|
|
||||||
|
def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize OCR number."""
|
||||||
|
digits = re.sub(r'\D', '', text)
|
||||||
|
|
||||||
|
if len(digits) < 5:
|
||||||
|
return None, False, f"Too few digits for OCR: {len(digits)}"
|
||||||
|
|
||||||
|
return digits, True, None
|
||||||
|
|
||||||
|
def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize Bankgiro number."""
|
||||||
|
digits = re.sub(r'\D', '', text)
|
||||||
|
|
||||||
|
if len(digits) == 8:
|
||||||
|
# Format as XXXX-XXXX
|
||||||
|
formatted = f"{digits[:4]}-{digits[4:]}"
|
||||||
|
return formatted, True, None
|
||||||
|
elif len(digits) == 7:
|
||||||
|
# Format as XXX-XXXX
|
||||||
|
formatted = f"{digits[:3]}-{digits[3:]}"
|
||||||
|
return formatted, True, None
|
||||||
|
elif 6 <= len(digits) <= 9:
|
||||||
|
return digits, True, None
|
||||||
|
else:
|
||||||
|
return None, False, f"Invalid Bankgiro length: {len(digits)}"
|
||||||
|
|
||||||
|
def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize Plusgiro number."""
|
||||||
|
digits = re.sub(r'\D', '', text)
|
||||||
|
|
||||||
|
if len(digits) >= 6:
|
||||||
|
# Format as XXXXXXX-X
|
||||||
|
formatted = f"{digits[:-1]}-{digits[-1]}"
|
||||||
|
return formatted, True, None
|
||||||
|
else:
|
||||||
|
return None, False, f"Invalid Plusgiro length: {len(digits)}"
|
||||||
|
|
||||||
|
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize monetary amount."""
|
||||||
|
# Remove currency and common suffixes
|
||||||
|
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE)
|
||||||
|
text = text.replace(' ', '').replace('\xa0', '')
|
||||||
|
|
||||||
|
# Handle comma as decimal separator
|
||||||
|
if ',' in text and '.' not in text:
|
||||||
|
text = text.replace(',', '.')
|
||||||
|
|
||||||
|
# Try to parse as float
|
||||||
|
try:
|
||||||
|
amount = float(text)
|
||||||
|
return f"{amount:.2f}", True, None
|
||||||
|
except ValueError:
|
||||||
|
return None, False, f"Cannot parse amount: {text}"
|
||||||
|
|
||||||
|
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||||
|
"""Normalize date."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Common date patterns
|
||||||
|
patterns = [
|
||||||
|
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"),
|
||||||
|
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||||
|
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||||
|
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, formatter in patterns:
|
||||||
|
match = re.search(pattern, text)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
date_str = formatter(match)
|
||||||
|
# Validate date
|
||||||
|
datetime.strptime(date_str, '%Y-%m-%d')
|
||||||
|
return date_str, True, None
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return None, False, f"Cannot parse date: {text}"
|
||||||
|
|
||||||
|
def extract_all_fields(
|
||||||
|
self,
|
||||||
|
detections: list[Detection],
|
||||||
|
image: np.ndarray | Image.Image
|
||||||
|
) -> list[ExtractedField]:
|
||||||
|
"""
|
||||||
|
Extract fields from all detections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detections: List of detections
|
||||||
|
image: Full page image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ExtractedField objects
|
||||||
|
"""
|
||||||
|
fields = []
|
||||||
|
|
||||||
|
for detection in detections:
|
||||||
|
field = self.extract_from_detection(detection, image)
|
||||||
|
fields.append(field)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
Infer OCR field from InvoiceNumber if not detected.
|
||||||
|
|
||||||
|
In Swedish invoices, OCR reference number is often identical to InvoiceNumber.
|
||||||
|
When OCR is not detected but InvoiceNumber is, we can infer OCR value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fields: Dict of field_name -> normalized_value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated fields dict with inferred OCR if applicable
|
||||||
|
"""
|
||||||
|
# If OCR already exists, no need to infer
|
||||||
|
if fields.get('OCR'):
|
||||||
|
return fields
|
||||||
|
|
||||||
|
# If InvoiceNumber exists and is numeric, use it as OCR
|
||||||
|
invoice_number = fields.get('InvoiceNumber')
|
||||||
|
if invoice_number:
|
||||||
|
# Check if it's mostly digits (valid OCR reference)
|
||||||
|
digits_only = re.sub(r'\D', '', invoice_number)
|
||||||
|
if len(digits_only) >= 5 and len(digits_only) == len(invoice_number):
|
||||||
|
fields['OCR'] = invoice_number
|
||||||
|
|
||||||
|
return fields
|
||||||
297
src/inference/pipeline.py
Normal file
297
src/inference/pipeline.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""
|
||||||
|
Inference Pipeline
|
||||||
|
|
||||||
|
Complete pipeline for extracting invoice data from PDFs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD
|
||||||
|
from .field_extractor import FieldExtractor, ExtractedField
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceResult:
|
||||||
|
"""Result of invoice processing."""
|
||||||
|
document_id: str | None = None
|
||||||
|
success: bool = False
|
||||||
|
fields: dict[str, Any] = field(default_factory=dict)
|
||||||
|
confidence: dict[str, float] = field(default_factory=dict)
|
||||||
|
raw_detections: list[Detection] = field(default_factory=list)
|
||||||
|
extracted_fields: list[ExtractedField] = field(default_factory=list)
|
||||||
|
processing_time_ms: float = 0.0
|
||||||
|
errors: list[str] = field(default_factory=list)
|
||||||
|
fallback_used: bool = False
|
||||||
|
|
||||||
|
def to_json(self) -> dict:
|
||||||
|
"""Convert to JSON-serializable dictionary."""
|
||||||
|
return {
|
||||||
|
'DocumentId': self.document_id,
|
||||||
|
'InvoiceNumber': self.fields.get('InvoiceNumber'),
|
||||||
|
'InvoiceDate': self.fields.get('InvoiceDate'),
|
||||||
|
'InvoiceDueDate': self.fields.get('InvoiceDueDate'),
|
||||||
|
'OCR': self.fields.get('OCR'),
|
||||||
|
'Bankgiro': self.fields.get('Bankgiro'),
|
||||||
|
'Plusgiro': self.fields.get('Plusgiro'),
|
||||||
|
'Amount': self.fields.get('Amount'),
|
||||||
|
'confidence': self.confidence,
|
||||||
|
'success': self.success,
|
||||||
|
'fallback_used': self.fallback_used
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_field(self, field_name: str) -> tuple[Any, float]:
|
||||||
|
"""Get field value and confidence."""
|
||||||
|
return self.fields.get(field_name), self.confidence.get(field_name, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class InferencePipeline:
|
||||||
|
"""
|
||||||
|
Complete inference pipeline for invoice data extraction.
|
||||||
|
|
||||||
|
Pipeline flow:
|
||||||
|
1. PDF -> Image rendering
|
||||||
|
2. YOLO detection of field regions
|
||||||
|
3. OCR extraction from detected regions
|
||||||
|
4. Field normalization and validation
|
||||||
|
5. Fallback to full-page OCR if YOLO fails
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str | Path,
|
||||||
|
confidence_threshold: float = 0.5,
|
||||||
|
ocr_lang: str = 'en',
|
||||||
|
use_gpu: bool = False,
|
||||||
|
dpi: int = 300,
|
||||||
|
enable_fallback: bool = True
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize inference pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to trained YOLO model
|
||||||
|
confidence_threshold: Detection confidence threshold
|
||||||
|
ocr_lang: Language for OCR
|
||||||
|
use_gpu: Whether to use GPU
|
||||||
|
dpi: Resolution for PDF rendering
|
||||||
|
enable_fallback: Enable fallback to full-page OCR
|
||||||
|
"""
|
||||||
|
self.detector = YOLODetector(
|
||||||
|
model_path,
|
||||||
|
confidence_threshold=confidence_threshold,
|
||||||
|
device='cuda' if use_gpu else 'cpu'
|
||||||
|
)
|
||||||
|
self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu)
|
||||||
|
self.dpi = dpi
|
||||||
|
self.enable_fallback = enable_fallback
|
||||||
|
|
||||||
|
def process_pdf(
|
||||||
|
self,
|
||||||
|
pdf_path: str | Path,
|
||||||
|
document_id: str | None = None
|
||||||
|
) -> InferenceResult:
|
||||||
|
"""
|
||||||
|
Process a PDF and extract invoice fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to PDF file
|
||||||
|
document_id: Optional document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InferenceResult with extracted fields
|
||||||
|
"""
|
||||||
|
from ..pdf.renderer import render_pdf_to_images
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
result = InferenceResult(
|
||||||
|
document_id=document_id or Path(pdf_path).stem
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
all_detections = []
|
||||||
|
all_extracted = []
|
||||||
|
|
||||||
|
# Process each page
|
||||||
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||||
|
# Convert to numpy array
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
# Run YOLO detection
|
||||||
|
detections = self.detector.detect(image_array, page_no=page_no)
|
||||||
|
all_detections.extend(detections)
|
||||||
|
|
||||||
|
# Extract fields from detections
|
||||||
|
for detection in detections:
|
||||||
|
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||||
|
all_extracted.append(extracted)
|
||||||
|
|
||||||
|
result.raw_detections = all_detections
|
||||||
|
result.extracted_fields = all_extracted
|
||||||
|
|
||||||
|
# Merge extracted fields (prefer highest confidence)
|
||||||
|
self._merge_fields(result)
|
||||||
|
|
||||||
|
# Fallback if key fields are missing
|
||||||
|
if self.enable_fallback and self._needs_fallback(result):
|
||||||
|
self._run_fallback(pdf_path, result)
|
||||||
|
|
||||||
|
result.success = len(result.fields) > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(str(e))
|
||||||
|
result.success = False
|
||||||
|
|
||||||
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _merge_fields(self, result: InferenceResult) -> None:
|
||||||
|
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||||
|
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||||
|
|
||||||
|
for extracted in result.extracted_fields:
|
||||||
|
if not extracted.is_valid or not extracted.normalized_value:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if extracted.field_name not in field_candidates:
|
||||||
|
field_candidates[extracted.field_name] = []
|
||||||
|
field_candidates[extracted.field_name].append(extracted)
|
||||||
|
|
||||||
|
# Select best candidate for each field
|
||||||
|
for field_name, candidates in field_candidates.items():
|
||||||
|
best = max(candidates, key=lambda x: x.confidence)
|
||||||
|
result.fields[field_name] = best.normalized_value
|
||||||
|
result.confidence[field_name] = best.confidence
|
||||||
|
|
||||||
|
def _needs_fallback(self, result: InferenceResult) -> bool:
|
||||||
|
"""Check if fallback OCR is needed."""
|
||||||
|
# Check for key fields
|
||||||
|
key_fields = ['Amount', 'InvoiceNumber', 'OCR']
|
||||||
|
missing = sum(1 for f in key_fields if f not in result.fields)
|
||||||
|
return missing >= 2 # Fallback if 2+ key fields missing
|
||||||
|
|
||||||
|
def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None:
|
||||||
|
"""Run full-page OCR fallback."""
|
||||||
|
from ..pdf.renderer import render_pdf_to_images
|
||||||
|
from ..ocr import OCREngine
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
result.fallback_used = True
|
||||||
|
ocr_engine = OCREngine()
|
||||||
|
|
||||||
|
try:
|
||||||
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
# Full page OCR
|
||||||
|
tokens = ocr_engine.extract_from_image(image_array, page_no)
|
||||||
|
full_text = ' '.join(t.text for t in tokens)
|
||||||
|
|
||||||
|
# Try to extract missing fields with regex patterns
|
||||||
|
self._extract_with_patterns(full_text, result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Fallback OCR error: {e}")
|
||||||
|
|
||||||
|
def _extract_with_patterns(self, text: str, result: InferenceResult) -> None:
|
||||||
|
"""Extract fields using regex patterns (fallback)."""
|
||||||
|
patterns = {
|
||||||
|
'Amount': [
|
||||||
|
r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?',
|
||||||
|
r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$',
|
||||||
|
],
|
||||||
|
'Bankgiro': [
|
||||||
|
r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})',
|
||||||
|
r'(\d{4}[-\s]\d{4})\s*(?=\s|$)',
|
||||||
|
],
|
||||||
|
'OCR': [
|
||||||
|
r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})',
|
||||||
|
],
|
||||||
|
'InvoiceNumber': [
|
||||||
|
r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)',
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
for field_name, field_patterns in patterns.items():
|
||||||
|
if field_name in result.fields:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for pattern in field_patterns:
|
||||||
|
match = re.search(pattern, text, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
value = match.group(1).strip()
|
||||||
|
|
||||||
|
# Normalize the value
|
||||||
|
if field_name == 'Amount':
|
||||||
|
value = value.replace(' ', '').replace(',', '.')
|
||||||
|
try:
|
||||||
|
value = f"{float(value):.2f}"
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
elif field_name == 'Bankgiro':
|
||||||
|
digits = re.sub(r'\D', '', value)
|
||||||
|
if len(digits) == 8:
|
||||||
|
value = f"{digits[:4]}-{digits[4:]}"
|
||||||
|
|
||||||
|
result.fields[field_name] = value
|
||||||
|
result.confidence[field_name] = 0.5 # Lower confidence for regex
|
||||||
|
break
|
||||||
|
|
||||||
|
def process_image(
|
||||||
|
self,
|
||||||
|
image_path: str | Path,
|
||||||
|
document_id: str | None = None
|
||||||
|
) -> InferenceResult:
|
||||||
|
"""
|
||||||
|
Process a single image (for pre-rendered pages).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to image file
|
||||||
|
document_id: Optional document ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
InferenceResult with extracted fields
|
||||||
|
"""
|
||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
result = InferenceResult(
|
||||||
|
document_id=document_id or Path(image_path).stem
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
# Run detection
|
||||||
|
detections = self.detector.detect(image_array, page_no=0)
|
||||||
|
result.raw_detections = detections
|
||||||
|
|
||||||
|
# Extract fields
|
||||||
|
for detection in detections:
|
||||||
|
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||||
|
result.extracted_fields.append(extracted)
|
||||||
|
|
||||||
|
# Merge fields
|
||||||
|
self._merge_fields(result)
|
||||||
|
result.success = len(result.fields) > 0
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(str(e))
|
||||||
|
result.success = False
|
||||||
|
|
||||||
|
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||||
|
return result
|
||||||
204
src/inference/yolo_detector.py
Normal file
204
src/inference/yolo_detector.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
YOLO Detection Module
|
||||||
|
|
||||||
|
Runs YOLO model inference for field detection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Detection:
|
||||||
|
"""Represents a single YOLO detection."""
|
||||||
|
class_id: int
|
||||||
|
class_name: str
|
||||||
|
confidence: float
|
||||||
|
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels
|
||||||
|
page_no: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x0(self) -> float:
|
||||||
|
return self.bbox[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y0(self) -> float:
|
||||||
|
return self.bbox[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x1(self) -> float:
|
||||||
|
return self.bbox[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y1(self) -> float:
|
||||||
|
return self.bbox[3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center(self) -> tuple[float, float]:
|
||||||
|
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self) -> float:
|
||||||
|
return self.x1 - self.x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self) -> float:
|
||||||
|
return self.y1 - self.y0
|
||||||
|
|
||||||
|
def get_padded_bbox(
|
||||||
|
self,
|
||||||
|
padding: float = 0.1,
|
||||||
|
image_width: float | None = None,
|
||||||
|
image_height: float | None = None
|
||||||
|
) -> tuple[float, float, float, float]:
|
||||||
|
"""Get bbox with padding for OCR extraction."""
|
||||||
|
pad_x = self.width * padding
|
||||||
|
pad_y = self.height * padding
|
||||||
|
|
||||||
|
x0 = self.x0 - pad_x
|
||||||
|
y0 = self.y0 - pad_y
|
||||||
|
x1 = self.x1 + pad_x
|
||||||
|
y1 = self.y1 + pad_y
|
||||||
|
|
||||||
|
if image_width:
|
||||||
|
x0 = max(0, x0)
|
||||||
|
x1 = min(image_width, x1)
|
||||||
|
if image_height:
|
||||||
|
y0 = max(0, y0)
|
||||||
|
y1 = min(image_height, y1)
|
||||||
|
|
||||||
|
return (x0, y0, x1, y1)
|
||||||
|
|
||||||
|
|
||||||
|
# Class names (must match training configuration)
|
||||||
|
CLASS_NAMES = [
|
||||||
|
'invoice_number',
|
||||||
|
'invoice_date',
|
||||||
|
'invoice_due_date',
|
||||||
|
'ocr_number',
|
||||||
|
'bankgiro',
|
||||||
|
'plusgiro',
|
||||||
|
'amount',
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mapping from class name to field name
|
||||||
|
CLASS_TO_FIELD = {
|
||||||
|
'invoice_number': 'InvoiceNumber',
|
||||||
|
'invoice_date': 'InvoiceDate',
|
||||||
|
'invoice_due_date': 'InvoiceDueDate',
|
||||||
|
'ocr_number': 'OCR',
|
||||||
|
'bankgiro': 'Bankgiro',
|
||||||
|
'plusgiro': 'Plusgiro',
|
||||||
|
'amount': 'Amount',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class YOLODetector:
|
||||||
|
"""YOLO model wrapper for field detection."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_path: str | Path,
|
||||||
|
confidence_threshold: float = 0.5,
|
||||||
|
iou_threshold: float = 0.45,
|
||||||
|
device: str = 'auto'
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize YOLO detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to trained YOLO model (.pt file)
|
||||||
|
confidence_threshold: Minimum confidence for detections
|
||||||
|
iou_threshold: IOU threshold for NMS
|
||||||
|
device: Device to run on ('auto', 'cpu', 'cuda', 'mps')
|
||||||
|
"""
|
||||||
|
from ultralytics import YOLO
|
||||||
|
|
||||||
|
self.model = YOLO(model_path)
|
||||||
|
self.confidence_threshold = confidence_threshold
|
||||||
|
self.iou_threshold = iou_threshold
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def detect(
|
||||||
|
self,
|
||||||
|
image: str | Path | np.ndarray,
|
||||||
|
page_no: int = 0
|
||||||
|
) -> list[Detection]:
|
||||||
|
"""
|
||||||
|
Run detection on an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image path or numpy array
|
||||||
|
page_no: Page number for reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Detection objects
|
||||||
|
"""
|
||||||
|
results = self.model.predict(
|
||||||
|
source=image,
|
||||||
|
conf=self.confidence_threshold,
|
||||||
|
iou=self.iou_threshold,
|
||||||
|
device=self.device,
|
||||||
|
verbose=False
|
||||||
|
)
|
||||||
|
|
||||||
|
detections = []
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
boxes = result.boxes
|
||||||
|
if boxes is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for i in range(len(boxes)):
|
||||||
|
class_id = int(boxes.cls[i])
|
||||||
|
confidence = float(boxes.conf[i])
|
||||||
|
bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1]
|
||||||
|
|
||||||
|
class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}"
|
||||||
|
|
||||||
|
detections.append(Detection(
|
||||||
|
class_id=class_id,
|
||||||
|
class_name=class_name,
|
||||||
|
confidence=confidence,
|
||||||
|
bbox=tuple(bbox),
|
||||||
|
page_no=page_no
|
||||||
|
))
|
||||||
|
|
||||||
|
return detections
|
||||||
|
|
||||||
|
def detect_pdf(
|
||||||
|
self,
|
||||||
|
pdf_path: str | Path,
|
||||||
|
dpi: int = 300
|
||||||
|
) -> dict[int, list[Detection]]:
|
||||||
|
"""
|
||||||
|
Run detection on all pages of a PDF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to PDF file
|
||||||
|
dpi: Resolution for rendering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping page number to list of detections
|
||||||
|
"""
|
||||||
|
from ..pdf.renderer import render_pdf_to_images
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
|
||||||
|
# Convert bytes to numpy array
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
detections = self.detect(image_array, page_no=page_no)
|
||||||
|
results[page_no] = detections
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def get_field_name(self, class_name: str) -> str:
|
||||||
|
"""Convert class name to field name."""
|
||||||
|
return CLASS_TO_FIELD.get(class_name, class_name)
|
||||||
3
src/matcher/__init__.py
Normal file
3
src/matcher/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .field_matcher import FieldMatcher, Match, find_field_matches
|
||||||
|
|
||||||
|
__all__ = ['FieldMatcher', 'Match', 'find_field_matches']
|
||||||
618
src/matcher/field_matcher.py
Normal file
618
src/matcher/field_matcher.py
Normal file
@@ -0,0 +1,618 @@
|
|||||||
|
"""
|
||||||
|
Field Matching Module
|
||||||
|
|
||||||
|
Matches normalized field values to tokens extracted from documents.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Protocol
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class TokenLike(Protocol):
|
||||||
|
"""Protocol for token objects."""
|
||||||
|
text: str
|
||||||
|
bbox: tuple[float, float, float, float]
|
||||||
|
page_no: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Match:
|
||||||
|
"""Represents a matched field in the document."""
|
||||||
|
field: str
|
||||||
|
value: str
|
||||||
|
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||||
|
page_no: int
|
||||||
|
score: float # 0-1 confidence score
|
||||||
|
matched_text: str # Actual text that matched
|
||||||
|
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||||
|
|
||||||
|
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||||
|
"""Convert to YOLO annotation format."""
|
||||||
|
x0, y0, x1, y1 = self.bbox
|
||||||
|
|
||||||
|
x_center = (x0 + x1) / 2 / image_width
|
||||||
|
y_center = (y0 + y1) / 2 / image_height
|
||||||
|
width = (x1 - x0) / image_width
|
||||||
|
height = (y1 - y0) / image_height
|
||||||
|
|
||||||
|
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||||
|
|
||||||
|
|
||||||
|
# Context keywords for each field type (Swedish invoice terms)
|
||||||
|
CONTEXT_KEYWORDS = {
|
||||||
|
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||||
|
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||||
|
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||||
|
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||||
|
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||||
|
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||||
|
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||||
|
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class FieldMatcher:
|
||||||
|
"""Matches field values to document tokens."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_radius: float = 100.0, # pixels
|
||||||
|
min_score_threshold: float = 0.5
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the matcher.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_radius: Distance to search for context keywords
|
||||||
|
min_score_threshold: Minimum score to consider a match valid
|
||||||
|
"""
|
||||||
|
self.context_radius = context_radius
|
||||||
|
self.min_score_threshold = min_score_threshold
|
||||||
|
|
||||||
|
def find_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
field_name: str,
|
||||||
|
normalized_values: list[str],
|
||||||
|
page_no: int = 0
|
||||||
|
) -> list[Match]:
|
||||||
|
"""
|
||||||
|
Find all matches for a field in the token list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: List of tokens from the document
|
||||||
|
field_name: Name of the field to match
|
||||||
|
normalized_values: List of normalized value variants to search for
|
||||||
|
page_no: Page number to filter tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Match objects sorted by score (descending)
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
page_tokens = [t for t in tokens if t.page_no == page_no]
|
||||||
|
|
||||||
|
for value in normalized_values:
|
||||||
|
# Strategy 1: Exact token match
|
||||||
|
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||||
|
matches.extend(exact_matches)
|
||||||
|
|
||||||
|
# Strategy 2: Multi-token concatenation
|
||||||
|
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
|
||||||
|
matches.extend(concat_matches)
|
||||||
|
|
||||||
|
# Strategy 3: Fuzzy match (for amounts and dates only)
|
||||||
|
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
|
||||||
|
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
|
||||||
|
matches.extend(fuzzy_matches)
|
||||||
|
|
||||||
|
# Strategy 4: Substring match (for dates embedded in longer text)
|
||||||
|
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||||
|
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||||
|
matches.extend(substring_matches)
|
||||||
|
|
||||||
|
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
|
||||||
|
# Only if no exact matches found for date fields
|
||||||
|
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
|
||||||
|
flexible_matches = self._find_flexible_date_matches(
|
||||||
|
page_tokens, normalized_values, field_name
|
||||||
|
)
|
||||||
|
matches.extend(flexible_matches)
|
||||||
|
|
||||||
|
# Deduplicate and sort by score
|
||||||
|
matches = self._deduplicate_matches(matches)
|
||||||
|
matches.sort(key=lambda m: m.score, reverse=True)
|
||||||
|
|
||||||
|
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||||
|
|
||||||
|
def _find_exact_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
value: str,
|
||||||
|
field_name: str
|
||||||
|
) -> list[Match]:
|
||||||
|
"""Find tokens that exactly match the value."""
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
token_text = token.text.strip()
|
||||||
|
|
||||||
|
# Exact match
|
||||||
|
if token_text == value:
|
||||||
|
score = 1.0
|
||||||
|
# Case-insensitive match
|
||||||
|
elif token_text.lower() == value.lower():
|
||||||
|
score = 0.95
|
||||||
|
# Digits-only match for numeric fields
|
||||||
|
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
|
||||||
|
token_digits = re.sub(r'\D', '', token_text)
|
||||||
|
value_digits = re.sub(r'\D', '', value)
|
||||||
|
if token_digits and token_digits == value_digits:
|
||||||
|
score = 0.9
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Boost score if context keywords are nearby
|
||||||
|
context_keywords, context_boost = self._find_context_keywords(
|
||||||
|
tokens, token, field_name
|
||||||
|
)
|
||||||
|
score = min(1.0, score + context_boost)
|
||||||
|
|
||||||
|
matches.append(Match(
|
||||||
|
field=field_name,
|
||||||
|
value=value,
|
||||||
|
bbox=token.bbox,
|
||||||
|
page_no=token.page_no,
|
||||||
|
score=score,
|
||||||
|
matched_text=token_text,
|
||||||
|
context_keywords=context_keywords
|
||||||
|
))
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _find_concatenated_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
value: str,
|
||||||
|
field_name: str
|
||||||
|
) -> list[Match]:
|
||||||
|
"""Find value by concatenating adjacent tokens."""
|
||||||
|
matches = []
|
||||||
|
value_clean = re.sub(r'\s+', '', value)
|
||||||
|
|
||||||
|
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||||
|
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||||
|
|
||||||
|
for i, start_token in enumerate(sorted_tokens):
|
||||||
|
# Try to build the value by concatenating nearby tokens
|
||||||
|
concat_text = start_token.text.strip()
|
||||||
|
concat_bbox = list(start_token.bbox)
|
||||||
|
used_tokens = [start_token]
|
||||||
|
|
||||||
|
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||||
|
next_token = sorted_tokens[j]
|
||||||
|
|
||||||
|
# Check if tokens are on the same line (y overlap)
|
||||||
|
if not self._tokens_on_same_line(start_token, next_token):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check horizontal proximity
|
||||||
|
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||||
|
break
|
||||||
|
|
||||||
|
concat_text += next_token.text.strip()
|
||||||
|
used_tokens.append(next_token)
|
||||||
|
|
||||||
|
# Update bounding box
|
||||||
|
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||||
|
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||||
|
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||||
|
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||||
|
|
||||||
|
# Check for match
|
||||||
|
concat_clean = re.sub(r'\s+', '', concat_text)
|
||||||
|
if concat_clean == value_clean:
|
||||||
|
context_keywords, context_boost = self._find_context_keywords(
|
||||||
|
tokens, start_token, field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(Match(
|
||||||
|
field=field_name,
|
||||||
|
value=value,
|
||||||
|
bbox=tuple(concat_bbox),
|
||||||
|
page_no=start_token.page_no,
|
||||||
|
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||||
|
matched_text=concat_text,
|
||||||
|
context_keywords=context_keywords
|
||||||
|
))
|
||||||
|
break
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _find_substring_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
value: str,
|
||||||
|
field_name: str
|
||||||
|
) -> list[Match]:
|
||||||
|
"""
|
||||||
|
Find value as a substring within longer tokens.
|
||||||
|
|
||||||
|
Handles cases like 'Fakturadatum: 2026-01-09' where the date
|
||||||
|
is embedded in a longer text string.
|
||||||
|
|
||||||
|
Uses lower score (0.75) than exact match to prefer exact matches.
|
||||||
|
Only matches if the value appears as a distinct segment (not part of a number).
|
||||||
|
"""
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# Only use for date fields - other fields risk false positives
|
||||||
|
if field_name not in ('InvoiceDate', 'InvoiceDueDate'):
|
||||||
|
return matches
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
token_text = token.text.strip()
|
||||||
|
|
||||||
|
# Skip if token is the same length as value (would be exact match)
|
||||||
|
if len(token_text) <= len(value):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if value appears as substring
|
||||||
|
if value in token_text:
|
||||||
|
# Verify it's a proper boundary match (not part of a larger number)
|
||||||
|
idx = token_text.find(value)
|
||||||
|
|
||||||
|
# Check character before (if exists)
|
||||||
|
if idx > 0:
|
||||||
|
char_before = token_text[idx - 1]
|
||||||
|
# Must be non-digit (allow : space - etc)
|
||||||
|
if char_before.isdigit():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check character after (if exists)
|
||||||
|
end_idx = idx + len(value)
|
||||||
|
if end_idx < len(token_text):
|
||||||
|
char_after = token_text[end_idx]
|
||||||
|
# Must be non-digit
|
||||||
|
if char_after.isdigit():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Found valid substring match
|
||||||
|
context_keywords, context_boost = self._find_context_keywords(
|
||||||
|
tokens, token, field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||||
|
token_lower = token_text.lower()
|
||||||
|
inline_context = []
|
||||||
|
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||||
|
if keyword in token_lower:
|
||||||
|
inline_context.append(keyword)
|
||||||
|
|
||||||
|
# Boost score if keyword is inline
|
||||||
|
inline_boost = 0.1 if inline_context else 0
|
||||||
|
|
||||||
|
matches.append(Match(
|
||||||
|
field=field_name,
|
||||||
|
value=value,
|
||||||
|
bbox=token.bbox, # Use full token bbox
|
||||||
|
page_no=token.page_no,
|
||||||
|
score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match
|
||||||
|
matched_text=token_text,
|
||||||
|
context_keywords=context_keywords + inline_context
|
||||||
|
))
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _find_fuzzy_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
value: str,
|
||||||
|
field_name: str
|
||||||
|
) -> list[Match]:
|
||||||
|
"""Find approximate matches for amounts and dates."""
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
token_text = token.text.strip()
|
||||||
|
|
||||||
|
if field_name == 'Amount':
|
||||||
|
# Try to parse both as numbers
|
||||||
|
try:
|
||||||
|
token_num = self._parse_amount(token_text)
|
||||||
|
value_num = self._parse_amount(value)
|
||||||
|
|
||||||
|
if token_num is not None and value_num is not None:
|
||||||
|
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||||
|
context_keywords, context_boost = self._find_context_keywords(
|
||||||
|
tokens, token, field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(Match(
|
||||||
|
field=field_name,
|
||||||
|
value=value,
|
||||||
|
bbox=token.bbox,
|
||||||
|
page_no=token.page_no,
|
||||||
|
score=min(1.0, 0.8 + context_boost),
|
||||||
|
matched_text=token_text,
|
||||||
|
context_keywords=context_keywords
|
||||||
|
))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
def _find_flexible_date_matches(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
normalized_values: list[str],
|
||||||
|
field_name: str
|
||||||
|
) -> list[Match]:
|
||||||
|
"""
|
||||||
|
Flexible date matching when exact match fails.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||||
|
2. Nearby date match: Match dates within 7 days of CSV value
|
||||||
|
3. Heuristic selection: Use context keywords to select the best date
|
||||||
|
|
||||||
|
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||||
|
but we can still find a reasonable date to label.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
# Parse the target date from normalized values
|
||||||
|
target_date = None
|
||||||
|
for value in normalized_values:
|
||||||
|
# Try to parse YYYY-MM-DD format
|
||||||
|
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||||
|
if date_match:
|
||||||
|
try:
|
||||||
|
target_date = datetime(
|
||||||
|
int(date_match.group(1)),
|
||||||
|
int(date_match.group(2)),
|
||||||
|
int(date_match.group(3))
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not target_date:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Find all date-like tokens in the document
|
||||||
|
date_candidates = []
|
||||||
|
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
token_text = token.text.strip()
|
||||||
|
|
||||||
|
# Search for date pattern in token
|
||||||
|
for match in date_pattern.finditer(token_text):
|
||||||
|
try:
|
||||||
|
found_date = datetime(
|
||||||
|
int(match.group(1)),
|
||||||
|
int(match.group(2)),
|
||||||
|
int(match.group(3))
|
||||||
|
)
|
||||||
|
date_str = match.group(0)
|
||||||
|
|
||||||
|
# Calculate date difference
|
||||||
|
days_diff = abs((found_date - target_date).days)
|
||||||
|
|
||||||
|
# Check for context keywords
|
||||||
|
context_keywords, context_boost = self._find_context_keywords(
|
||||||
|
tokens, token, field_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if keyword is in the same token
|
||||||
|
token_lower = token_text.lower()
|
||||||
|
inline_keywords = []
|
||||||
|
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||||
|
if keyword in token_lower:
|
||||||
|
inline_keywords.append(keyword)
|
||||||
|
|
||||||
|
date_candidates.append({
|
||||||
|
'token': token,
|
||||||
|
'date': found_date,
|
||||||
|
'date_str': date_str,
|
||||||
|
'matched_text': token_text,
|
||||||
|
'days_diff': days_diff,
|
||||||
|
'context_keywords': context_keywords + inline_keywords,
|
||||||
|
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||||
|
'same_year_month': (found_date.year == target_date.year and
|
||||||
|
found_date.month == target_date.month),
|
||||||
|
})
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not date_candidates:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
# Score and rank candidates
|
||||||
|
for candidate in date_candidates:
|
||||||
|
score = 0.0
|
||||||
|
|
||||||
|
# Strategy 1: Same year-month gets higher score
|
||||||
|
if candidate['same_year_month']:
|
||||||
|
score = 0.7
|
||||||
|
# Bonus if day is close
|
||||||
|
if candidate['days_diff'] <= 7:
|
||||||
|
score = 0.75
|
||||||
|
if candidate['days_diff'] <= 3:
|
||||||
|
score = 0.8
|
||||||
|
# Strategy 2: Nearby dates (within 14 days)
|
||||||
|
elif candidate['days_diff'] <= 14:
|
||||||
|
score = 0.6
|
||||||
|
elif candidate['days_diff'] <= 30:
|
||||||
|
score = 0.55
|
||||||
|
else:
|
||||||
|
# Too far apart, skip unless has strong context
|
||||||
|
if not candidate['context_keywords']:
|
||||||
|
continue
|
||||||
|
score = 0.5
|
||||||
|
|
||||||
|
# Strategy 3: Boost with context keywords
|
||||||
|
score = min(1.0, score + candidate['context_boost'])
|
||||||
|
|
||||||
|
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||||
|
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||||
|
if candidate['context_keywords']:
|
||||||
|
score = min(1.0, score + 0.05)
|
||||||
|
|
||||||
|
if score >= self.min_score_threshold:
|
||||||
|
matches.append(Match(
|
||||||
|
field=field_name,
|
||||||
|
value=candidate['date_str'],
|
||||||
|
bbox=candidate['token'].bbox,
|
||||||
|
page_no=candidate['token'].page_no,
|
||||||
|
score=score,
|
||||||
|
matched_text=candidate['matched_text'],
|
||||||
|
context_keywords=candidate['context_keywords']
|
||||||
|
))
|
||||||
|
|
||||||
|
# Sort by score and return best matches
|
||||||
|
matches.sort(key=lambda m: m.score, reverse=True)
|
||||||
|
|
||||||
|
# Only return the best match to avoid multiple labels for same field
|
||||||
|
return matches[:1] if matches else []
|
||||||
|
|
||||||
|
def _find_context_keywords(
|
||||||
|
self,
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
target_token: TokenLike,
|
||||||
|
field_name: str
|
||||||
|
) -> tuple[list[str], float]:
|
||||||
|
"""Find context keywords near the target token."""
|
||||||
|
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||||
|
found_keywords = []
|
||||||
|
|
||||||
|
target_center = (
|
||||||
|
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||||
|
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
for token in tokens:
|
||||||
|
if token is target_token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
token_center = (
|
||||||
|
(token.bbox[0] + token.bbox[2]) / 2,
|
||||||
|
(token.bbox[1] + token.bbox[3]) / 2
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate distance
|
||||||
|
distance = (
|
||||||
|
(target_center[0] - token_center[0]) ** 2 +
|
||||||
|
(target_center[1] - token_center[1]) ** 2
|
||||||
|
) ** 0.5
|
||||||
|
|
||||||
|
if distance <= self.context_radius:
|
||||||
|
token_lower = token.text.lower()
|
||||||
|
for keyword in keywords:
|
||||||
|
if keyword in token_lower:
|
||||||
|
found_keywords.append(keyword)
|
||||||
|
|
||||||
|
# Calculate boost based on keywords found
|
||||||
|
boost = min(0.15, len(found_keywords) * 0.05)
|
||||||
|
return found_keywords, boost
|
||||||
|
|
||||||
|
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||||
|
"""Check if two tokens are on the same line."""
|
||||||
|
# Check vertical overlap
|
||||||
|
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||||
|
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||||
|
return y_overlap > min_height * 0.5
|
||||||
|
|
||||||
|
def _parse_amount(self, text: str) -> float | None:
|
||||||
|
"""Try to parse text as a monetary amount."""
|
||||||
|
# Remove currency and spaces
|
||||||
|
text = re.sub(r'[SEK|kr|:-]', '', text, flags=re.IGNORECASE)
|
||||||
|
text = text.replace(' ', '').replace('\xa0', '')
|
||||||
|
|
||||||
|
# Try comma as decimal separator
|
||||||
|
if ',' in text and '.' not in text:
|
||||||
|
text = text.replace(',', '.')
|
||||||
|
|
||||||
|
try:
|
||||||
|
return float(text)
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||||
|
"""Remove duplicate matches based on bbox overlap."""
|
||||||
|
if not matches:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sort by score descending
|
||||||
|
matches.sort(key=lambda m: m.score, reverse=True)
|
||||||
|
unique = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
is_duplicate = False
|
||||||
|
for existing in unique:
|
||||||
|
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7:
|
||||||
|
is_duplicate = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not is_duplicate:
|
||||||
|
unique.append(match)
|
||||||
|
|
||||||
|
return unique
|
||||||
|
|
||||||
|
def _bbox_overlap(
|
||||||
|
self,
|
||||||
|
bbox1: tuple[float, float, float, float],
|
||||||
|
bbox2: tuple[float, float, float, float]
|
||||||
|
) -> float:
|
||||||
|
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||||
|
x1 = max(bbox1[0], bbox2[0])
|
||||||
|
y1 = max(bbox1[1], bbox2[1])
|
||||||
|
x2 = min(bbox1[2], bbox2[2])
|
||||||
|
y2 = min(bbox1[3], bbox2[3])
|
||||||
|
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
intersection = (x2 - x1) * (y2 - y1)
|
||||||
|
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
||||||
|
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
||||||
|
union = area1 + area2 - intersection
|
||||||
|
|
||||||
|
return intersection / union if union > 0 else 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def find_field_matches(
|
||||||
|
tokens: list[TokenLike],
|
||||||
|
field_values: dict[str, str],
|
||||||
|
page_no: int = 0
|
||||||
|
) -> dict[str, list[Match]]:
|
||||||
|
"""
|
||||||
|
Convenience function to find matches for multiple fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: List of tokens from the document
|
||||||
|
field_values: Dict of field_name -> value to search for
|
||||||
|
page_no: Page number
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict of field_name -> list of matches
|
||||||
|
"""
|
||||||
|
from ..normalize import normalize_field
|
||||||
|
|
||||||
|
matcher = FieldMatcher()
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for field_name, value in field_values.items():
|
||||||
|
if value is None or str(value).strip() == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized_values = normalize_field(field_name, str(value))
|
||||||
|
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
|
||||||
|
results[field_name] = matches
|
||||||
|
|
||||||
|
return results
|
||||||
3
src/normalize/__init__.py
Normal file
3
src/normalize/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .normalizer import normalize_field, FieldNormalizer
|
||||||
|
|
||||||
|
__all__ = ['normalize_field', 'FieldNormalizer']
|
||||||
290
src/normalize/normalizer.py
Normal file
290
src/normalize/normalizer.py
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
"""
|
||||||
|
Field Normalization Module
|
||||||
|
|
||||||
|
Normalizes field values to generate multiple candidate forms for matching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class NormalizedValue:
|
||||||
|
"""Represents a normalized value with its variants."""
|
||||||
|
original: str
|
||||||
|
variants: list[str]
|
||||||
|
field_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class FieldNormalizer:
|
||||||
|
"""Handles normalization of different invoice field types."""
|
||||||
|
|
||||||
|
# Common Swedish month names for date parsing
|
||||||
|
SWEDISH_MONTHS = {
|
||||||
|
'januari': '01', 'jan': '01',
|
||||||
|
'februari': '02', 'feb': '02',
|
||||||
|
'mars': '03', 'mar': '03',
|
||||||
|
'april': '04', 'apr': '04',
|
||||||
|
'maj': '05',
|
||||||
|
'juni': '06', 'jun': '06',
|
||||||
|
'juli': '07', 'jul': '07',
|
||||||
|
'augusti': '08', 'aug': '08',
|
||||||
|
'september': '09', 'sep': '09', 'sept': '09',
|
||||||
|
'oktober': '10', 'okt': '10',
|
||||||
|
'november': '11', 'nov': '11',
|
||||||
|
'december': '12', 'dec': '12'
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
"""Remove invisible characters and normalize whitespace."""
|
||||||
|
# Remove zero-width characters
|
||||||
|
text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text)
|
||||||
|
# Normalize whitespace
|
||||||
|
text = ' '.join(text.split())
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_invoice_number(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize invoice number.
|
||||||
|
Keeps only digits for matching.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'100017500321' -> ['100017500321']
|
||||||
|
'INV-100017500321' -> ['100017500321', 'INV-100017500321']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
digits_only = re.sub(r'\D', '', value)
|
||||||
|
|
||||||
|
variants = [value]
|
||||||
|
if digits_only and digits_only != value:
|
||||||
|
variants.append(digits_only)
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_ocr_number(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize OCR number (Swedish payment reference).
|
||||||
|
Similar to invoice number - digits only.
|
||||||
|
"""
|
||||||
|
return FieldNormalizer.normalize_invoice_number(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_bankgiro(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize Bankgiro number.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'5393-9484' -> ['5393-9484', '53939484']
|
||||||
|
'53939484' -> ['53939484', '5393-9484']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
digits_only = re.sub(r'\D', '', value)
|
||||||
|
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
if digits_only:
|
||||||
|
# Add without dash
|
||||||
|
variants.append(digits_only)
|
||||||
|
|
||||||
|
# Add with dash (format: XXXX-XXXX for 8 digits)
|
||||||
|
if len(digits_only) == 8:
|
||||||
|
with_dash = f"{digits_only[:4]}-{digits_only[4:]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
elif len(digits_only) == 7:
|
||||||
|
# Some bankgiro numbers are 7 digits: XXX-XXXX
|
||||||
|
with_dash = f"{digits_only[:3]}-{digits_only[3:]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_plusgiro(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize Plusgiro number.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'1234567-8' -> ['1234567-8', '12345678']
|
||||||
|
'12345678' -> ['12345678', '1234567-8']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
digits_only = re.sub(r'\D', '', value)
|
||||||
|
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
if digits_only:
|
||||||
|
variants.append(digits_only)
|
||||||
|
|
||||||
|
# Plusgiro format: XXXXXXX-X (7 digits + check digit)
|
||||||
|
if len(digits_only) == 8:
|
||||||
|
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
# Also handle 6+1 format
|
||||||
|
elif len(digits_only) == 7:
|
||||||
|
with_dash = f"{digits_only[:-1]}-{digits_only[-1]}"
|
||||||
|
variants.append(with_dash)
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_amount(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize monetary amount.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
'114' -> ['114', '114,00', '114.00']
|
||||||
|
'114,00' -> ['114,00', '114.00', '114']
|
||||||
|
'1 234,56' -> ['1234,56', '1234.56', '1 234,56']
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
|
||||||
|
# Remove currency symbols and common suffixes
|
||||||
|
value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip()
|
||||||
|
|
||||||
|
# Remove spaces (thousand separators)
|
||||||
|
no_space = value.replace(' ', '').replace('\xa0', '')
|
||||||
|
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
# Normalize decimal separator
|
||||||
|
if ',' in no_space:
|
||||||
|
dot_version = no_space.replace(',', '.')
|
||||||
|
variants.append(no_space)
|
||||||
|
variants.append(dot_version)
|
||||||
|
elif '.' in no_space:
|
||||||
|
comma_version = no_space.replace('.', ',')
|
||||||
|
variants.append(no_space)
|
||||||
|
variants.append(comma_version)
|
||||||
|
else:
|
||||||
|
# Integer amount - add decimal versions
|
||||||
|
variants.append(no_space)
|
||||||
|
variants.append(f"{no_space},00")
|
||||||
|
variants.append(f"{no_space}.00")
|
||||||
|
|
||||||
|
# Try to parse and get clean numeric value
|
||||||
|
try:
|
||||||
|
# Parse as float
|
||||||
|
clean = no_space.replace(',', '.')
|
||||||
|
num = float(clean)
|
||||||
|
|
||||||
|
# Integer if no decimals
|
||||||
|
if num == int(num):
|
||||||
|
variants.append(str(int(num)))
|
||||||
|
variants.append(f"{int(num)},00")
|
||||||
|
variants.append(f"{int(num)}.00")
|
||||||
|
else:
|
||||||
|
variants.append(f"{num:.2f}")
|
||||||
|
variants.append(f"{num:.2f}".replace('.', ','))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_date(value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize date to YYYY-MM-DD and generate variants.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
'2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025']
|
||||||
|
'13/12/2025' -> ['2025-12-13', '13/12/2025', ...]
|
||||||
|
'13 december 2025' -> ['2025-12-13', ...]
|
||||||
|
"""
|
||||||
|
value = FieldNormalizer.clean_text(value)
|
||||||
|
variants = [value]
|
||||||
|
|
||||||
|
parsed_date = None
|
||||||
|
|
||||||
|
# Try different date formats
|
||||||
|
date_patterns = [
|
||||||
|
# ISO format with optional time (e.g., 2026-01-09 00:00:00)
|
||||||
|
(r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||||
|
# European format with /
|
||||||
|
(r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
||||||
|
# European format with .
|
||||||
|
(r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
||||||
|
# European format with -
|
||||||
|
(r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))),
|
||||||
|
# Swedish format: YYMMDD
|
||||||
|
(r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))),
|
||||||
|
# Swedish format: YYYYMMDD
|
||||||
|
(r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, extractor in date_patterns:
|
||||||
|
match = re.match(pattern, value)
|
||||||
|
if match:
|
||||||
|
try:
|
||||||
|
year, month, day = extractor(match)
|
||||||
|
parsed_date = datetime(year, month, day)
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try Swedish month names
|
||||||
|
if not parsed_date:
|
||||||
|
for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items():
|
||||||
|
if month_name in value.lower():
|
||||||
|
# Extract day and year
|
||||||
|
numbers = re.findall(r'\d+', value)
|
||||||
|
if len(numbers) >= 2:
|
||||||
|
day = int(numbers[0])
|
||||||
|
year = int(numbers[-1])
|
||||||
|
if year < 100:
|
||||||
|
year = 2000 + year if year < 50 else 1900 + year
|
||||||
|
try:
|
||||||
|
parsed_date = datetime(year, int(month_num), day)
|
||||||
|
break
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if parsed_date:
|
||||||
|
# Generate different formats
|
||||||
|
iso = parsed_date.strftime('%Y-%m-%d')
|
||||||
|
eu_slash = parsed_date.strftime('%d/%m/%Y')
|
||||||
|
eu_dot = parsed_date.strftime('%d.%m.%Y')
|
||||||
|
compact = parsed_date.strftime('%Y%m%d')
|
||||||
|
|
||||||
|
variants.extend([iso, eu_slash, eu_dot, compact])
|
||||||
|
|
||||||
|
return list(set(v for v in variants if v))
|
||||||
|
|
||||||
|
|
||||||
|
# Field type to normalizer mapping
|
||||||
|
NORMALIZERS: dict[str, Callable[[str], list[str]]] = {
|
||||||
|
'InvoiceNumber': FieldNormalizer.normalize_invoice_number,
|
||||||
|
'OCR': FieldNormalizer.normalize_ocr_number,
|
||||||
|
'Bankgiro': FieldNormalizer.normalize_bankgiro,
|
||||||
|
'Plusgiro': FieldNormalizer.normalize_plusgiro,
|
||||||
|
'Amount': FieldNormalizer.normalize_amount,
|
||||||
|
'InvoiceDate': FieldNormalizer.normalize_date,
|
||||||
|
'InvoiceDueDate': FieldNormalizer.normalize_date,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_field(field_name: str, value: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Normalize a field value based on its type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount')
|
||||||
|
value: Raw value to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of normalized variants
|
||||||
|
"""
|
||||||
|
if value is None or (isinstance(value, str) and not value.strip()):
|
||||||
|
return []
|
||||||
|
|
||||||
|
value = str(value)
|
||||||
|
normalizer = NORMALIZERS.get(field_name)
|
||||||
|
|
||||||
|
if normalizer:
|
||||||
|
return normalizer(value)
|
||||||
|
|
||||||
|
# Default: just clean the text
|
||||||
|
return [FieldNormalizer.clean_text(value)]
|
||||||
3
src/ocr/__init__.py
Normal file
3
src/ocr/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .paddle_ocr import OCREngine, extract_ocr_tokens
|
||||||
|
|
||||||
|
__all__ = ['OCREngine', 'extract_ocr_tokens']
|
||||||
188
src/ocr/paddle_ocr.py
Normal file
188
src/ocr/paddle_ocr.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""
|
||||||
|
OCR Extraction Module using PaddleOCR
|
||||||
|
|
||||||
|
Extracts text tokens with bounding boxes from scanned PDFs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OCRToken:
|
||||||
|
"""Represents an OCR-extracted text token with its bounding box."""
|
||||||
|
text: str
|
||||||
|
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||||
|
confidence: float
|
||||||
|
page_no: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x0(self) -> float:
|
||||||
|
return self.bbox[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y0(self) -> float:
|
||||||
|
return self.bbox[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x1(self) -> float:
|
||||||
|
return self.bbox[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y1(self) -> float:
|
||||||
|
return self.bbox[3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center(self) -> tuple[float, float]:
|
||||||
|
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
class OCREngine:
|
||||||
|
"""PaddleOCR wrapper for text extraction."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lang: str = "en",
|
||||||
|
use_gpu: bool = False,
|
||||||
|
det_model_dir: str | None = None,
|
||||||
|
rec_model_dir: str | None = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize OCR engine.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lang: Language code ('en', 'sv', 'ch', etc.)
|
||||||
|
use_gpu: Whether to use GPU acceleration
|
||||||
|
det_model_dir: Custom detection model directory
|
||||||
|
rec_model_dir: Custom recognition model directory
|
||||||
|
"""
|
||||||
|
from paddleocr import PaddleOCR
|
||||||
|
|
||||||
|
# PaddleOCR 3.x API - simplified init
|
||||||
|
init_params = {'lang': lang}
|
||||||
|
if det_model_dir:
|
||||||
|
init_params['text_detection_model_dir'] = det_model_dir
|
||||||
|
if rec_model_dir:
|
||||||
|
init_params['text_recognition_model_dir'] = rec_model_dir
|
||||||
|
|
||||||
|
self.ocr = PaddleOCR(**init_params)
|
||||||
|
|
||||||
|
def extract_from_image(
|
||||||
|
self,
|
||||||
|
image: str | Path | np.ndarray,
|
||||||
|
page_no: int = 0
|
||||||
|
) -> list[OCRToken]:
|
||||||
|
"""
|
||||||
|
Extract text tokens from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: Image path or numpy array
|
||||||
|
page_no: Page number for reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OCRToken objects
|
||||||
|
"""
|
||||||
|
if isinstance(image, (str, Path)):
|
||||||
|
image = str(image)
|
||||||
|
|
||||||
|
# PaddleOCR 3.x uses predict() method instead of ocr()
|
||||||
|
result = self.ocr.predict(image)
|
||||||
|
|
||||||
|
tokens = []
|
||||||
|
if result:
|
||||||
|
for item in result:
|
||||||
|
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
|
||||||
|
if isinstance(item, dict):
|
||||||
|
rec_texts = item.get('rec_texts', [])
|
||||||
|
rec_scores = item.get('rec_scores', [])
|
||||||
|
dt_polys = item.get('dt_polys', [])
|
||||||
|
|
||||||
|
for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)):
|
||||||
|
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||||
|
x_coords = [p[0] for p in poly]
|
||||||
|
y_coords = [p[1] for p in poly]
|
||||||
|
|
||||||
|
bbox = (
|
||||||
|
min(x_coords),
|
||||||
|
min(y_coords),
|
||||||
|
max(x_coords),
|
||||||
|
max(y_coords)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens.append(OCRToken(
|
||||||
|
text=text,
|
||||||
|
bbox=bbox,
|
||||||
|
confidence=float(score),
|
||||||
|
page_no=page_no
|
||||||
|
))
|
||||||
|
elif isinstance(item, (list, tuple)) and len(item) == 2:
|
||||||
|
# Legacy format: [[bbox_points], (text, confidence)]
|
||||||
|
bbox_points, (text, confidence) = item
|
||||||
|
|
||||||
|
x_coords = [p[0] for p in bbox_points]
|
||||||
|
y_coords = [p[1] for p in bbox_points]
|
||||||
|
|
||||||
|
bbox = (
|
||||||
|
min(x_coords),
|
||||||
|
min(y_coords),
|
||||||
|
max(x_coords),
|
||||||
|
max(y_coords)
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens.append(OCRToken(
|
||||||
|
text=text,
|
||||||
|
bbox=bbox,
|
||||||
|
confidence=confidence,
|
||||||
|
page_no=page_no
|
||||||
|
))
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def extract_from_pdf(
|
||||||
|
self,
|
||||||
|
pdf_path: str | Path,
|
||||||
|
dpi: int = 300
|
||||||
|
) -> Generator[list[OCRToken], None, None]:
|
||||||
|
"""
|
||||||
|
Extract text from all pages of a scanned PDF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
dpi: Resolution for rendering
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
List of OCRToken for each page
|
||||||
|
"""
|
||||||
|
from ..pdf.renderer import render_pdf_to_images
|
||||||
|
import io
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi):
|
||||||
|
# Convert bytes to numpy array
|
||||||
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
|
image_array = np.array(image)
|
||||||
|
|
||||||
|
tokens = self.extract_from_image(image_array, page_no=page_no)
|
||||||
|
yield tokens
|
||||||
|
|
||||||
|
|
||||||
|
def extract_ocr_tokens(
|
||||||
|
image_path: str | Path,
|
||||||
|
lang: str = "en",
|
||||||
|
page_no: int = 0
|
||||||
|
) -> list[OCRToken]:
|
||||||
|
"""
|
||||||
|
Convenience function to extract OCR tokens from an image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: Path to the image file
|
||||||
|
lang: Language code
|
||||||
|
page_no: Page number for reference
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of OCRToken objects
|
||||||
|
"""
|
||||||
|
engine = OCREngine(lang=lang)
|
||||||
|
return engine.extract_from_image(image_path, page_no=page_no)
|
||||||
5
src/pdf/__init__.py
Normal file
5
src/pdf/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from .detector import is_text_pdf, get_pdf_type
|
||||||
|
from .renderer import render_pdf_to_images
|
||||||
|
from .extractor import extract_text_tokens
|
||||||
|
|
||||||
|
__all__ = ['is_text_pdf', 'get_pdf_type', 'render_pdf_to_images', 'extract_text_tokens']
|
||||||
98
src/pdf/detector.py
Normal file
98
src/pdf/detector.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
PDF Type Detection Module
|
||||||
|
|
||||||
|
Automatically distinguishes between:
|
||||||
|
- Text-based PDFs (digitally generated)
|
||||||
|
- Scanned image PDFs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
|
||||||
|
PDFType = Literal["text", "scanned", "mixed"]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_first_page(pdf_path: str | Path) -> str:
|
||||||
|
"""Extract text from the first page of a PDF."""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
if len(doc) == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
first_page = doc[0]
|
||||||
|
text = first_page.get_text()
|
||||||
|
doc.close()
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def is_text_pdf(pdf_path: str | Path, min_chars: int = 30) -> bool:
|
||||||
|
"""
|
||||||
|
Check if PDF has extractable text layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
min_chars: Minimum characters to consider it a text PDF
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if PDF has text layer, False if scanned
|
||||||
|
"""
|
||||||
|
text = extract_text_first_page(pdf_path)
|
||||||
|
return len(text.strip()) > min_chars
|
||||||
|
|
||||||
|
|
||||||
|
def get_pdf_type(pdf_path: str | Path) -> PDFType:
|
||||||
|
"""
|
||||||
|
Determine the PDF type.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
'text' - Has extractable text layer
|
||||||
|
'scanned' - Image-based, needs OCR
|
||||||
|
'mixed' - Some pages have text, some don't
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
if len(doc) == 0:
|
||||||
|
doc.close()
|
||||||
|
return "scanned"
|
||||||
|
|
||||||
|
text_pages = 0
|
||||||
|
for page in doc:
|
||||||
|
text = page.get_text().strip()
|
||||||
|
if len(text) > 30:
|
||||||
|
text_pages += 1
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
total_pages = len(doc)
|
||||||
|
if text_pages == total_pages:
|
||||||
|
return "text"
|
||||||
|
elif text_pages == 0:
|
||||||
|
return "scanned"
|
||||||
|
else:
|
||||||
|
return "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_info(pdf_path: str | Path) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Get information about each page in the PDF.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with page info (number, width, height, has_text)
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
pages = []
|
||||||
|
|
||||||
|
for i, page in enumerate(doc):
|
||||||
|
text = page.get_text().strip()
|
||||||
|
rect = page.rect
|
||||||
|
pages.append({
|
||||||
|
"page_no": i,
|
||||||
|
"width": rect.width,
|
||||||
|
"height": rect.height,
|
||||||
|
"has_text": len(text) > 30,
|
||||||
|
"char_count": len(text)
|
||||||
|
})
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return pages
|
||||||
176
src/pdf/extractor.py
Normal file
176
src/pdf/extractor.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
PDF Text Extraction Module
|
||||||
|
|
||||||
|
Extracts text tokens with bounding boxes from text-layer PDFs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Token:
|
||||||
|
"""Represents a text token with its bounding box."""
|
||||||
|
text: str
|
||||||
|
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||||
|
page_no: int
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x0(self) -> float:
|
||||||
|
return self.bbox[0]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y0(self) -> float:
|
||||||
|
return self.bbox[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def x1(self) -> float:
|
||||||
|
return self.bbox[2]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def y1(self) -> float:
|
||||||
|
return self.bbox[3]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self) -> float:
|
||||||
|
return self.x1 - self.x0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self) -> float:
|
||||||
|
return self.y1 - self.y0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def center(self) -> tuple[float, float]:
|
||||||
|
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_text_tokens(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
page_no: int | None = None
|
||||||
|
) -> Generator[Token, None, None]:
|
||||||
|
"""
|
||||||
|
Extract text tokens with bounding boxes from PDF.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
page_no: Specific page to extract (None for all pages)
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Token objects with text and bbox
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||||
|
|
||||||
|
for pg_no in pages_to_process:
|
||||||
|
page = doc[pg_no]
|
||||||
|
|
||||||
|
# Get text with position info using "dict" mode
|
||||||
|
text_dict = page.get_text("dict")
|
||||||
|
|
||||||
|
for block in text_dict.get("blocks", []):
|
||||||
|
if block.get("type") != 0: # Skip non-text blocks
|
||||||
|
continue
|
||||||
|
|
||||||
|
for line in block.get("lines", []):
|
||||||
|
for span in line.get("spans", []):
|
||||||
|
text = span.get("text", "").strip()
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
bbox = span.get("bbox")
|
||||||
|
if bbox:
|
||||||
|
yield Token(
|
||||||
|
text=text,
|
||||||
|
bbox=tuple(bbox),
|
||||||
|
page_no=pg_no
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_words(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
page_no: int | None = None
|
||||||
|
) -> Generator[Token, None, None]:
|
||||||
|
"""
|
||||||
|
Extract individual words with bounding boxes.
|
||||||
|
|
||||||
|
Uses PyMuPDF's word extraction which splits text into words.
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||||
|
|
||||||
|
for pg_no in pages_to_process:
|
||||||
|
page = doc[pg_no]
|
||||||
|
|
||||||
|
# get_text("words") returns list of (x0, y0, x1, y1, word, block_no, line_no, word_no)
|
||||||
|
words = page.get_text("words")
|
||||||
|
|
||||||
|
for word_info in words:
|
||||||
|
x0, y0, x1, y1, text, *_ = word_info
|
||||||
|
text = text.strip()
|
||||||
|
if text:
|
||||||
|
yield Token(
|
||||||
|
text=text,
|
||||||
|
bbox=(x0, y0, x1, y1),
|
||||||
|
page_no=pg_no
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_lines(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
page_no: int | None = None
|
||||||
|
) -> Generator[Token, None, None]:
|
||||||
|
"""
|
||||||
|
Extract text lines with bounding boxes.
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
pages_to_process = [page_no] if page_no is not None else range(len(doc))
|
||||||
|
|
||||||
|
for pg_no in pages_to_process:
|
||||||
|
page = doc[pg_no]
|
||||||
|
text_dict = page.get_text("dict")
|
||||||
|
|
||||||
|
for block in text_dict.get("blocks", []):
|
||||||
|
if block.get("type") != 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for line in block.get("lines", []):
|
||||||
|
spans = line.get("spans", [])
|
||||||
|
if not spans:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Combine all spans in the line
|
||||||
|
line_text = " ".join(s.get("text", "") for s in spans).strip()
|
||||||
|
if not line_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Get line bbox from all spans
|
||||||
|
x0 = min(s["bbox"][0] for s in spans)
|
||||||
|
y0 = min(s["bbox"][1] for s in spans)
|
||||||
|
x1 = max(s["bbox"][2] for s in spans)
|
||||||
|
y1 = max(s["bbox"][3] for s in spans)
|
||||||
|
|
||||||
|
yield Token(
|
||||||
|
text=line_text,
|
||||||
|
bbox=(x0, y0, x1, y1),
|
||||||
|
page_no=pg_no
|
||||||
|
)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_dimensions(pdf_path: str | Path, page_no: int = 0) -> tuple[float, float]:
|
||||||
|
"""Get the dimensions of a PDF page in points."""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
page = doc[page_no]
|
||||||
|
rect = page.rect
|
||||||
|
doc.close()
|
||||||
|
return rect.width, rect.height
|
||||||
117
src/pdf/renderer.py
Normal file
117
src/pdf/renderer.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
"""
|
||||||
|
PDF Rendering Module
|
||||||
|
|
||||||
|
Converts PDF pages to images for YOLO training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
import fitz # PyMuPDF
|
||||||
|
|
||||||
|
|
||||||
|
def render_pdf_to_images(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
output_dir: str | Path | None = None,
|
||||||
|
dpi: int = 300,
|
||||||
|
image_format: str = "png"
|
||||||
|
) -> Generator[tuple[int, Path | bytes], None, None]:
|
||||||
|
"""
|
||||||
|
Render PDF pages to images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
output_dir: Directory to save images (if None, returns bytes)
|
||||||
|
dpi: Resolution for rendering (default 300)
|
||||||
|
image_format: Output format ('png' or 'jpg')
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple of (page_number, image_path or image_bytes)
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
# Calculate zoom factor for desired DPI (72 is base DPI for PDF)
|
||||||
|
zoom = dpi / 72
|
||||||
|
matrix = fitz.Matrix(zoom, zoom)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
pdf_name = Path(pdf_path).stem
|
||||||
|
|
||||||
|
for page_no, page in enumerate(doc):
|
||||||
|
# Render page to pixmap
|
||||||
|
pix = page.get_pixmap(matrix=matrix)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
# Save to file
|
||||||
|
ext = "jpg" if image_format.lower() in ("jpg", "jpeg") else "png"
|
||||||
|
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.{ext}"
|
||||||
|
|
||||||
|
if ext == "jpg":
|
||||||
|
pix.save(str(image_path), "jpeg")
|
||||||
|
else:
|
||||||
|
pix.save(str(image_path))
|
||||||
|
|
||||||
|
yield page_no, image_path
|
||||||
|
else:
|
||||||
|
# Return bytes
|
||||||
|
if image_format.lower() in ("jpg", "jpeg"):
|
||||||
|
yield page_no, pix.tobytes("jpeg")
|
||||||
|
else:
|
||||||
|
yield page_no, pix.tobytes("png")
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
|
||||||
|
|
||||||
|
def render_page_to_image(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
page_no: int,
|
||||||
|
dpi: int = 300
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
Render a single page to image bytes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
page_no: Page number (0-indexed)
|
||||||
|
dpi: Resolution for rendering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PNG image bytes
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
|
||||||
|
if page_no >= len(doc):
|
||||||
|
doc.close()
|
||||||
|
raise ValueError(f"Page {page_no} does not exist (PDF has {len(doc)} pages)")
|
||||||
|
|
||||||
|
zoom = dpi / 72
|
||||||
|
matrix = fitz.Matrix(zoom, zoom)
|
||||||
|
|
||||||
|
page = doc[page_no]
|
||||||
|
pix = page.get_pixmap(matrix=matrix)
|
||||||
|
image_bytes = pix.tobytes("png")
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return image_bytes
|
||||||
|
|
||||||
|
|
||||||
|
def get_render_dimensions(pdf_path: str | Path, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Get the dimensions of a rendered page.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(width, height) in pixels
|
||||||
|
"""
|
||||||
|
doc = fitz.open(pdf_path)
|
||||||
|
page = doc[page_no]
|
||||||
|
|
||||||
|
zoom = dpi / 72
|
||||||
|
rect = page.rect
|
||||||
|
|
||||||
|
width = int(rect.width * zoom)
|
||||||
|
height = int(rect.height * zoom)
|
||||||
|
|
||||||
|
doc.close()
|
||||||
|
return width, height
|
||||||
4
src/yolo/__init__.py
Normal file
4
src/yolo/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from .annotation_generator import AnnotationGenerator, generate_annotations
|
||||||
|
from .dataset_builder import DatasetBuilder
|
||||||
|
|
||||||
|
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder']
|
||||||
281
src/yolo/annotation_generator.py
Normal file
281
src/yolo/annotation_generator.py
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
"""
|
||||||
|
YOLO Annotation Generator
|
||||||
|
|
||||||
|
Generates YOLO format annotations from matched fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
import csv
|
||||||
|
|
||||||
|
|
||||||
|
# Field class mapping for YOLO
|
||||||
|
FIELD_CLASSES = {
|
||||||
|
'InvoiceNumber': 0,
|
||||||
|
'InvoiceDate': 1,
|
||||||
|
'InvoiceDueDate': 2,
|
||||||
|
'OCR': 3,
|
||||||
|
'Bankgiro': 4,
|
||||||
|
'Plusgiro': 5,
|
||||||
|
'Amount': 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
CLASS_NAMES = [
|
||||||
|
'invoice_number',
|
||||||
|
'invoice_date',
|
||||||
|
'invoice_due_date',
|
||||||
|
'ocr_number',
|
||||||
|
'bankgiro',
|
||||||
|
'plusgiro',
|
||||||
|
'amount',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class YOLOAnnotation:
|
||||||
|
"""Represents a single YOLO annotation."""
|
||||||
|
class_id: int
|
||||||
|
x_center: float # normalized 0-1
|
||||||
|
y_center: float # normalized 0-1
|
||||||
|
width: float # normalized 0-1
|
||||||
|
height: float # normalized 0-1
|
||||||
|
confidence: float = 1.0
|
||||||
|
|
||||||
|
def to_string(self) -> str:
|
||||||
|
"""Convert to YOLO format string."""
|
||||||
|
return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}"
|
||||||
|
|
||||||
|
|
||||||
|
class AnnotationGenerator:
|
||||||
|
"""Generates YOLO annotations from document matches."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_confidence: float = 0.7,
|
||||||
|
bbox_padding_px: int = 20, # Absolute padding in pixels
|
||||||
|
min_bbox_height_px: int = 30 # Minimum bbox height
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize annotation generator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_confidence: Minimum match score to include in training
|
||||||
|
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
||||||
|
min_bbox_height_px: Minimum bbox height in pixels
|
||||||
|
"""
|
||||||
|
self.min_confidence = min_confidence
|
||||||
|
self.bbox_padding_px = bbox_padding_px
|
||||||
|
self.min_bbox_height_px = min_bbox_height_px
|
||||||
|
|
||||||
|
def generate_from_matches(
|
||||||
|
self,
|
||||||
|
matches: dict[str, list[Any]], # field_name -> list of Match
|
||||||
|
image_width: float,
|
||||||
|
image_height: float,
|
||||||
|
dpi: int = 300
|
||||||
|
) -> list[YOLOAnnotation]:
|
||||||
|
"""
|
||||||
|
Generate YOLO annotations from field matches.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matches: Dict of field_name -> list of Match objects
|
||||||
|
image_width: Width of the rendered image in pixels
|
||||||
|
image_height: Height of the rendered image in pixels
|
||||||
|
dpi: DPI used for rendering (needed to convert PDF coords to pixels)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of YOLOAnnotation objects
|
||||||
|
"""
|
||||||
|
annotations = []
|
||||||
|
|
||||||
|
# Scale factor to convert PDF points (72 DPI) to rendered pixels
|
||||||
|
scale = dpi / 72.0
|
||||||
|
|
||||||
|
for field_name, field_matches in matches.items():
|
||||||
|
if field_name not in FIELD_CLASSES:
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_id = FIELD_CLASSES[field_name]
|
||||||
|
|
||||||
|
# Take only the best match per field
|
||||||
|
if field_matches:
|
||||||
|
best_match = field_matches[0] # Already sorted by score
|
||||||
|
|
||||||
|
if best_match.score < self.min_confidence:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# best_match.bbox is in PDF points (72 DPI), convert to pixels
|
||||||
|
x0, y0, x1, y1 = best_match.bbox
|
||||||
|
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||||
|
|
||||||
|
# Add absolute padding
|
||||||
|
pad = self.bbox_padding_px
|
||||||
|
x0 = max(0, x0 - pad)
|
||||||
|
y0 = max(0, y0 - pad)
|
||||||
|
x1 = min(image_width, x1 + pad)
|
||||||
|
y1 = min(image_height, y1 + pad)
|
||||||
|
|
||||||
|
# Ensure minimum height
|
||||||
|
current_height = y1 - y0
|
||||||
|
if current_height < self.min_bbox_height_px:
|
||||||
|
extra = (self.min_bbox_height_px - current_height) / 2
|
||||||
|
y0 = max(0, y0 - extra)
|
||||||
|
y1 = min(image_height, y1 + extra)
|
||||||
|
|
||||||
|
# Convert to YOLO format (normalized center + size)
|
||||||
|
x_center = (x0 + x1) / 2 / image_width
|
||||||
|
y_center = (y0 + y1) / 2 / image_height
|
||||||
|
width = (x1 - x0) / image_width
|
||||||
|
height = (y1 - y0) / image_height
|
||||||
|
|
||||||
|
# Clamp values to 0-1
|
||||||
|
x_center = max(0, min(1, x_center))
|
||||||
|
y_center = max(0, min(1, y_center))
|
||||||
|
width = max(0, min(1, width))
|
||||||
|
height = max(0, min(1, height))
|
||||||
|
|
||||||
|
annotations.append(YOLOAnnotation(
|
||||||
|
class_id=class_id,
|
||||||
|
x_center=x_center,
|
||||||
|
y_center=y_center,
|
||||||
|
width=width,
|
||||||
|
height=height,
|
||||||
|
confidence=best_match.score
|
||||||
|
))
|
||||||
|
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
def save_annotations(
|
||||||
|
self,
|
||||||
|
annotations: list[YOLOAnnotation],
|
||||||
|
output_path: str | Path
|
||||||
|
) -> None:
|
||||||
|
"""Save annotations to a YOLO format text file."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
for ann in annotations:
|
||||||
|
f.write(ann.to_string() + '\n')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_classes_file(output_path: str | Path) -> None:
|
||||||
|
"""Generate the classes.txt file for YOLO."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
for class_name in CLASS_NAMES:
|
||||||
|
f.write(class_name + '\n')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def generate_yaml_config(
|
||||||
|
output_path: str | Path,
|
||||||
|
train_path: str = 'train/images',
|
||||||
|
val_path: str = 'val/images',
|
||||||
|
test_path: str = 'test/images'
|
||||||
|
) -> None:
|
||||||
|
"""Generate YOLO dataset YAML configuration."""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Use absolute path for WSL compatibility
|
||||||
|
dataset_dir = output_path.parent.absolute()
|
||||||
|
# Convert Windows path to WSL path if needed
|
||||||
|
dataset_path_str = str(dataset_dir).replace('\\', '/')
|
||||||
|
if dataset_path_str[1] == ':':
|
||||||
|
# Windows path like C:/... -> /mnt/c/...
|
||||||
|
drive = dataset_path_str[0].lower()
|
||||||
|
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
||||||
|
|
||||||
|
config = f"""# Invoice Field Detection Dataset
|
||||||
|
path: {dataset_path_str}
|
||||||
|
train: {train_path}
|
||||||
|
val: {val_path}
|
||||||
|
test: {test_path}
|
||||||
|
|
||||||
|
# Classes
|
||||||
|
names:
|
||||||
|
"""
|
||||||
|
for i, name in enumerate(CLASS_NAMES):
|
||||||
|
config += f" {i}: {name}\n"
|
||||||
|
|
||||||
|
with open(output_path, 'w') as f:
|
||||||
|
f.write(config)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_annotations(
|
||||||
|
pdf_path: str | Path,
|
||||||
|
structured_data: dict[str, Any],
|
||||||
|
output_dir: str | Path,
|
||||||
|
dpi: int = 300
|
||||||
|
) -> list[Path]:
|
||||||
|
"""
|
||||||
|
Generate YOLO annotations for a PDF using structured data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_path: Path to the PDF file
|
||||||
|
structured_data: Dict with field values from CSV
|
||||||
|
output_dir: Directory to save images and labels
|
||||||
|
dpi: Resolution for rendering
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of paths to generated annotation files
|
||||||
|
"""
|
||||||
|
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||||
|
from ..pdf.renderer import get_render_dimensions
|
||||||
|
from ..ocr import OCREngine
|
||||||
|
from ..matcher import FieldMatcher
|
||||||
|
from ..normalize import normalize_field
|
||||||
|
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
images_dir = output_dir / 'images'
|
||||||
|
labels_dir = output_dir / 'labels'
|
||||||
|
images_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
labels_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
generator = AnnotationGenerator()
|
||||||
|
matcher = FieldMatcher()
|
||||||
|
annotation_files = []
|
||||||
|
|
||||||
|
# Check PDF type
|
||||||
|
use_ocr = not is_text_pdf(pdf_path)
|
||||||
|
|
||||||
|
# Initialize OCR if needed
|
||||||
|
ocr_engine = OCREngine() if use_ocr else None
|
||||||
|
|
||||||
|
# Process each page
|
||||||
|
for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi):
|
||||||
|
# Get image dimensions
|
||||||
|
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||||
|
|
||||||
|
# Extract tokens
|
||||||
|
if use_ocr:
|
||||||
|
from PIL import Image
|
||||||
|
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||||
|
else:
|
||||||
|
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||||
|
|
||||||
|
# Match fields
|
||||||
|
matches = {}
|
||||||
|
for field_name in FIELD_CLASSES.keys():
|
||||||
|
value = structured_data.get(field_name)
|
||||||
|
if value is None or str(value).strip() == '':
|
||||||
|
continue
|
||||||
|
|
||||||
|
normalized = normalize_field(field_name, str(value))
|
||||||
|
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||||
|
if field_matches:
|
||||||
|
matches[field_name] = field_matches
|
||||||
|
|
||||||
|
# Generate annotations (pass DPI for coordinate conversion)
|
||||||
|
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||||
|
|
||||||
|
# Save annotations
|
||||||
|
if annotations:
|
||||||
|
label_path = labels_dir / f"{image_path.stem}.txt"
|
||||||
|
generator.save_annotations(annotations, label_path)
|
||||||
|
annotation_files.append(label_path)
|
||||||
|
|
||||||
|
return annotation_files
|
||||||
249
src/yolo/dataset_builder.py
Normal file
249
src/yolo/dataset_builder.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""
|
||||||
|
YOLO Dataset Builder
|
||||||
|
|
||||||
|
Builds training dataset from PDFs and structured CSV data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetStats:
|
||||||
|
"""Statistics about the generated dataset."""
|
||||||
|
total_documents: int
|
||||||
|
successful: int
|
||||||
|
failed: int
|
||||||
|
total_annotations: int
|
||||||
|
annotations_per_class: dict[str, int]
|
||||||
|
train_count: int
|
||||||
|
val_count: int
|
||||||
|
test_count: int
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetBuilder:
|
||||||
|
"""Builds YOLO training dataset from PDFs and CSV data."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pdf_dir: str | Path,
|
||||||
|
csv_path: str | Path,
|
||||||
|
output_dir: str | Path,
|
||||||
|
document_id_column: str = 'DocumentId',
|
||||||
|
dpi: int = 300,
|
||||||
|
train_ratio: float = 0.8,
|
||||||
|
val_ratio: float = 0.1,
|
||||||
|
test_ratio: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize dataset builder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pdf_dir: Directory containing PDF files
|
||||||
|
csv_path: Path to structured data CSV
|
||||||
|
output_dir: Output directory for dataset
|
||||||
|
document_id_column: Column name for document ID
|
||||||
|
dpi: Resolution for rendering
|
||||||
|
train_ratio: Fraction for training set
|
||||||
|
val_ratio: Fraction for validation set
|
||||||
|
test_ratio: Fraction for test set
|
||||||
|
"""
|
||||||
|
self.pdf_dir = Path(pdf_dir)
|
||||||
|
self.csv_path = Path(csv_path)
|
||||||
|
self.output_dir = Path(output_dir)
|
||||||
|
self.document_id_column = document_id_column
|
||||||
|
self.dpi = dpi
|
||||||
|
self.train_ratio = train_ratio
|
||||||
|
self.val_ratio = val_ratio
|
||||||
|
self.test_ratio = test_ratio
|
||||||
|
|
||||||
|
def load_structured_data(self) -> dict[str, dict]:
|
||||||
|
"""Load structured data from CSV."""
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
with open(self.csv_path, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.DictReader(f)
|
||||||
|
for row in reader:
|
||||||
|
doc_id = row.get(self.document_id_column)
|
||||||
|
if doc_id:
|
||||||
|
data[doc_id] = row
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def find_pdf_for_document(self, doc_id: str) -> Path | None:
|
||||||
|
"""Find PDF file for a document ID."""
|
||||||
|
# Try common naming patterns
|
||||||
|
patterns = [
|
||||||
|
f"{doc_id}.pdf",
|
||||||
|
f"{doc_id.lower()}.pdf",
|
||||||
|
f"{doc_id.upper()}.pdf",
|
||||||
|
f"*{doc_id}*.pdf",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
matches = list(self.pdf_dir.glob(pattern))
|
||||||
|
if matches:
|
||||||
|
return matches[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def build(self, seed: int = 42) -> DatasetStats:
|
||||||
|
"""
|
||||||
|
Build the complete dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seed: Random seed for train/val/test split
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatasetStats with build results
|
||||||
|
"""
|
||||||
|
from .annotation_generator import AnnotationGenerator, CLASS_NAMES
|
||||||
|
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
# Setup output directories
|
||||||
|
for split in ['train', 'val', 'test']:
|
||||||
|
(self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||||
|
(self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Generate config files
|
||||||
|
AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt')
|
||||||
|
AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml')
|
||||||
|
|
||||||
|
# Load structured data
|
||||||
|
structured_data = self.load_structured_data()
|
||||||
|
|
||||||
|
# Stats tracking
|
||||||
|
stats = {
|
||||||
|
'total': 0,
|
||||||
|
'successful': 0,
|
||||||
|
'failed': 0,
|
||||||
|
'annotations': 0,
|
||||||
|
'per_class': {name: 0 for name in CLASS_NAMES},
|
||||||
|
'splits': {'train': 0, 'val': 0, 'test': 0}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process each document
|
||||||
|
processed_items = []
|
||||||
|
|
||||||
|
for doc_id, data in structured_data.items():
|
||||||
|
stats['total'] += 1
|
||||||
|
|
||||||
|
pdf_path = self.find_pdf_for_document(doc_id)
|
||||||
|
if not pdf_path:
|
||||||
|
print(f"Warning: PDF not found for document {doc_id}")
|
||||||
|
stats['failed'] += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate to temp dir first
|
||||||
|
temp_dir = self.output_dir / 'temp' / doc_id
|
||||||
|
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
from .annotation_generator import generate_annotations
|
||||||
|
annotation_files = generate_annotations(
|
||||||
|
pdf_path, data, temp_dir, self.dpi
|
||||||
|
)
|
||||||
|
|
||||||
|
if annotation_files:
|
||||||
|
processed_items.append({
|
||||||
|
'doc_id': doc_id,
|
||||||
|
'temp_dir': temp_dir,
|
||||||
|
'annotation_files': annotation_files
|
||||||
|
})
|
||||||
|
stats['successful'] += 1
|
||||||
|
else:
|
||||||
|
print(f"Warning: No annotations generated for {doc_id}")
|
||||||
|
stats['failed'] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {doc_id}: {e}")
|
||||||
|
stats['failed'] += 1
|
||||||
|
|
||||||
|
# Shuffle and split
|
||||||
|
random.shuffle(processed_items)
|
||||||
|
|
||||||
|
n_train = int(len(processed_items) * self.train_ratio)
|
||||||
|
n_val = int(len(processed_items) * self.val_ratio)
|
||||||
|
|
||||||
|
splits = {
|
||||||
|
'train': processed_items[:n_train],
|
||||||
|
'val': processed_items[n_train:n_train + n_val],
|
||||||
|
'test': processed_items[n_train + n_val:]
|
||||||
|
}
|
||||||
|
|
||||||
|
# Move files to final locations
|
||||||
|
for split_name, items in splits.items():
|
||||||
|
for item in items:
|
||||||
|
temp_dir = item['temp_dir']
|
||||||
|
|
||||||
|
# Move images
|
||||||
|
for img in (temp_dir / 'images').glob('*'):
|
||||||
|
dest = self.output_dir / split_name / 'images' / img.name
|
||||||
|
shutil.move(str(img), str(dest))
|
||||||
|
|
||||||
|
# Move labels and count annotations
|
||||||
|
for label in (temp_dir / 'labels').glob('*.txt'):
|
||||||
|
dest = self.output_dir / split_name / 'labels' / label.name
|
||||||
|
shutil.move(str(label), str(dest))
|
||||||
|
|
||||||
|
# Count annotations per class
|
||||||
|
with open(dest, 'r') as f:
|
||||||
|
for line in f:
|
||||||
|
class_id = int(line.strip().split()[0])
|
||||||
|
if 0 <= class_id < len(CLASS_NAMES):
|
||||||
|
stats['per_class'][CLASS_NAMES[class_id]] += 1
|
||||||
|
stats['annotations'] += 1
|
||||||
|
|
||||||
|
stats['splits'][split_name] += 1
|
||||||
|
|
||||||
|
# Cleanup temp dir
|
||||||
|
shutil.rmtree(self.output_dir / 'temp', ignore_errors=True)
|
||||||
|
|
||||||
|
return DatasetStats(
|
||||||
|
total_documents=stats['total'],
|
||||||
|
successful=stats['successful'],
|
||||||
|
failed=stats['failed'],
|
||||||
|
total_annotations=stats['annotations'],
|
||||||
|
annotations_per_class=stats['per_class'],
|
||||||
|
train_count=stats['splits']['train'],
|
||||||
|
val_count=stats['splits']['val'],
|
||||||
|
test_count=stats['splits']['test']
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_single_document(
|
||||||
|
self,
|
||||||
|
doc_id: str,
|
||||||
|
data: dict,
|
||||||
|
split: str = 'train'
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Process a single document (for incremental building).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
doc_id: Document ID
|
||||||
|
data: Structured data dict
|
||||||
|
split: Which split to add to
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful
|
||||||
|
"""
|
||||||
|
from .annotation_generator import generate_annotations
|
||||||
|
|
||||||
|
pdf_path = self.find_pdf_for_document(doc_id)
|
||||||
|
if not pdf_path:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
output_subdir = self.output_dir / split
|
||||||
|
annotation_files = generate_annotations(
|
||||||
|
pdf_path, data, output_subdir, self.dpi
|
||||||
|
)
|
||||||
|
return len(annotation_files) > 0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {doc_id}: {e}")
|
||||||
|
return False
|
||||||
Reference in New Issue
Block a user