commit 8938661850d648328f663dcaf312c1da0d9e0cb3 Author: Yaojia Wang Date: Sat Jan 10 17:44:14 2026 +0100 Initial commit: Invoice field extraction system using YOLO + OCR Features: - Auto-labeling pipeline: CSV values -> PDF search -> YOLO annotations - Flexible date matching: year-month match, nearby date tolerance - PDF text extraction with PyMuPDF - OCR support for scanned documents (PaddleOCR) - YOLO training and inference pipeline - 7 field types: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount Co-Authored-By: Claude Opus 4.5 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..aea562c --- /dev/null +++ b/.gitignore @@ -0,0 +1,71 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg + +# Virtual environments +venv/ +ENV/ +env/ +.venv/ + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# Data files (large files) +data/raw_pdfs/ +data/dataset/train/images/ +data/dataset/val/images/ +data/dataset/test/images/ +data/dataset/train/labels/ +data/dataset/val/labels/ +data/dataset/test/labels/ +*.pdf +*.png +*.jpg +*.jpeg + +# Model weights +models/weights/ +runs/ +*.pt +*.onnx + +# Reports and logs +reports/*.jsonl +logs/ +*.log + +# Jupyter +.ipynb_checkpoints/ + +# OS +.DS_Store +Thumbs.db + +# Credentials +.env +*.key +*.pem +credentials.json diff --git a/README.md b/README.md new file mode 100644 index 0000000..59f7f11 --- /dev/null +++ b/README.md @@ -0,0 +1,226 @@ +# Invoice Master POC v2 + +自动账单信息提取系统 - 使用 YOLO + OCR 从 PDF 发票中提取结构化数据。 + +## 运行环境 + +> **重要**: 本项目需要在 **WSL (Windows Subsystem for Linux)** 环境下运行。 + +### 系统要求 + +- WSL 2 (Ubuntu 22.04 推荐) +- Python 3.10+ +- **NVIDIA GPU + CUDA 12.x (强烈推荐)** - GPU 训练比 CPU 快 10-50 倍 + +## 功能特点 + +- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF +- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据 +- **字段检测**: 使用 YOLOv8 检测发票字段区域 +- **OCR 识别**: 使用 PaddleOCR 提取检测区域的文本 +- **智能匹配**: 支持多种格式规范化和上下文关键词增强 + +## 支持的字段 + +| 字段 | 说明 | +|------|------| +| InvoiceNumber | 发票号码 | +| InvoiceDate | 发票日期 | +| InvoiceDueDate | 到期日期 | +| OCR | OCR 参考号 (瑞典) | +| Bankgiro | Bankgiro 号码 | +| Plusgiro | Plusgiro 号码 | +| Amount | 金额 | + +## 安装 (WSL) + +### 1. 进入 WSL 环境 + +```bash +# 从 Windows 终端进入 WSL +wsl + +# 进入项目目录 (Windows 路径映射到 /mnt/) +cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 +``` + +### 2. 安装系统依赖 + +```bash +# 更新系统 +sudo apt update && sudo apt upgrade -y + +# 安装 Python 和必要工具 +sudo apt install -y python3.10 python3.10-venv python3-pip + +# 安装 OpenCV 依赖 +sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6 +``` + +### 3. 创建虚拟环境并安装依赖 + +```bash +# 创建虚拟环境 +python3 -m venv venv +source venv/bin/activate + +# 升级 pip +pip install --upgrade pip + +# 安装依赖 +pip install -r requirements.txt + +# 或使用 pip install (开发模式) +pip install -e . +``` + +### GPU 支持 (可选) + +```bash +# 确保 WSL 已配置 CUDA +nvidia-smi # 检查 GPU 是否可用 + +# 安装 GPU 版本 PaddlePaddle +pip install paddlepaddle-gpu + +# 或指定 CUDA 版本 +pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html +``` + +## 快速开始 + +### 1. 准备数据 + +``` +data/ +├── raw_pdfs/ +│ ├── {DocumentId}.pdf +│ └── ... +└── structured_data/ + └── invoices.csv +``` + +CSV 格式: +```csv +DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount +3be53fd7-...,2025-12-13,100017500321,2026-01-03,100017500321,53939484,,114 +``` + +### 2. 自动标注 + +```bash +python -m src.cli.autolabel \ + --csv data/structured_data/invoices.csv \ + --pdf-dir data/raw_pdfs \ + --output data/dataset \ + --report reports/autolabel_report.jsonl +``` + +### 3. 训练模型 + +> **重要**: 务必使用 GPU 进行训练!CPU 训练速度非常慢。 + +```bash +# GPU 训练 (强烈推荐) +python -m src.cli.train \ + --data data/dataset/dataset.yaml \ + --model yolo11n.pt \ + --epochs 100 \ + --batch 16 \ + --device 0 # 使用 GPU + +# 验证 GPU 可用 +python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')" +``` + +GPU vs CPU 训练时间对比 (100 epochs, 77 训练图片): +- **GPU (RTX 5080)**: ~2 分钟 +- **CPU**: 30+ 分钟 + +### 4. 推理 + +```bash +python -m src.cli.infer \ + --model runs/train/invoice_fields/weights/best.pt \ + --input path/to/invoice.pdf \ + --output result.json +``` + +## 输出示例 + +```json +{ + "DocumentId": "3be53fd7-d5ea-458c-a229-8d360b8ba6a9", + "InvoiceNumber": "100017500321", + "InvoiceDate": "2025-12-13", + "InvoiceDueDate": "2026-01-03", + "OCR": "100017500321", + "Bankgiro": "5393-9484", + "Plusgiro": null, + "Amount": "114.00", + "confidence": { + "InvoiceNumber": 0.96, + "InvoiceDate": 0.92, + "Amount": 0.93 + } +} +``` + +## 项目结构 + +``` +invoice-master-poc-v2/ +├── src/ +│ ├── pdf/ # PDF 处理模块 +│ ├── ocr/ # OCR 提取模块 +│ ├── normalize/ # 字段规范化模块 +│ ├── matcher/ # 字段匹配模块 +│ ├── yolo/ # YOLO 标注生成 +│ ├── inference/ # 推理管道 +│ ├── data/ # 数据加载模块 +│ └── cli/ # 命令行工具 +├── configs/ # 配置文件 +├── data/ # 数据目录 +└── requirements.txt +``` + +## 开发优先级 + +1. ✅ 文本层 PDF 自动标注 +2. ✅ 扫描图 OCR 自动标注 +3. 🔄 金额 / OCR / Bankgiro 三字段稳定 +4. ⏳ 日期、Plusgiro 扩展 +5. ⏳ 表格 items 处理 + +## 配置 + +编辑 `configs/default.yaml` 自定义: +- PDF 渲染 DPI +- OCR 语言 +- 匹配置信度阈值 +- 上下文关键词 +- 数据增强参数 + +## API 使用 + +```python +from src.inference import InferencePipeline + +# 初始化 +pipeline = InferencePipeline( + model_path='models/best.pt', + confidence_threshold=0.5, + ocr_lang='en' +) + +# 处理 PDF +result = pipeline.process_pdf('invoice.pdf') + +# 获取字段 +print(result.fields) +print(result.confidence) +``` + +## 许可证 + +MIT License diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..b2c81ff --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,129 @@ +# Invoice Master POC v2 Configuration +# Default configuration for invoice field extraction system + +# PDF Processing +pdf: + dpi: 300 # Resolution for rendering PDFs to images + min_text_chars: 30 # Minimum chars to consider PDF as text-based + +# OCR Settings +ocr: + engine: paddleocr # OCR engine to use + lang: en # Language code (en, sv, ch, etc.) + use_gpu: false # Enable GPU acceleration + +# Field Normalization +normalize: + # Bankgiro formats + bankgiro: + format: "XXXX-XXXX" # Standard 8-digit format + alternatives: + - "XXX-XXXX" # 7-digit format + + # Plusgiro formats + plusgiro: + format: "XXXXXXX-X" # Standard format with check digit + + # Amount formats + amount: + decimal_separator: "," # Swedish uses comma + thousand_separator: " " # Space for thousands + currency_symbols: + - "SEK" + - "kr" + + # Date formats + date: + output_format: "%Y-%m-%d" + input_formats: + - "%Y-%m-%d" + - "%Y-%m-%d %H:%M:%S" + - "%d/%m/%Y" + - "%d.%m.%Y" + +# Field Matching +matching: + min_score_threshold: 0.7 # Minimum score to accept match + context_radius: 100 # Pixels to search for context keywords + + # Context keywords for each field (Swedish) + context_keywords: + InvoiceNumber: + - "fakturanr" + - "fakturanummer" + - "invoice" + InvoiceDate: + - "fakturadatum" + - "datum" + InvoiceDueDate: + - "förfallodatum" + - "förfaller" + - "betalas senast" + OCR: + - "ocr" + - "referens" + Bankgiro: + - "bankgiro" + - "bg" + Plusgiro: + - "plusgiro" + - "pg" + Amount: + - "att betala" + - "summa" + - "total" + - "belopp" + +# YOLO Training +yolo: + model: yolov8s # Model architecture (yolov8n/s/m/l/x) + epochs: 100 + batch_size: 16 + img_size: 1280 # Image size for training + + # Data augmentation + augmentation: + rotation: 5 # Max rotation degrees + scale: 0.2 # Scale variation + mosaic: 0.0 # Disable mosaic for documents + hsv_h: 0.0 # No hue variation + hsv_s: 0.1 # Slight saturation variation + hsv_v: 0.2 # Brightness variation + + # Class definitions + classes: + 0: invoice_number + 1: invoice_date + 2: invoice_due_date + 3: ocr_number + 4: bankgiro + 5: plusgiro + 6: amount + +# Auto-labeling +autolabel: + min_confidence: 0.7 # Minimum score to include in training + bbox_padding: 0.02 # Padding around bboxes (fraction of image) + +# Dataset Split +dataset: + train_ratio: 0.8 + val_ratio: 0.1 + test_ratio: 0.1 + random_seed: 42 + +# Inference +inference: + confidence_threshold: 0.5 # Detection confidence threshold + iou_threshold: 0.45 # NMS IOU threshold + enable_fallback: true # Enable regex fallback if YOLO fails + fallback_min_missing: 2 # Min missing fields to trigger fallback + +# Paths (relative to project root) +paths: + raw_pdfs: data/raw_pdfs + images: data/images + labels: data/labels + structured_data: data/structured_data + models: models + reports: reports diff --git a/configs/training.yaml b/configs/training.yaml new file mode 100644 index 0000000..4eb2a89 --- /dev/null +++ b/configs/training.yaml @@ -0,0 +1,59 @@ +# YOLO Training Configuration +# Use with: yolo train data=dataset.yaml cfg=training.yaml + +# Model +model: yolov8s.pt + +# Training hyperparameters +epochs: 100 +patience: 20 # Early stopping patience +batch: 16 +imgsz: 1280 + +# Optimizer +optimizer: AdamW +lr0: 0.001 # Initial learning rate +lrf: 0.01 # Final learning rate factor +momentum: 0.937 +weight_decay: 0.0005 + +# Warmup +warmup_epochs: 3 +warmup_momentum: 0.8 +warmup_bias_lr: 0.1 + +# Loss weights +box: 7.5 # Box loss gain +cls: 0.5 # Class loss gain +dfl: 1.5 # DFL loss gain + +# Augmentation +# Keep minimal for document images +hsv_h: 0.0 # No hue augmentation +hsv_s: 0.1 # Slight saturation +hsv_v: 0.2 # Brightness variation +degrees: 5.0 # Rotation ±5° +translate: 0.05 # Translation +scale: 0.2 # Scale ±20% +shear: 0.0 # No shear +perspective: 0.0 # No perspective +flipud: 0.0 # No vertical flip +fliplr: 0.0 # No horizontal flip +mosaic: 0.0 # Disable mosaic (not suitable for documents) +mixup: 0.0 # Disable mixup +copy_paste: 0.0 # Disable copy-paste + +# Validation +val: true +save: true +save_period: 10 +cache: true + +# Other +device: 0 # GPU device (0, 1, etc.) or 'cpu' +workers: 8 +project: runs/train +name: invoice_fields +exist_ok: true +pretrained: true +verbose: true diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2165d00 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,78 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "invoice-master" +version = "2.0.0" +description = "Automatic invoice information extraction using YOLO + OCR" +readme = "README.md" +requires-python = ">=3.10" +license = {text = "MIT"} +authors = [ + {name = "Invoice Master Team"} +] +keywords = ["invoice", "ocr", "yolo", "document-processing", "pdf"] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] + +dependencies = [ + "PyMuPDF>=1.23.0", + "paddlepaddle>=2.5.0", + "paddleocr>=2.7.0", + "ultralytics>=8.1.0", + "Pillow>=10.0.0", + "numpy>=1.24.0", + "opencv-python>=4.8.0", + "pyyaml>=6.0", + "tqdm>=4.65.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "ruff>=0.1.0", + "mypy>=1.0.0", +] +gpu = [ + "paddlepaddle-gpu>=2.5.0", +] + +[project.scripts] +invoice-autolabel = "src.cli.autolabel:main" +invoice-train = "src.cli.train:main" +invoice-infer = "src.cli.infer:main" + +[tool.setuptools.packages.find] +where = ["."] +include = ["src*"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = ["E", "F", "W", "I", "N", "D", "UP", "B", "C4", "SIM"] +ignore = ["D100", "D104"] + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +addopts = "-v --cov=src --cov-report=term-missing" diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..2d95467 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +# Invoice Master POC v2 - Dependencies + +# PDF Processing +PyMuPDF>=1.23.0 # PDF rendering and text extraction + +# OCR +paddlepaddle>=2.5.0 # PaddlePaddle framework +paddleocr>=2.7.0 # PaddleOCR + +# YOLO +ultralytics>=8.1.0 # YOLOv8/v11 + +# Image Processing +Pillow>=10.0.0 # Image handling +numpy>=1.24.0 # Array operations +opencv-python>=4.8.0 # Image processing + +# Data Processing +pyyaml>=6.0 # YAML config files + +# Utilities +tqdm>=4.65.0 # Progress bars diff --git a/run_autolabel.py b/run_autolabel.py new file mode 100644 index 0000000..40cc97a --- /dev/null +++ b/run_autolabel.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 +""" +自动标注脚本 - 调用 CLI 模块 +在 WSL 中运行: python run_autolabel.py +""" + +from src.cli.autolabel import main + +if __name__ == '__main__': + main() diff --git a/scripts/run_autolabel.sh b/scripts/run_autolabel.sh new file mode 100644 index 0000000..bcde5a1 --- /dev/null +++ b/scripts/run_autolabel.sh @@ -0,0 +1,55 @@ +#!/bin/bash +# 自动标注运行脚本 +# 使用方法: bash scripts/run_autolabel.sh + +set -e + +# 项目根目录 +PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$PROJECT_DIR" + +# 激活虚拟环境 +if [ -f "venv/bin/activate" ]; then + source venv/bin/activate +else + echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh" + exit 1 +fi + +# 默认参数 +CSV_FILE="${CSV_FILE:-data/structured_data/invoices.csv}" +PDF_DIR="${PDF_DIR:-data/raw_pdfs}" +OUTPUT_DIR="${OUTPUT_DIR:-data/dataset}" +REPORT_FILE="${REPORT_FILE:-reports/autolabel_report.jsonl}" +DPI="${DPI:-300}" +MIN_CONFIDENCE="${MIN_CONFIDENCE:-0.7}" + +# 显示配置 +echo "==========================================" +echo "自动标注配置" +echo "==========================================" +echo "CSV 文件: $CSV_FILE" +echo "PDF 目录: $PDF_DIR" +echo "输出目录: $OUTPUT_DIR" +echo "报告文件: $REPORT_FILE" +echo "DPI: $DPI" +echo "最小置信度: $MIN_CONFIDENCE" +echo "==========================================" +echo "" + +# 创建必要目录 +mkdir -p "$(dirname "$REPORT_FILE")" +mkdir -p "$OUTPUT_DIR" + +# 运行自动标注 +python -m src.cli.autolabel \ + --csv "$CSV_FILE" \ + --pdf-dir "$PDF_DIR" \ + --output "$OUTPUT_DIR" \ + --report "$REPORT_FILE" \ + --dpi "$DPI" \ + --min-confidence "$MIN_CONFIDENCE" \ + --verbose + +echo "" +echo "完成! 数据集已生成到: $OUTPUT_DIR" diff --git a/scripts/run_train.sh b/scripts/run_train.sh new file mode 100644 index 0000000..22a778d --- /dev/null +++ b/scripts/run_train.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# 训练运行脚本 +# 使用方法: bash scripts/run_train.sh + +set -e + +# 项目根目录 +PROJECT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" +cd "$PROJECT_DIR" + +# 激活虚拟环境 +if [ -f "venv/bin/activate" ]; then + source venv/bin/activate +else + echo "错误: 虚拟环境不存在,请先运行 setup_wsl.sh" + exit 1 +fi + +# 默认参数 +DATA_YAML="${DATA_YAML:-data/dataset/dataset.yaml}" +MODEL="${MODEL:-yolov8s.pt}" +EPOCHS="${EPOCHS:-100}" +BATCH_SIZE="${BATCH_SIZE:-16}" +IMG_SIZE="${IMG_SIZE:-1280}" +DEVICE="${DEVICE:-0}" + +# 检查数据集是否存在 +if [ ! -f "$DATA_YAML" ]; then + echo "错误: 数据集配置文件不存在: $DATA_YAML" + echo "请先运行自动标注: bash scripts/run_autolabel.sh" + exit 1 +fi + +# 显示配置 +echo "==========================================" +echo "训练配置" +echo "==========================================" +echo "数据集: $DATA_YAML" +echo "基础模型: $MODEL" +echo "Epochs: $EPOCHS" +echo "Batch Size: $BATCH_SIZE" +echo "图像尺寸: $IMG_SIZE" +echo "设备: $DEVICE" +echo "==========================================" +echo "" + +# 检查 GPU +if command -v nvidia-smi &> /dev/null; then + echo "GPU 状态:" + nvidia-smi --query-gpu=name,memory.used,memory.total --format=csv,noheader + echo "" +else + echo "警告: 未检测到 GPU,将使用 CPU 训练 (较慢)" + DEVICE="cpu" +fi + +# 运行训练 +python -m src.cli.train \ + --data "$DATA_YAML" \ + --model "$MODEL" \ + --epochs "$EPOCHS" \ + --batch "$BATCH_SIZE" \ + --imgsz "$IMG_SIZE" \ + --device "$DEVICE" + +echo "" +echo "训练完成! 模型保存在: runs/train/invoice_fields/weights/" diff --git a/scripts/setup_wsl.sh b/scripts/setup_wsl.sh new file mode 100644 index 0000000..9bd9e24 --- /dev/null +++ b/scripts/setup_wsl.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# WSL 环境安装脚本 +# 使用方法: bash scripts/setup_wsl.sh + +set -e + +echo "==========================================" +echo "Invoice Master POC v2 - WSL 安装脚本" +echo "==========================================" + +# 检查是否在 WSL 中运行 +if ! grep -qi microsoft /proc/version 2>/dev/null; then + echo "警告: 未检测到 WSL 环境,请在 WSL 中运行此脚本" + echo "提示: 在 Windows 终端中输入 'wsl' 进入 WSL" + exit 1 +fi + +echo "" +echo "[1/5] 更新系统包..." +sudo apt update + +echo "" +echo "[2/5] 安装系统依赖..." +sudo apt install -y \ + python3.10 \ + python3.10-venv \ + python3-pip \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxrender1 \ + libxext6 \ + libgomp1 + +echo "" +echo "[3/5] 创建 Python 虚拟环境..." +if [ -d "venv" ]; then + echo "虚拟环境已存在,跳过创建" +else + python3 -m venv venv +fi + +echo "" +echo "[4/5] 激活虚拟环境并安装依赖..." +source venv/bin/activate +pip install --upgrade pip + +echo "" +echo "安装 Python 依赖包..." +pip install -r requirements.txt + +echo "" +echo "[5/5] 验证安装..." +python3 -c "import fitz; print(f'PyMuPDF: {fitz.version}')" +python3 -c "from ultralytics import YOLO; print('Ultralytics: OK')" +python3 -c "from paddleocr import PaddleOCR; print('PaddleOCR: OK')" + +echo "" +echo "==========================================" +echo "安装完成!" +echo "==========================================" +echo "" +echo "使用方法:" +echo " 1. 激活虚拟环境: source venv/bin/activate" +echo " 2. 运行自动标注: python -m src.cli.autolabel --help" +echo " 3. 训练模型: python -m src.cli.train --help" +echo " 4. 推理: python -m src.cli.infer --help" +echo "" + +# 检查 GPU +echo "检查 GPU 支持..." +if command -v nvidia-smi &> /dev/null; then + echo "检测到 NVIDIA GPU:" + nvidia-smi --query-gpu=name,memory.total --format=csv,noheader + echo "" + echo "提示: 运行以下命令启用 GPU 加速:" + echo " pip install paddlepaddle-gpu" +else + echo "未检测到 GPU,将使用 CPU 模式" +fi diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..2929ae9 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,2 @@ +# Invoice Master POC v2 +# Automatic invoice information extraction system using YOLO + OCR diff --git a/src/cli/__init__.py b/src/cli/__init__.py new file mode 100644 index 0000000..a24f998 --- /dev/null +++ b/src/cli/__init__.py @@ -0,0 +1 @@ +# CLI modules for Invoice Master diff --git a/src/cli/autolabel.py b/src/cli/autolabel.py new file mode 100644 index 0000000..cdad5cf --- /dev/null +++ b/src/cli/autolabel.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +""" +Auto-labeling CLI + +Generates YOLO training data from PDFs and structured CSV data. +""" + +import argparse +import sys +import time +from pathlib import Path +from tqdm import tqdm +from concurrent.futures import ProcessPoolExecutor, as_completed +import multiprocessing + +# Global OCR engine for worker processes (initialized once per worker) +_worker_ocr_engine = None + + +def _init_worker(): + """Initialize worker process with OCR engine (called once per worker).""" + global _worker_ocr_engine + # OCR engine will be lazily initialized on first use + _worker_ocr_engine = None + + +def _get_ocr_engine(): + """Get or create OCR engine for current worker.""" + global _worker_ocr_engine + if _worker_ocr_engine is None: + from ..ocr import OCREngine + _worker_ocr_engine = OCREngine() + return _worker_ocr_engine + + +def process_single_document(args_tuple): + """ + Process a single document (worker function for parallel processing). + + Args: + args_tuple: (row_dict, pdf_path, output_dir, dpi, min_confidence, skip_ocr) + + Returns: + dict with results + """ + row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = args_tuple + + # Import inside worker to avoid pickling issues + from ..data import AutoLabelReport, FieldMatchResult + from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens + from ..pdf.renderer import get_render_dimensions + from ..matcher import FieldMatcher + from ..normalize import normalize_field + from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES + + start_time = time.time() + pdf_path = Path(pdf_path_str) + output_dir = Path(output_dir_str) + doc_id = row_dict['DocumentId'] + + report = AutoLabelReport(document_id=doc_id) + report.pdf_path = str(pdf_path) + + result = { + 'doc_id': doc_id, + 'success': False, + 'pages': [], + 'report': None, + 'stats': {name: 0 for name in FIELD_CLASSES.keys()} + } + + try: + # Check PDF type + use_ocr = not is_text_pdf(pdf_path) + report.pdf_type = "scanned" if use_ocr else "text" + + # Skip OCR if requested + if use_ocr and skip_ocr: + report.errors.append("Skipped (scanned PDF)") + report.processing_time_ms = (time.time() - start_time) * 1000 + result['report'] = report.to_dict() + return result + + # Get OCR engine from worker cache (only created once per worker) + ocr_engine = None + if use_ocr: + ocr_engine = _get_ocr_engine() + + generator = AnnotationGenerator(min_confidence=min_confidence) + matcher = FieldMatcher() + + # Process each page + page_annotations = [] + + for page_no, image_path in render_pdf_to_images( + pdf_path, + output_dir / 'temp' / doc_id / 'images', + dpi=dpi + ): + report.total_pages += 1 + img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi) + + # Extract tokens + if use_ocr: + tokens = ocr_engine.extract_from_image(str(image_path), page_no) + else: + tokens = list(extract_text_tokens(pdf_path, page_no)) + + # Match fields + matches = {} + for field_name in FIELD_CLASSES.keys(): + value = row_dict.get(field_name) + if not value: + continue + + normalized = normalize_field(field_name, str(value)) + field_matches = matcher.find_matches(tokens, field_name, normalized, page_no) + + # Record result + if field_matches: + best = field_matches[0] + matches[field_name] = field_matches + report.add_field_result(FieldMatchResult( + field_name=field_name, + csv_value=str(value), + matched=True, + score=best.score, + matched_text=best.matched_text, + candidate_used=best.value, + bbox=best.bbox, + page_no=page_no, + context_keywords=best.context_keywords + )) + else: + report.add_field_result(FieldMatchResult( + field_name=field_name, + csv_value=str(value), + matched=False, + page_no=page_no + )) + + # Generate annotations + annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) + + if annotations: + label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt" + generator.save_annotations(annotations, label_path) + page_annotations.append({ + 'image_path': str(image_path), + 'label_path': str(label_path), + 'count': len(annotations) + }) + + report.annotations_generated += len(annotations) + for ann in annotations: + class_name = list(FIELD_CLASSES.keys())[ann.class_id] + result['stats'][class_name] += 1 + + if page_annotations: + result['pages'] = page_annotations + result['success'] = True + report.success = True + else: + report.errors.append("No annotations generated") + + except Exception as e: + report.errors.append(str(e)) + + report.processing_time_ms = (time.time() - start_time) * 1000 + result['report'] = report.to_dict() + + return result + + +def main(): + parser = argparse.ArgumentParser( + description='Generate YOLO annotations from PDFs and CSV data' + ) + parser.add_argument( + '--csv', '-c', + default='data/structured_data/document_export_20260109_212743.csv', + help='Path to structured data CSV file' + ) + parser.add_argument( + '--pdf-dir', '-p', + default='data/raw_pdfs', + help='Directory containing PDF files' + ) + parser.add_argument( + '--output', '-o', + default='data/dataset', + help='Output directory for dataset' + ) + parser.add_argument( + '--dpi', + type=int, + default=300, + help='DPI for PDF rendering (default: 300)' + ) + parser.add_argument( + '--min-confidence', + type=float, + default=0.7, + help='Minimum match confidence (default: 0.7)' + ) + parser.add_argument( + '--train-ratio', + type=float, + default=0.8, + help='Training set ratio (default: 0.8)' + ) + parser.add_argument( + '--val-ratio', + type=float, + default=0.1, + help='Validation set ratio (default: 0.1)' + ) + parser.add_argument( + '--report', + default='reports/autolabel_report.jsonl', + help='Path for auto-label report (JSONL)' + ) + parser.add_argument( + '--single', + help='Process single document ID only' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Verbose output' + ) + parser.add_argument( + '--workers', '-w', + type=int, + default=4, + help='Number of parallel workers (default: 4)' + ) + parser.add_argument( + '--skip-ocr', + action='store_true', + help='Skip scanned PDFs (text-layer only)' + ) + + args = parser.parse_args() + + # Import here to avoid slow startup + from ..data import CSVLoader, AutoLabelReport, FieldMatchResult + from ..data.autolabel_report import ReportWriter + from ..yolo import DatasetBuilder + from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens + from ..pdf.renderer import get_render_dimensions + from ..ocr import OCREngine + from ..matcher import FieldMatcher + from ..normalize import normalize_field + from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES + + print(f"Loading CSV data from: {args.csv}") + loader = CSVLoader(args.csv, args.pdf_dir) + + # Validate data + issues = loader.validate() + if issues: + print(f"Warning: Found {len(issues)} validation issues") + if args.verbose: + for issue in issues[:10]: + print(f" - {issue}") + + rows = loader.load_all() + print(f"Loaded {len(rows)} invoice records") + + # Filter to single document if specified + if args.single: + rows = [r for r in rows if r.DocumentId == args.single] + if not rows: + print(f"Error: Document {args.single} not found") + sys.exit(1) + print(f"Processing single document: {args.single}") + + # Setup output directories + output_dir = Path(args.output) + for split in ['train', 'val', 'test']: + (output_dir / split / 'images').mkdir(parents=True, exist_ok=True) + (output_dir / split / 'labels').mkdir(parents=True, exist_ok=True) + + # Generate YOLO config files + AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt') + AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml') + + # Report writer + report_path = Path(args.report) + report_path.parent.mkdir(parents=True, exist_ok=True) + report_writer = ReportWriter(args.report) + + # Stats + stats = { + 'total': len(rows), + 'successful': 0, + 'failed': 0, + 'skipped': 0, + 'annotations': 0, + 'by_field': {name: 0 for name in FIELD_CLASSES.keys()} + } + + # Prepare tasks + tasks = [] + for row in rows: + pdf_path = loader.get_pdf_path(row) + if not pdf_path: + # Write report for missing PDF + report = AutoLabelReport(document_id=row.DocumentId) + report.errors.append("PDF not found") + report_writer.write(report) + stats['failed'] += 1 + continue + + # Convert row to dict for pickling + row_dict = { + 'DocumentId': row.DocumentId, + 'InvoiceNumber': row.InvoiceNumber, + 'InvoiceDate': row.InvoiceDate, + 'InvoiceDueDate': row.InvoiceDueDate, + 'OCR': row.OCR, + 'Bankgiro': row.Bankgiro, + 'Plusgiro': row.Plusgiro, + 'Amount': row.Amount, + } + + tasks.append(( + row_dict, + str(pdf_path), + str(output_dir), + args.dpi, + args.min_confidence, + args.skip_ocr + )) + + print(f"Processing {len(tasks)} documents with {args.workers} workers...") + + # Process documents in parallel + processed_items = [] + + # Use single process for debugging or when workers=1 + if args.workers == 1: + for task in tqdm(tasks, desc="Processing"): + result = process_single_document(task) + + # Write report + if result['report']: + report_writer.write_dict(result['report']) + + if result['success']: + processed_items.append({ + 'doc_id': result['doc_id'], + 'pages': result['pages'] + }) + stats['successful'] += 1 + for field, count in result['stats'].items(): + stats['by_field'][field] += count + stats['annotations'] += count + elif 'Skipped' in str(result.get('report', {}).get('errors', [])): + stats['skipped'] += 1 + else: + stats['failed'] += 1 + else: + # Parallel processing with worker initialization + # Each worker initializes OCR engine once and reuses it + with ProcessPoolExecutor(max_workers=args.workers, initializer=_init_worker) as executor: + futures = {executor.submit(process_single_document, task): task[0]['DocumentId'] + for task in tasks} + + for future in tqdm(as_completed(futures), total=len(futures), desc="Processing"): + doc_id = futures[future] + try: + result = future.result() + + # Write report + if result['report']: + report_writer.write_dict(result['report']) + + if result['success']: + processed_items.append({ + 'doc_id': result['doc_id'], + 'pages': result['pages'] + }) + stats['successful'] += 1 + for field, count in result['stats'].items(): + stats['by_field'][field] += count + stats['annotations'] += count + elif 'Skipped' in str(result.get('report', {}).get('errors', [])): + stats['skipped'] += 1 + else: + stats['failed'] += 1 + + except Exception as e: + stats['failed'] += 1 + # Write error report for failed documents + error_report = { + 'document_id': doc_id, + 'success': False, + 'errors': [f"Worker error: {str(e)}"] + } + report_writer.write_dict(error_report) + if args.verbose: + print(f"Error processing {doc_id}: {e}") + + # Split and move files + import random + random.seed(42) + random.shuffle(processed_items) + + n_train = int(len(processed_items) * args.train_ratio) + n_val = int(len(processed_items) * args.val_ratio) + + splits = { + 'train': processed_items[:n_train], + 'val': processed_items[n_train:n_train + n_val], + 'test': processed_items[n_train + n_val:] + } + + import shutil + for split_name, items in splits.items(): + for item in items: + for page in item['pages']: + # Move image + image_path = Path(page['image_path']) + label_path = Path(page['label_path']) + dest_img = output_dir / split_name / 'images' / image_path.name + shutil.move(str(image_path), str(dest_img)) + + # Move label + dest_label = output_dir / split_name / 'labels' / label_path.name + shutil.move(str(label_path), str(dest_label)) + + # Cleanup temp + shutil.rmtree(output_dir / 'temp', ignore_errors=True) + + # Print summary + print("\n" + "=" * 50) + print("Auto-labeling Complete") + print("=" * 50) + print(f"Total documents: {stats['total']}") + print(f"Successful: {stats['successful']}") + print(f"Failed: {stats['failed']}") + print(f"Skipped (OCR): {stats['skipped']}") + print(f"Total annotations: {stats['annotations']}") + print(f"\nDataset split:") + print(f" Train: {len(splits['train'])} documents") + print(f" Val: {len(splits['val'])} documents") + print(f" Test: {len(splits['test'])} documents") + print(f"\nAnnotations by field:") + for field, count in stats['by_field'].items(): + print(f" {field}: {count}") + print(f"\nOutput: {output_dir}") + print(f"Report: {args.report}") + + +if __name__ == '__main__': + main() diff --git a/src/cli/infer.py b/src/cli/infer.py new file mode 100644 index 0000000..7ec3396 --- /dev/null +++ b/src/cli/infer.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python3 +""" +Inference CLI + +Runs inference on new PDFs to extract invoice data. +""" + +import argparse +import json +import sys +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description='Extract invoice data from PDFs using trained model' + ) + parser.add_argument( + '--model', '-m', + required=True, + help='Path to trained YOLO model (.pt file)' + ) + parser.add_argument( + '--input', '-i', + required=True, + help='Input PDF file or directory' + ) + parser.add_argument( + '--output', '-o', + help='Output JSON file (default: stdout)' + ) + parser.add_argument( + '--confidence', + type=float, + default=0.5, + help='Detection confidence threshold (default: 0.5)' + ) + parser.add_argument( + '--dpi', + type=int, + default=300, + help='DPI for PDF rendering (default: 300)' + ) + parser.add_argument( + '--no-fallback', + action='store_true', + help='Disable fallback OCR' + ) + parser.add_argument( + '--lang', + default='en', + help='OCR language (default: en)' + ) + parser.add_argument( + '--gpu', + action='store_true', + help='Use GPU' + ) + parser.add_argument( + '--verbose', '-v', + action='store_true', + help='Verbose output' + ) + + args = parser.parse_args() + + # Validate model + model_path = Path(args.model) + if not model_path.exists(): + print(f"Error: Model not found: {model_path}", file=sys.stderr) + sys.exit(1) + + # Get input files + input_path = Path(args.input) + if input_path.is_file(): + pdf_files = [input_path] + elif input_path.is_dir(): + pdf_files = list(input_path.glob('*.pdf')) + else: + print(f"Error: Input not found: {input_path}", file=sys.stderr) + sys.exit(1) + + if not pdf_files: + print("Error: No PDF files found", file=sys.stderr) + sys.exit(1) + + if args.verbose: + print(f"Processing {len(pdf_files)} PDF file(s)") + print(f"Model: {model_path}") + + from ..inference import InferencePipeline + + # Initialize pipeline + pipeline = InferencePipeline( + model_path=model_path, + confidence_threshold=args.confidence, + ocr_lang=args.lang, + use_gpu=args.gpu, + dpi=args.dpi, + enable_fallback=not args.no_fallback + ) + + # Process files + results = [] + + for pdf_path in pdf_files: + if args.verbose: + print(f"Processing: {pdf_path.name}") + + result = pipeline.process_pdf(pdf_path) + results.append(result.to_json()) + + if args.verbose: + print(f" Success: {result.success}") + print(f" Fields: {len(result.fields)}") + if result.fallback_used: + print(f" Fallback used: Yes") + if result.errors: + print(f" Errors: {result.errors}") + + # Output results + if len(results) == 1: + output = results[0] + else: + output = results + + json_output = json.dumps(output, indent=2, ensure_ascii=False) + + if args.output: + with open(args.output, 'w', encoding='utf-8') as f: + f.write(json_output) + if args.verbose: + print(f"\nResults written to: {args.output}") + else: + print(json_output) + + +if __name__ == '__main__': + main() diff --git a/src/cli/train.py b/src/cli/train.py new file mode 100644 index 0000000..85528cd --- /dev/null +++ b/src/cli/train.py @@ -0,0 +1,138 @@ +#!/usr/bin/env python3 +""" +Training CLI + +Trains YOLO model on generated dataset. +""" + +import argparse +import sys +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description='Train YOLO model for invoice field detection' + ) + parser.add_argument( + '--data', '-d', + required=True, + help='Path to dataset.yaml file' + ) + parser.add_argument( + '--model', '-m', + default='yolov8s.pt', + help='Base model (default: yolov8s.pt)' + ) + parser.add_argument( + '--epochs', '-e', + type=int, + default=100, + help='Number of epochs (default: 100)' + ) + parser.add_argument( + '--batch', '-b', + type=int, + default=16, + help='Batch size (default: 16)' + ) + parser.add_argument( + '--imgsz', + type=int, + default=1280, + help='Image size (default: 1280)' + ) + parser.add_argument( + '--project', + default='runs/train', + help='Project directory (default: runs/train)' + ) + parser.add_argument( + '--name', + default='invoice_fields', + help='Run name (default: invoice_fields)' + ) + parser.add_argument( + '--device', + default='0', + help='Device (0, 1, cpu, mps)' + ) + parser.add_argument( + '--resume', + help='Resume from checkpoint' + ) + parser.add_argument( + '--config', + help='Path to training config YAML' + ) + + args = parser.parse_args() + + # Validate data file + data_path = Path(args.data) + if not data_path.exists(): + print(f"Error: Dataset file not found: {data_path}") + sys.exit(1) + + print(f"Training YOLO model for invoice field detection") + print(f"Dataset: {args.data}") + print(f"Model: {args.model}") + print(f"Epochs: {args.epochs}") + print(f"Batch size: {args.batch}") + print(f"Image size: {args.imgsz}") + + from ultralytics import YOLO + + # Load model + if args.resume: + print(f"Resuming from: {args.resume}") + model = YOLO(args.resume) + else: + model = YOLO(args.model) + + # Training arguments + train_args = { + 'data': str(data_path.absolute()), + 'epochs': args.epochs, + 'batch': args.batch, + 'imgsz': args.imgsz, + 'project': args.project, + 'name': args.name, + 'device': args.device, + 'exist_ok': True, + 'pretrained': True, + 'verbose': True, + # Document-specific augmentation settings + 'degrees': 5.0, + 'translate': 0.05, + 'scale': 0.2, + 'shear': 0.0, + 'perspective': 0.0, + 'flipud': 0.0, + 'fliplr': 0.0, + 'mosaic': 0.0, + 'mixup': 0.0, + 'hsv_h': 0.0, + 'hsv_s': 0.1, + 'hsv_v': 0.2, + } + + # Train + results = model.train(**train_args) + + # Print results + print("\n" + "=" * 50) + print("Training Complete") + print("=" * 50) + print(f"Best model: {args.project}/{args.name}/weights/best.pt") + print(f"Last model: {args.project}/{args.name}/weights/last.pt") + + # Validate on test set + print("\nRunning validation...") + metrics = model.val() + print(f"mAP50: {metrics.box.map50:.4f}") + print(f"mAP50-95: {metrics.box.map:.4f}") + + +if __name__ == '__main__': + main() diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..454510e --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,4 @@ +from .csv_loader import CSVLoader, InvoiceRow +from .autolabel_report import AutoLabelReport, FieldMatchResult + +__all__ = ['CSVLoader', 'InvoiceRow', 'AutoLabelReport', 'FieldMatchResult'] diff --git a/src/data/autolabel_report.py b/src/data/autolabel_report.py new file mode 100644 index 0000000..ebd7c3a --- /dev/null +++ b/src/data/autolabel_report.py @@ -0,0 +1,252 @@ +""" +Auto-Label Report Generator + +Generates quality control reports for auto-labeling process. +""" + +import json +from dataclasses import dataclass, field, asdict +from datetime import datetime +from pathlib import Path +from typing import Any + + +@dataclass +class FieldMatchResult: + """Result of matching a single field.""" + field_name: str + csv_value: str | None + matched: bool + score: float = 0.0 + matched_text: str | None = None + candidate_used: str | None = None # Which normalized variant matched + bbox: tuple[float, float, float, float] | None = None + page_no: int = 0 + context_keywords: list[str] = field(default_factory=list) + error: str | None = None + + def to_dict(self) -> dict: + """Convert to dictionary.""" + # Convert bbox to native Python floats to avoid numpy serialization issues + bbox_list = None + if self.bbox: + bbox_list = [float(x) for x in self.bbox] + + return { + 'field_name': self.field_name, + 'csv_value': self.csv_value, + 'matched': self.matched, + 'score': float(self.score) if self.score else 0.0, + 'matched_text': self.matched_text, + 'candidate_used': self.candidate_used, + 'bbox': bbox_list, + 'page_no': int(self.page_no) if self.page_no else 0, + 'context_keywords': self.context_keywords, + 'error': self.error + } + + +@dataclass +class AutoLabelReport: + """Report for a single document's auto-labeling process.""" + document_id: str + pdf_path: str | None = None + pdf_type: str | None = None # 'text' | 'scanned' | 'mixed' + success: bool = False + total_pages: int = 0 + fields_matched: int = 0 + fields_total: int = 0 + field_results: list[FieldMatchResult] = field(default_factory=list) + annotations_generated: int = 0 + image_paths: list[str] = field(default_factory=list) + label_paths: list[str] = field(default_factory=list) + processing_time_ms: float = 0.0 + timestamp: str = field(default_factory=lambda: datetime.now().isoformat()) + errors: list[str] = field(default_factory=list) + + def add_field_result(self, result: FieldMatchResult) -> None: + """Add a field matching result.""" + self.field_results.append(result) + self.fields_total += 1 + if result.matched: + self.fields_matched += 1 + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + 'document_id': self.document_id, + 'pdf_path': self.pdf_path, + 'pdf_type': self.pdf_type, + 'success': self.success, + 'total_pages': self.total_pages, + 'fields_matched': self.fields_matched, + 'fields_total': self.fields_total, + 'field_results': [r.to_dict() for r in self.field_results], + 'annotations_generated': self.annotations_generated, + 'image_paths': self.image_paths, + 'label_paths': self.label_paths, + 'processing_time_ms': self.processing_time_ms, + 'timestamp': self.timestamp, + 'errors': self.errors + } + + def to_json(self, indent: int | None = None) -> str: + """Convert to JSON string.""" + return json.dumps(self.to_dict(), indent=indent, ensure_ascii=False) + + @property + def match_rate(self) -> float: + """Calculate field match rate.""" + if self.fields_total == 0: + return 0.0 + return self.fields_matched / self.fields_total + + def get_summary(self) -> dict: + """Get a summary of the report.""" + return { + 'document_id': self.document_id, + 'success': self.success, + 'match_rate': f"{self.match_rate:.1%}", + 'fields': f"{self.fields_matched}/{self.fields_total}", + 'annotations': self.annotations_generated, + 'errors': len(self.errors) + } + + +class ReportWriter: + """Writes auto-label reports to file.""" + + def __init__(self, output_path: str | Path): + """ + Initialize report writer. + + Args: + output_path: Path to output JSONL file + """ + self.output_path = Path(output_path) + self.output_path.parent.mkdir(parents=True, exist_ok=True) + + def write(self, report: AutoLabelReport) -> None: + """Append a report to the output file.""" + with open(self.output_path, 'a', encoding='utf-8') as f: + f.write(report.to_json() + '\n') + + def write_dict(self, report_dict: dict) -> None: + """Append a report dict to the output file (for parallel processing).""" + import json + with open(self.output_path, 'a', encoding='utf-8') as f: + f.write(json.dumps(report_dict, ensure_ascii=False) + '\n') + f.flush() + + def write_batch(self, reports: list[AutoLabelReport]) -> None: + """Write multiple reports.""" + with open(self.output_path, 'a', encoding='utf-8') as f: + for report in reports: + f.write(report.to_json() + '\n') + + +class ReportReader: + """Reads auto-label reports from file.""" + + def __init__(self, input_path: str | Path): + """ + Initialize report reader. + + Args: + input_path: Path to input JSONL file + """ + self.input_path = Path(input_path) + + def read_all(self) -> list[AutoLabelReport]: + """Read all reports from file.""" + reports = [] + + if not self.input_path.exists(): + return reports + + with open(self.input_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if not line: + continue + + data = json.loads(line) + report = self._dict_to_report(data) + reports.append(report) + + return reports + + def _dict_to_report(self, data: dict) -> AutoLabelReport: + """Convert dictionary to AutoLabelReport.""" + field_results = [] + for fr_data in data.get('field_results', []): + bbox = tuple(fr_data['bbox']) if fr_data.get('bbox') else None + field_results.append(FieldMatchResult( + field_name=fr_data['field_name'], + csv_value=fr_data.get('csv_value'), + matched=fr_data.get('matched', False), + score=fr_data.get('score', 0.0), + matched_text=fr_data.get('matched_text'), + candidate_used=fr_data.get('candidate_used'), + bbox=bbox, + page_no=fr_data.get('page_no', 0), + context_keywords=fr_data.get('context_keywords', []), + error=fr_data.get('error') + )) + + return AutoLabelReport( + document_id=data['document_id'], + pdf_path=data.get('pdf_path'), + pdf_type=data.get('pdf_type'), + success=data.get('success', False), + total_pages=data.get('total_pages', 0), + fields_matched=data.get('fields_matched', 0), + fields_total=data.get('fields_total', 0), + field_results=field_results, + annotations_generated=data.get('annotations_generated', 0), + image_paths=data.get('image_paths', []), + label_paths=data.get('label_paths', []), + processing_time_ms=data.get('processing_time_ms', 0.0), + timestamp=data.get('timestamp', ''), + errors=data.get('errors', []) + ) + + def get_statistics(self) -> dict: + """Calculate statistics from all reports.""" + reports = self.read_all() + + if not reports: + return {'total': 0} + + successful = sum(1 for r in reports if r.success) + total_fields_matched = sum(r.fields_matched for r in reports) + total_fields = sum(r.fields_total for r in reports) + total_annotations = sum(r.annotations_generated for r in reports) + + # Per-field statistics + field_stats = {} + for report in reports: + for fr in report.field_results: + if fr.field_name not in field_stats: + field_stats[fr.field_name] = {'matched': 0, 'total': 0, 'avg_score': 0.0} + field_stats[fr.field_name]['total'] += 1 + if fr.matched: + field_stats[fr.field_name]['matched'] += 1 + field_stats[fr.field_name]['avg_score'] += fr.score + + # Calculate averages + for field_name, stats in field_stats.items(): + if stats['matched'] > 0: + stats['avg_score'] /= stats['matched'] + stats['match_rate'] = stats['matched'] / stats['total'] if stats['total'] > 0 else 0 + + return { + 'total': len(reports), + 'successful': successful, + 'success_rate': successful / len(reports), + 'total_fields_matched': total_fields_matched, + 'total_fields': total_fields, + 'overall_match_rate': total_fields_matched / total_fields if total_fields > 0 else 0, + 'total_annotations': total_annotations, + 'field_statistics': field_stats + } diff --git a/src/data/csv_loader.py b/src/data/csv_loader.py new file mode 100644 index 0000000..0c690f1 --- /dev/null +++ b/src/data/csv_loader.py @@ -0,0 +1,306 @@ +""" +CSV Data Loader + +Loads and parses structured invoice data from CSV files. +Follows the CSV specification for invoice data. +""" + +import csv +from dataclasses import dataclass, field +from datetime import datetime, date +from decimal import Decimal, InvalidOperation +from pathlib import Path +from typing import Any, Iterator + + +@dataclass +class InvoiceRow: + """Parsed invoice data row.""" + DocumentId: str + InvoiceDate: date | None = None + InvoiceNumber: str | None = None + InvoiceDueDate: date | None = None + OCR: str | None = None + Message: str | None = None + Bankgiro: str | None = None + Plusgiro: str | None = None + Amount: Decimal | None = None + + # Raw values for reference + raw_data: dict = field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for matching.""" + return { + 'DocumentId': self.DocumentId, + 'InvoiceDate': self.InvoiceDate.isoformat() if self.InvoiceDate else None, + 'InvoiceNumber': self.InvoiceNumber, + 'InvoiceDueDate': self.InvoiceDueDate.isoformat() if self.InvoiceDueDate else None, + 'OCR': self.OCR, + 'Bankgiro': self.Bankgiro, + 'Plusgiro': self.Plusgiro, + 'Amount': str(self.Amount) if self.Amount else None, + } + + def get_field_value(self, field_name: str) -> str | None: + """Get field value as string for matching.""" + value = getattr(self, field_name, None) + if value is None: + return None + if isinstance(value, date): + return value.isoformat() + if isinstance(value, Decimal): + return str(value) + return str(value) if value else None + + +class CSVLoader: + """Loads invoice data from CSV files.""" + + # Expected field mappings (CSV header -> InvoiceRow attribute) + FIELD_MAPPINGS = { + 'DocumentId': 'DocumentId', + 'InvoiceDate': 'InvoiceDate', + 'InvoiceNumber': 'InvoiceNumber', + 'InvoiceDueDate': 'InvoiceDueDate', + 'OCR': 'OCR', + 'Message': 'Message', + 'Bankgiro': 'Bankgiro', + 'Plusgiro': 'Plusgiro', + 'Amount': 'Amount', + } + + def __init__( + self, + csv_path: str | Path, + pdf_dir: str | Path | None = None, + doc_map_path: str | Path | None = None, + encoding: str = 'utf-8' + ): + """ + Initialize CSV loader. + + Args: + csv_path: Path to the CSV file + pdf_dir: Directory containing PDF files (default: data/raw_pdfs) + doc_map_path: Optional path to document mapping CSV + encoding: CSV file encoding (default: utf-8) + """ + self.csv_path = Path(csv_path) + self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs' + self.doc_map_path = Path(doc_map_path) if doc_map_path else None + self.encoding = encoding + + # Load document mapping if provided + self.doc_map = self._load_doc_map() if self.doc_map_path else {} + + def _load_doc_map(self) -> dict[str, str]: + """Load document ID to filename mapping.""" + mapping = {} + if self.doc_map_path and self.doc_map_path.exists(): + with open(self.doc_map_path, 'r', encoding=self.encoding) as f: + reader = csv.DictReader(f) + for row in reader: + doc_id = row.get('DocumentId', '').strip() + filename = row.get('FileName', '').strip() + if doc_id and filename: + mapping[doc_id] = filename + return mapping + + def _parse_date(self, value: str | None) -> date | None: + """Parse date from various formats.""" + if not value or not value.strip(): + return None + + value = value.strip() + + # Try different date formats + formats = [ + '%Y-%m-%d', + '%Y-%m-%d %H:%M:%S', + '%Y-%m-%d %H:%M:%S.%f', + '%d/%m/%Y', + '%d.%m.%Y', + '%d-%m-%Y', + '%Y%m%d', + ] + + for fmt in formats: + try: + return datetime.strptime(value, fmt).date() + except ValueError: + continue + + return None + + def _parse_amount(self, value: str | None) -> Decimal | None: + """Parse monetary amount from various formats.""" + if not value or not value.strip(): + return None + + value = value.strip() + + # Remove currency symbols and common suffixes + value = value.replace('SEK', '').replace('kr', '').replace(':-', '') + value = value.strip() + + # Remove spaces (thousand separators) + value = value.replace(' ', '').replace('\xa0', '') + + # Handle comma as decimal separator (European format) + if ',' in value and '.' not in value: + value = value.replace(',', '.') + elif ',' in value and '.' in value: + # Assume comma is thousands separator, dot is decimal + value = value.replace(',', '') + + try: + return Decimal(value) + except InvalidOperation: + return None + + def _parse_string(self, value: str | None) -> str | None: + """Parse string field with cleanup.""" + if value is None: + return None + value = value.strip() + return value if value else None + + def _parse_row(self, row: dict) -> InvoiceRow | None: + """Parse a single CSV row into InvoiceRow.""" + doc_id = self._parse_string(row.get('DocumentId')) + if not doc_id: + return None + + return InvoiceRow( + DocumentId=doc_id, + InvoiceDate=self._parse_date(row.get('InvoiceDate')), + InvoiceNumber=self._parse_string(row.get('InvoiceNumber')), + InvoiceDueDate=self._parse_date(row.get('InvoiceDueDate')), + OCR=self._parse_string(row.get('OCR')), + Message=self._parse_string(row.get('Message')), + Bankgiro=self._parse_string(row.get('Bankgiro')), + Plusgiro=self._parse_string(row.get('Plusgiro')), + Amount=self._parse_amount(row.get('Amount')), + raw_data=dict(row) + ) + + def load_all(self) -> list[InvoiceRow]: + """Load all rows from CSV.""" + rows = [] + for row in self.iter_rows(): + rows.append(row) + return rows + + def iter_rows(self) -> Iterator[InvoiceRow]: + """Iterate over CSV rows.""" + # Handle BOM - try utf-8-sig first to handle BOM correctly + encodings = ['utf-8-sig', self.encoding, 'latin-1'] + + for enc in encodings: + try: + with open(self.csv_path, 'r', encoding=enc) as f: + reader = csv.DictReader(f) + for row in reader: + parsed = self._parse_row(row) + if parsed: + yield parsed + return + except UnicodeDecodeError: + continue + + raise ValueError(f"Could not read CSV file with any supported encoding") + + def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None: + """ + Get PDF path for an invoice row. + + Uses document mapping if available, otherwise assumes + DocumentId.pdf naming convention. + """ + doc_id = invoice_row.DocumentId + + # Check document mapping first + if doc_id in self.doc_map: + filename = self.doc_map[doc_id] + pdf_path = self.pdf_dir / filename + if pdf_path.exists(): + return pdf_path + + # Try default naming patterns + patterns = [ + f"{doc_id}.pdf", + f"{doc_id.lower()}.pdf", + f"{doc_id.upper()}.pdf", + ] + + for pattern in patterns: + pdf_path = self.pdf_dir / pattern + if pdf_path.exists(): + return pdf_path + + # Try glob patterns for partial matches + for pdf_file in self.pdf_dir.glob(f"*{doc_id}*.pdf"): + return pdf_file + + return None + + def get_row_by_id(self, doc_id: str) -> InvoiceRow | None: + """Get a specific row by DocumentId.""" + for row in self.iter_rows(): + if row.DocumentId == doc_id: + return row + return None + + def validate(self) -> list[dict]: + """ + Validate CSV data and return issues. + + Returns: + List of validation issues + """ + issues = [] + + for i, row in enumerate(self.iter_rows(), start=2): # Start at 2 (header is row 1) + # Check required DocumentId + if not row.DocumentId: + issues.append({ + 'row': i, + 'field': 'DocumentId', + 'issue': 'Missing required DocumentId' + }) + continue + + # Check if PDF exists + pdf_path = self.get_pdf_path(row) + if not pdf_path: + issues.append({ + 'row': i, + 'doc_id': row.DocumentId, + 'field': 'PDF', + 'issue': 'PDF file not found' + }) + + # Check for at least one matchable field + matchable_fields = [ + row.InvoiceNumber, + row.OCR, + row.Bankgiro, + row.Plusgiro, + row.Amount + ] + if not any(matchable_fields): + issues.append({ + 'row': i, + 'doc_id': row.DocumentId, + 'field': 'All', + 'issue': 'No matchable fields (InvoiceNumber/OCR/Bankgiro/Plusgiro/Amount)' + }) + + return issues + + +def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]: + """Convenience function to load invoice CSV.""" + loader = CSVLoader(csv_path, pdf_dir) + return loader.load_all() diff --git a/src/inference/__init__.py b/src/inference/__init__.py new file mode 100644 index 0000000..cb32852 --- /dev/null +++ b/src/inference/__init__.py @@ -0,0 +1,5 @@ +from .pipeline import InferencePipeline, InferenceResult +from .yolo_detector import YOLODetector, Detection +from .field_extractor import FieldExtractor + +__all__ = ['InferencePipeline', 'InferenceResult', 'YOLODetector', 'Detection', 'FieldExtractor'] diff --git a/src/inference/field_extractor.py b/src/inference/field_extractor.py new file mode 100644 index 0000000..00b7395 --- /dev/null +++ b/src/inference/field_extractor.py @@ -0,0 +1,382 @@ +""" +Field Extractor Module + +Extracts and validates field values from detected regions. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import re +import numpy as np +from PIL import Image + +from .yolo_detector import Detection, CLASS_TO_FIELD + + +@dataclass +class ExtractedField: + """Represents an extracted field value.""" + field_name: str + raw_text: str + normalized_value: str | None + confidence: float + detection_confidence: float + ocr_confidence: float + bbox: tuple[float, float, float, float] + page_no: int + is_valid: bool = True + validation_error: str | None = None + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + 'field_name': self.field_name, + 'value': self.normalized_value, + 'raw_text': self.raw_text, + 'confidence': self.confidence, + 'bbox': list(self.bbox), + 'page_no': self.page_no, + 'is_valid': self.is_valid, + 'validation_error': self.validation_error + } + + +class FieldExtractor: + """Extracts field values from detected regions using OCR or PDF text.""" + + def __init__( + self, + ocr_lang: str = 'en', + use_gpu: bool = False, + bbox_padding: float = 0.1, + dpi: int = 300 + ): + """ + Initialize field extractor. + + Args: + ocr_lang: Language for OCR + use_gpu: Whether to use GPU for OCR + bbox_padding: Padding to add around bboxes (as fraction) + dpi: DPI used for rendering (for coordinate conversion) + """ + self.ocr_lang = ocr_lang + self.use_gpu = use_gpu + self.bbox_padding = bbox_padding + self.dpi = dpi + self._ocr_engine = None # Lazy init + + @property + def ocr_engine(self): + """Lazy-load OCR engine only when needed.""" + if self._ocr_engine is None: + from ..ocr import OCREngine + self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu) + return self._ocr_engine + + def extract_from_detection_with_pdf( + self, + detection: Detection, + pdf_tokens: list, + image_width: int, + image_height: int + ) -> ExtractedField: + """ + Extract field value using PDF text tokens (faster and more accurate for text PDFs). + + Args: + detection: Detection object with bbox in pixel coordinates + pdf_tokens: List of Token objects from PDF text extraction + image_width: Width of rendered image in pixels + image_height: Height of rendered image in pixels + + Returns: + ExtractedField object + """ + # Convert detection bbox from pixels to PDF points + scale = 72 / self.dpi # points per pixel + x0_pdf = detection.bbox[0] * scale + y0_pdf = detection.bbox[1] * scale + x1_pdf = detection.bbox[2] * scale + y1_pdf = detection.bbox[3] * scale + + # Add padding in points + pad = 3 # Small padding in points + + # Find tokens that overlap with detection bbox + matching_tokens = [] + for token in pdf_tokens: + if token.page_no != detection.page_no: + continue + tx0, ty0, tx1, ty1 = token.bbox + # Check overlap + if (tx0 < x1_pdf + pad and tx1 > x0_pdf - pad and + ty0 < y1_pdf + pad and ty1 > y0_pdf - pad): + # Calculate overlap ratio to prioritize better matches + overlap_x = min(tx1, x1_pdf) - max(tx0, x0_pdf) + overlap_y = min(ty1, y1_pdf) - max(ty0, y0_pdf) + if overlap_x > 0 and overlap_y > 0: + token_area = (tx1 - tx0) * (ty1 - ty0) + overlap_area = overlap_x * overlap_y + overlap_ratio = overlap_area / token_area if token_area > 0 else 0 + matching_tokens.append((token, overlap_ratio)) + + # Sort by overlap ratio and combine text + matching_tokens.sort(key=lambda x: -x[1]) + raw_text = ' '.join(t[0].text for t in matching_tokens) + + # Get field name + field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name) + + # Normalize and validate + normalized_value, is_valid, validation_error = self._normalize_and_validate( + field_name, raw_text + ) + + return ExtractedField( + field_name=field_name, + raw_text=raw_text, + normalized_value=normalized_value, + confidence=detection.confidence if normalized_value else detection.confidence * 0.5, + detection_confidence=detection.confidence, + ocr_confidence=1.0, # PDF text is always accurate + bbox=detection.bbox, + page_no=detection.page_no, + is_valid=is_valid, + validation_error=validation_error + ) + + def extract_from_detection( + self, + detection: Detection, + image: np.ndarray | Image.Image + ) -> ExtractedField: + """ + Extract field value from a detection region using OCR. + + Args: + detection: Detection object + image: Full page image + + Returns: + ExtractedField object + """ + if isinstance(image, Image.Image): + image = np.array(image) + + # Get padded bbox + h, w = image.shape[:2] + bbox = detection.get_padded_bbox(self.bbox_padding, w, h) + + # Crop region + x0, y0, x1, y1 = [int(v) for v in bbox] + region = image[y0:y1, x0:x1] + + # Run OCR on region + ocr_tokens = self.ocr_engine.extract_from_image(region) + + # Combine all OCR text + raw_text = ' '.join(t.text for t in ocr_tokens) + ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0 + + # Get field name + field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name) + + # Normalize and validate + normalized_value, is_valid, validation_error = self._normalize_and_validate( + field_name, raw_text + ) + + # Combined confidence + confidence = (detection.confidence + ocr_confidence) / 2 if ocr_tokens else detection.confidence * 0.5 + + return ExtractedField( + field_name=field_name, + raw_text=raw_text, + normalized_value=normalized_value, + confidence=confidence, + detection_confidence=detection.confidence, + ocr_confidence=ocr_confidence, + bbox=detection.bbox, + page_no=detection.page_no, + is_valid=is_valid, + validation_error=validation_error + ) + + def _normalize_and_validate( + self, + field_name: str, + raw_text: str + ) -> tuple[str | None, bool, str | None]: + """ + Normalize and validate extracted text for a field. + + Returns: + (normalized_value, is_valid, validation_error) + """ + text = raw_text.strip() + + if not text: + return None, False, "Empty text" + + if field_name == 'InvoiceNumber': + return self._normalize_invoice_number(text) + + elif field_name == 'OCR': + return self._normalize_ocr_number(text) + + elif field_name == 'Bankgiro': + return self._normalize_bankgiro(text) + + elif field_name == 'Plusgiro': + return self._normalize_plusgiro(text) + + elif field_name == 'Amount': + return self._normalize_amount(text) + + elif field_name in ('InvoiceDate', 'InvoiceDueDate'): + return self._normalize_date(text) + + else: + return text, True, None + + def _normalize_invoice_number(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize invoice number.""" + # Extract digits only + digits = re.sub(r'\D', '', text) + + if len(digits) < 3: + return None, False, f"Too few digits: {len(digits)}" + + return digits, True, None + + def _normalize_ocr_number(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize OCR number.""" + digits = re.sub(r'\D', '', text) + + if len(digits) < 5: + return None, False, f"Too few digits for OCR: {len(digits)}" + + return digits, True, None + + def _normalize_bankgiro(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize Bankgiro number.""" + digits = re.sub(r'\D', '', text) + + if len(digits) == 8: + # Format as XXXX-XXXX + formatted = f"{digits[:4]}-{digits[4:]}" + return formatted, True, None + elif len(digits) == 7: + # Format as XXX-XXXX + formatted = f"{digits[:3]}-{digits[3:]}" + return formatted, True, None + elif 6 <= len(digits) <= 9: + return digits, True, None + else: + return None, False, f"Invalid Bankgiro length: {len(digits)}" + + def _normalize_plusgiro(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize Plusgiro number.""" + digits = re.sub(r'\D', '', text) + + if len(digits) >= 6: + # Format as XXXXXXX-X + formatted = f"{digits[:-1]}-{digits[-1]}" + return formatted, True, None + else: + return None, False, f"Invalid Plusgiro length: {len(digits)}" + + def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize monetary amount.""" + # Remove currency and common suffixes + text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE) + text = text.replace(' ', '').replace('\xa0', '') + + # Handle comma as decimal separator + if ',' in text and '.' not in text: + text = text.replace(',', '.') + + # Try to parse as float + try: + amount = float(text) + return f"{amount:.2f}", True, None + except ValueError: + return None, False, f"Cannot parse amount: {text}" + + def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]: + """Normalize date.""" + from datetime import datetime + + # Common date patterns + patterns = [ + (r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"), + (r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"), + (r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"), + (r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"), + ] + + for pattern, formatter in patterns: + match = re.search(pattern, text) + if match: + try: + date_str = formatter(match) + # Validate date + datetime.strptime(date_str, '%Y-%m-%d') + return date_str, True, None + except ValueError: + continue + + return None, False, f"Cannot parse date: {text}" + + def extract_all_fields( + self, + detections: list[Detection], + image: np.ndarray | Image.Image + ) -> list[ExtractedField]: + """ + Extract fields from all detections. + + Args: + detections: List of detections + image: Full page image + + Returns: + List of ExtractedField objects + """ + fields = [] + + for detection in detections: + field = self.extract_from_detection(detection, image) + fields.append(field) + + return fields + + @staticmethod + def infer_ocr_from_invoice_number(fields: dict[str, str]) -> dict[str, str]: + """ + Infer OCR field from InvoiceNumber if not detected. + + In Swedish invoices, OCR reference number is often identical to InvoiceNumber. + When OCR is not detected but InvoiceNumber is, we can infer OCR value. + + Args: + fields: Dict of field_name -> normalized_value + + Returns: + Updated fields dict with inferred OCR if applicable + """ + # If OCR already exists, no need to infer + if fields.get('OCR'): + return fields + + # If InvoiceNumber exists and is numeric, use it as OCR + invoice_number = fields.get('InvoiceNumber') + if invoice_number: + # Check if it's mostly digits (valid OCR reference) + digits_only = re.sub(r'\D', '', invoice_number) + if len(digits_only) >= 5 and len(digits_only) == len(invoice_number): + fields['OCR'] = invoice_number + + return fields diff --git a/src/inference/pipeline.py b/src/inference/pipeline.py new file mode 100644 index 0000000..136ed3f --- /dev/null +++ b/src/inference/pipeline.py @@ -0,0 +1,297 @@ +""" +Inference Pipeline + +Complete pipeline for extracting invoice data from PDFs. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any +import time +import re + +from .yolo_detector import YOLODetector, Detection, CLASS_TO_FIELD +from .field_extractor import FieldExtractor, ExtractedField + + +@dataclass +class InferenceResult: + """Result of invoice processing.""" + document_id: str | None = None + success: bool = False + fields: dict[str, Any] = field(default_factory=dict) + confidence: dict[str, float] = field(default_factory=dict) + raw_detections: list[Detection] = field(default_factory=list) + extracted_fields: list[ExtractedField] = field(default_factory=list) + processing_time_ms: float = 0.0 + errors: list[str] = field(default_factory=list) + fallback_used: bool = False + + def to_json(self) -> dict: + """Convert to JSON-serializable dictionary.""" + return { + 'DocumentId': self.document_id, + 'InvoiceNumber': self.fields.get('InvoiceNumber'), + 'InvoiceDate': self.fields.get('InvoiceDate'), + 'InvoiceDueDate': self.fields.get('InvoiceDueDate'), + 'OCR': self.fields.get('OCR'), + 'Bankgiro': self.fields.get('Bankgiro'), + 'Plusgiro': self.fields.get('Plusgiro'), + 'Amount': self.fields.get('Amount'), + 'confidence': self.confidence, + 'success': self.success, + 'fallback_used': self.fallback_used + } + + def get_field(self, field_name: str) -> tuple[Any, float]: + """Get field value and confidence.""" + return self.fields.get(field_name), self.confidence.get(field_name, 0.0) + + +class InferencePipeline: + """ + Complete inference pipeline for invoice data extraction. + + Pipeline flow: + 1. PDF -> Image rendering + 2. YOLO detection of field regions + 3. OCR extraction from detected regions + 4. Field normalization and validation + 5. Fallback to full-page OCR if YOLO fails + """ + + def __init__( + self, + model_path: str | Path, + confidence_threshold: float = 0.5, + ocr_lang: str = 'en', + use_gpu: bool = False, + dpi: int = 300, + enable_fallback: bool = True + ): + """ + Initialize inference pipeline. + + Args: + model_path: Path to trained YOLO model + confidence_threshold: Detection confidence threshold + ocr_lang: Language for OCR + use_gpu: Whether to use GPU + dpi: Resolution for PDF rendering + enable_fallback: Enable fallback to full-page OCR + """ + self.detector = YOLODetector( + model_path, + confidence_threshold=confidence_threshold, + device='cuda' if use_gpu else 'cpu' + ) + self.extractor = FieldExtractor(ocr_lang=ocr_lang, use_gpu=use_gpu) + self.dpi = dpi + self.enable_fallback = enable_fallback + + def process_pdf( + self, + pdf_path: str | Path, + document_id: str | None = None + ) -> InferenceResult: + """ + Process a PDF and extract invoice fields. + + Args: + pdf_path: Path to PDF file + document_id: Optional document ID + + Returns: + InferenceResult with extracted fields + """ + from ..pdf.renderer import render_pdf_to_images + from PIL import Image + import io + import numpy as np + + start_time = time.time() + + result = InferenceResult( + document_id=document_id or Path(pdf_path).stem + ) + + try: + all_detections = [] + all_extracted = [] + + # Process each page + for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi): + # Convert to numpy array + image = Image.open(io.BytesIO(image_bytes)) + image_array = np.array(image) + + # Run YOLO detection + detections = self.detector.detect(image_array, page_no=page_no) + all_detections.extend(detections) + + # Extract fields from detections + for detection in detections: + extracted = self.extractor.extract_from_detection(detection, image_array) + all_extracted.append(extracted) + + result.raw_detections = all_detections + result.extracted_fields = all_extracted + + # Merge extracted fields (prefer highest confidence) + self._merge_fields(result) + + # Fallback if key fields are missing + if self.enable_fallback and self._needs_fallback(result): + self._run_fallback(pdf_path, result) + + result.success = len(result.fields) > 0 + + except Exception as e: + result.errors.append(str(e)) + result.success = False + + result.processing_time_ms = (time.time() - start_time) * 1000 + return result + + def _merge_fields(self, result: InferenceResult) -> None: + """Merge extracted fields, keeping highest confidence for each field.""" + field_candidates: dict[str, list[ExtractedField]] = {} + + for extracted in result.extracted_fields: + if not extracted.is_valid or not extracted.normalized_value: + continue + + if extracted.field_name not in field_candidates: + field_candidates[extracted.field_name] = [] + field_candidates[extracted.field_name].append(extracted) + + # Select best candidate for each field + for field_name, candidates in field_candidates.items(): + best = max(candidates, key=lambda x: x.confidence) + result.fields[field_name] = best.normalized_value + result.confidence[field_name] = best.confidence + + def _needs_fallback(self, result: InferenceResult) -> bool: + """Check if fallback OCR is needed.""" + # Check for key fields + key_fields = ['Amount', 'InvoiceNumber', 'OCR'] + missing = sum(1 for f in key_fields if f not in result.fields) + return missing >= 2 # Fallback if 2+ key fields missing + + def _run_fallback(self, pdf_path: str | Path, result: InferenceResult) -> None: + """Run full-page OCR fallback.""" + from ..pdf.renderer import render_pdf_to_images + from ..ocr import OCREngine + from PIL import Image + import io + import numpy as np + + result.fallback_used = True + ocr_engine = OCREngine() + + try: + for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi): + image = Image.open(io.BytesIO(image_bytes)) + image_array = np.array(image) + + # Full page OCR + tokens = ocr_engine.extract_from_image(image_array, page_no) + full_text = ' '.join(t.text for t in tokens) + + # Try to extract missing fields with regex patterns + self._extract_with_patterns(full_text, result) + + except Exception as e: + result.errors.append(f"Fallback OCR error: {e}") + + def _extract_with_patterns(self, text: str, result: InferenceResult) -> None: + """Extract fields using regex patterns (fallback).""" + patterns = { + 'Amount': [ + r'(?:att\s*betala|summa|total|belopp)\s*[:.]?\s*([\d\s,\.]+)\s*(?:SEK|kr)?', + r'([\d\s,\.]+)\s*(?:SEK|kr)\s*$', + ], + 'Bankgiro': [ + r'(?:bankgiro|bg)\s*[:.]?\s*(\d{3,4}[-\s]?\d{4})', + r'(\d{4}[-\s]\d{4})\s*(?=\s|$)', + ], + 'OCR': [ + r'(?:ocr|referens)\s*[:.]?\s*(\d{10,25})', + ], + 'InvoiceNumber': [ + r'(?:fakturanr|fakturanummer|invoice)\s*[:.]?\s*(\d+)', + ], + } + + for field_name, field_patterns in patterns.items(): + if field_name in result.fields: + continue + + for pattern in field_patterns: + match = re.search(pattern, text, re.IGNORECASE) + if match: + value = match.group(1).strip() + + # Normalize the value + if field_name == 'Amount': + value = value.replace(' ', '').replace(',', '.') + try: + value = f"{float(value):.2f}" + except ValueError: + continue + elif field_name == 'Bankgiro': + digits = re.sub(r'\D', '', value) + if len(digits) == 8: + value = f"{digits[:4]}-{digits[4:]}" + + result.fields[field_name] = value + result.confidence[field_name] = 0.5 # Lower confidence for regex + break + + def process_image( + self, + image_path: str | Path, + document_id: str | None = None + ) -> InferenceResult: + """ + Process a single image (for pre-rendered pages). + + Args: + image_path: Path to image file + document_id: Optional document ID + + Returns: + InferenceResult with extracted fields + """ + from PIL import Image + import numpy as np + + start_time = time.time() + + result = InferenceResult( + document_id=document_id or Path(image_path).stem + ) + + try: + image = Image.open(image_path) + image_array = np.array(image) + + # Run detection + detections = self.detector.detect(image_array, page_no=0) + result.raw_detections = detections + + # Extract fields + for detection in detections: + extracted = self.extractor.extract_from_detection(detection, image_array) + result.extracted_fields.append(extracted) + + # Merge fields + self._merge_fields(result) + result.success = len(result.fields) > 0 + + except Exception as e: + result.errors.append(str(e)) + result.success = False + + result.processing_time_ms = (time.time() - start_time) * 1000 + return result diff --git a/src/inference/yolo_detector.py b/src/inference/yolo_detector.py new file mode 100644 index 0000000..ed9dd0b --- /dev/null +++ b/src/inference/yolo_detector.py @@ -0,0 +1,204 @@ +""" +YOLO Detection Module + +Runs YOLO model inference for field detection. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import numpy as np + + +@dataclass +class Detection: + """Represents a single YOLO detection.""" + class_id: int + class_name: str + confidence: float + bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) in pixels + page_no: int = 0 + + @property + def x0(self) -> float: + return self.bbox[0] + + @property + def y0(self) -> float: + return self.bbox[1] + + @property + def x1(self) -> float: + return self.bbox[2] + + @property + def y1(self) -> float: + return self.bbox[3] + + @property + def center(self) -> tuple[float, float]: + return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2) + + @property + def width(self) -> float: + return self.x1 - self.x0 + + @property + def height(self) -> float: + return self.y1 - self.y0 + + def get_padded_bbox( + self, + padding: float = 0.1, + image_width: float | None = None, + image_height: float | None = None + ) -> tuple[float, float, float, float]: + """Get bbox with padding for OCR extraction.""" + pad_x = self.width * padding + pad_y = self.height * padding + + x0 = self.x0 - pad_x + y0 = self.y0 - pad_y + x1 = self.x1 + pad_x + y1 = self.y1 + pad_y + + if image_width: + x0 = max(0, x0) + x1 = min(image_width, x1) + if image_height: + y0 = max(0, y0) + y1 = min(image_height, y1) + + return (x0, y0, x1, y1) + + +# Class names (must match training configuration) +CLASS_NAMES = [ + 'invoice_number', + 'invoice_date', + 'invoice_due_date', + 'ocr_number', + 'bankgiro', + 'plusgiro', + 'amount', +] + +# Mapping from class name to field name +CLASS_TO_FIELD = { + 'invoice_number': 'InvoiceNumber', + 'invoice_date': 'InvoiceDate', + 'invoice_due_date': 'InvoiceDueDate', + 'ocr_number': 'OCR', + 'bankgiro': 'Bankgiro', + 'plusgiro': 'Plusgiro', + 'amount': 'Amount', +} + + +class YOLODetector: + """YOLO model wrapper for field detection.""" + + def __init__( + self, + model_path: str | Path, + confidence_threshold: float = 0.5, + iou_threshold: float = 0.45, + device: str = 'auto' + ): + """ + Initialize YOLO detector. + + Args: + model_path: Path to trained YOLO model (.pt file) + confidence_threshold: Minimum confidence for detections + iou_threshold: IOU threshold for NMS + device: Device to run on ('auto', 'cpu', 'cuda', 'mps') + """ + from ultralytics import YOLO + + self.model = YOLO(model_path) + self.confidence_threshold = confidence_threshold + self.iou_threshold = iou_threshold + self.device = device + + def detect( + self, + image: str | Path | np.ndarray, + page_no: int = 0 + ) -> list[Detection]: + """ + Run detection on an image. + + Args: + image: Image path or numpy array + page_no: Page number for reference + + Returns: + List of Detection objects + """ + results = self.model.predict( + source=image, + conf=self.confidence_threshold, + iou=self.iou_threshold, + device=self.device, + verbose=False + ) + + detections = [] + + for result in results: + boxes = result.boxes + if boxes is None: + continue + + for i in range(len(boxes)): + class_id = int(boxes.cls[i]) + confidence = float(boxes.conf[i]) + bbox = boxes.xyxy[i].tolist() # [x0, y0, x1, y1] + + class_name = CLASS_NAMES[class_id] if class_id < len(CLASS_NAMES) else f"class_{class_id}" + + detections.append(Detection( + class_id=class_id, + class_name=class_name, + confidence=confidence, + bbox=tuple(bbox), + page_no=page_no + )) + + return detections + + def detect_pdf( + self, + pdf_path: str | Path, + dpi: int = 300 + ) -> dict[int, list[Detection]]: + """ + Run detection on all pages of a PDF. + + Args: + pdf_path: Path to PDF file + dpi: Resolution for rendering + + Returns: + Dict mapping page number to list of detections + """ + from ..pdf.renderer import render_pdf_to_images + from PIL import Image + import io + + results = {} + + for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi): + # Convert bytes to numpy array + image = Image.open(io.BytesIO(image_bytes)) + image_array = np.array(image) + + detections = self.detect(image_array, page_no=page_no) + results[page_no] = detections + + return results + + def get_field_name(self, class_name: str) -> str: + """Convert class name to field name.""" + return CLASS_TO_FIELD.get(class_name, class_name) diff --git a/src/matcher/__init__.py b/src/matcher/__init__.py new file mode 100644 index 0000000..eced8fa --- /dev/null +++ b/src/matcher/__init__.py @@ -0,0 +1,3 @@ +from .field_matcher import FieldMatcher, Match, find_field_matches + +__all__ = ['FieldMatcher', 'Match', 'find_field_matches'] diff --git a/src/matcher/field_matcher.py b/src/matcher/field_matcher.py new file mode 100644 index 0000000..c7b2531 --- /dev/null +++ b/src/matcher/field_matcher.py @@ -0,0 +1,618 @@ +""" +Field Matching Module + +Matches normalized field values to tokens extracted from documents. +""" + +from dataclasses import dataclass +from typing import Protocol +import re + + +class TokenLike(Protocol): + """Protocol for token objects.""" + text: str + bbox: tuple[float, float, float, float] + page_no: int + + +@dataclass +class Match: + """Represents a matched field in the document.""" + field: str + value: str + bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) + page_no: int + score: float # 0-1 confidence score + matched_text: str # Actual text that matched + context_keywords: list[str] # Nearby keywords that boosted confidence + + def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str: + """Convert to YOLO annotation format.""" + x0, y0, x1, y1 = self.bbox + + x_center = (x0 + x1) / 2 / image_width + y_center = (y0 + y1) / 2 / image_height + width = (x1 - x0) / image_width + height = (y1 - y0) / image_height + + return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}" + + +# Context keywords for each field type (Swedish invoice terms) +CONTEXT_KEYWORDS = { + 'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'], + 'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'], + 'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast', + 'förfallodag', 'oss tillhanda senast', 'senast'], + 'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'], + 'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'], + 'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'], + 'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'], +} + + +class FieldMatcher: + """Matches field values to document tokens.""" + + def __init__( + self, + context_radius: float = 100.0, # pixels + min_score_threshold: float = 0.5 + ): + """ + Initialize the matcher. + + Args: + context_radius: Distance to search for context keywords + min_score_threshold: Minimum score to consider a match valid + """ + self.context_radius = context_radius + self.min_score_threshold = min_score_threshold + + def find_matches( + self, + tokens: list[TokenLike], + field_name: str, + normalized_values: list[str], + page_no: int = 0 + ) -> list[Match]: + """ + Find all matches for a field in the token list. + + Args: + tokens: List of tokens from the document + field_name: Name of the field to match + normalized_values: List of normalized value variants to search for + page_no: Page number to filter tokens + + Returns: + List of Match objects sorted by score (descending) + """ + matches = [] + page_tokens = [t for t in tokens if t.page_no == page_no] + + for value in normalized_values: + # Strategy 1: Exact token match + exact_matches = self._find_exact_matches(page_tokens, value, field_name) + matches.extend(exact_matches) + + # Strategy 2: Multi-token concatenation + concat_matches = self._find_concatenated_matches(page_tokens, value, field_name) + matches.extend(concat_matches) + + # Strategy 3: Fuzzy match (for amounts and dates only) + if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'): + fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name) + matches.extend(fuzzy_matches) + + # Strategy 4: Substring match (for dates embedded in longer text) + if field_name in ('InvoiceDate', 'InvoiceDueDate'): + substring_matches = self._find_substring_matches(page_tokens, value, field_name) + matches.extend(substring_matches) + + # Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection) + # Only if no exact matches found for date fields + if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches: + flexible_matches = self._find_flexible_date_matches( + page_tokens, normalized_values, field_name + ) + matches.extend(flexible_matches) + + # Deduplicate and sort by score + matches = self._deduplicate_matches(matches) + matches.sort(key=lambda m: m.score, reverse=True) + + return [m for m in matches if m.score >= self.min_score_threshold] + + def _find_exact_matches( + self, + tokens: list[TokenLike], + value: str, + field_name: str + ) -> list[Match]: + """Find tokens that exactly match the value.""" + matches = [] + + for token in tokens: + token_text = token.text.strip() + + # Exact match + if token_text == value: + score = 1.0 + # Case-insensitive match + elif token_text.lower() == value.lower(): + score = 0.95 + # Digits-only match for numeric fields + elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'): + token_digits = re.sub(r'\D', '', token_text) + value_digits = re.sub(r'\D', '', value) + if token_digits and token_digits == value_digits: + score = 0.9 + else: + continue + else: + continue + + # Boost score if context keywords are nearby + context_keywords, context_boost = self._find_context_keywords( + tokens, token, field_name + ) + score = min(1.0, score + context_boost) + + matches.append(Match( + field=field_name, + value=value, + bbox=token.bbox, + page_no=token.page_no, + score=score, + matched_text=token_text, + context_keywords=context_keywords + )) + + return matches + + def _find_concatenated_matches( + self, + tokens: list[TokenLike], + value: str, + field_name: str + ) -> list[Match]: + """Find value by concatenating adjacent tokens.""" + matches = [] + value_clean = re.sub(r'\s+', '', value) + + # Sort tokens by position (top-to-bottom, left-to-right) + sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0])) + + for i, start_token in enumerate(sorted_tokens): + # Try to build the value by concatenating nearby tokens + concat_text = start_token.text.strip() + concat_bbox = list(start_token.bbox) + used_tokens = [start_token] + + for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens + next_token = sorted_tokens[j] + + # Check if tokens are on the same line (y overlap) + if not self._tokens_on_same_line(start_token, next_token): + break + + # Check horizontal proximity + if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap + break + + concat_text += next_token.text.strip() + used_tokens.append(next_token) + + # Update bounding box + concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0]) + concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1]) + concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2]) + concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3]) + + # Check for match + concat_clean = re.sub(r'\s+', '', concat_text) + if concat_clean == value_clean: + context_keywords, context_boost = self._find_context_keywords( + tokens, start_token, field_name + ) + + matches.append(Match( + field=field_name, + value=value, + bbox=tuple(concat_bbox), + page_no=start_token.page_no, + score=min(1.0, 0.85 + context_boost), # Slightly lower base score + matched_text=concat_text, + context_keywords=context_keywords + )) + break + + return matches + + def _find_substring_matches( + self, + tokens: list[TokenLike], + value: str, + field_name: str + ) -> list[Match]: + """ + Find value as a substring within longer tokens. + + Handles cases like 'Fakturadatum: 2026-01-09' where the date + is embedded in a longer text string. + + Uses lower score (0.75) than exact match to prefer exact matches. + Only matches if the value appears as a distinct segment (not part of a number). + """ + matches = [] + + # Only use for date fields - other fields risk false positives + if field_name not in ('InvoiceDate', 'InvoiceDueDate'): + return matches + + for token in tokens: + token_text = token.text.strip() + + # Skip if token is the same length as value (would be exact match) + if len(token_text) <= len(value): + continue + + # Check if value appears as substring + if value in token_text: + # Verify it's a proper boundary match (not part of a larger number) + idx = token_text.find(value) + + # Check character before (if exists) + if idx > 0: + char_before = token_text[idx - 1] + # Must be non-digit (allow : space - etc) + if char_before.isdigit(): + continue + + # Check character after (if exists) + end_idx = idx + len(value) + if end_idx < len(token_text): + char_after = token_text[end_idx] + # Must be non-digit + if char_after.isdigit(): + continue + + # Found valid substring match + context_keywords, context_boost = self._find_context_keywords( + tokens, token, field_name + ) + + # Check if context keyword is in the same token (like "Fakturadatum:") + token_lower = token_text.lower() + inline_context = [] + for keyword in CONTEXT_KEYWORDS.get(field_name, []): + if keyword in token_lower: + inline_context.append(keyword) + + # Boost score if keyword is inline + inline_boost = 0.1 if inline_context else 0 + + matches.append(Match( + field=field_name, + value=value, + bbox=token.bbox, # Use full token bbox + page_no=token.page_no, + score=min(1.0, 0.75 + context_boost + inline_boost), # Lower than exact match + matched_text=token_text, + context_keywords=context_keywords + inline_context + )) + + return matches + + def _find_fuzzy_matches( + self, + tokens: list[TokenLike], + value: str, + field_name: str + ) -> list[Match]: + """Find approximate matches for amounts and dates.""" + matches = [] + + for token in tokens: + token_text = token.text.strip() + + if field_name == 'Amount': + # Try to parse both as numbers + try: + token_num = self._parse_amount(token_text) + value_num = self._parse_amount(value) + + if token_num is not None and value_num is not None: + if abs(token_num - value_num) < 0.01: # Within 1 cent + context_keywords, context_boost = self._find_context_keywords( + tokens, token, field_name + ) + + matches.append(Match( + field=field_name, + value=value, + bbox=token.bbox, + page_no=token.page_no, + score=min(1.0, 0.8 + context_boost), + matched_text=token_text, + context_keywords=context_keywords + )) + except: + pass + + return matches + + def _find_flexible_date_matches( + self, + tokens: list[TokenLike], + normalized_values: list[str], + field_name: str + ) -> list[Match]: + """ + Flexible date matching when exact match fails. + + Strategies: + 1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date + 2. Nearby date match: Match dates within 7 days of CSV value + 3. Heuristic selection: Use context keywords to select the best date + + This handles cases where CSV InvoiceDate doesn't exactly match PDF, + but we can still find a reasonable date to label. + """ + from datetime import datetime, timedelta + + matches = [] + + # Parse the target date from normalized values + target_date = None + for value in normalized_values: + # Try to parse YYYY-MM-DD format + date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value) + if date_match: + try: + target_date = datetime( + int(date_match.group(1)), + int(date_match.group(2)), + int(date_match.group(3)) + ) + break + except ValueError: + continue + + if not target_date: + return matches + + # Find all date-like tokens in the document + date_candidates = [] + date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})') + + for token in tokens: + token_text = token.text.strip() + + # Search for date pattern in token + for match in date_pattern.finditer(token_text): + try: + found_date = datetime( + int(match.group(1)), + int(match.group(2)), + int(match.group(3)) + ) + date_str = match.group(0) + + # Calculate date difference + days_diff = abs((found_date - target_date).days) + + # Check for context keywords + context_keywords, context_boost = self._find_context_keywords( + tokens, token, field_name + ) + + # Check if keyword is in the same token + token_lower = token_text.lower() + inline_keywords = [] + for keyword in CONTEXT_KEYWORDS.get(field_name, []): + if keyword in token_lower: + inline_keywords.append(keyword) + + date_candidates.append({ + 'token': token, + 'date': found_date, + 'date_str': date_str, + 'matched_text': token_text, + 'days_diff': days_diff, + 'context_keywords': context_keywords + inline_keywords, + 'context_boost': context_boost + (0.1 if inline_keywords else 0), + 'same_year_month': (found_date.year == target_date.year and + found_date.month == target_date.month), + }) + except ValueError: + continue + + if not date_candidates: + return matches + + # Score and rank candidates + for candidate in date_candidates: + score = 0.0 + + # Strategy 1: Same year-month gets higher score + if candidate['same_year_month']: + score = 0.7 + # Bonus if day is close + if candidate['days_diff'] <= 7: + score = 0.75 + if candidate['days_diff'] <= 3: + score = 0.8 + # Strategy 2: Nearby dates (within 14 days) + elif candidate['days_diff'] <= 14: + score = 0.6 + elif candidate['days_diff'] <= 30: + score = 0.55 + else: + # Too far apart, skip unless has strong context + if not candidate['context_keywords']: + continue + score = 0.5 + + # Strategy 3: Boost with context keywords + score = min(1.0, score + candidate['context_boost']) + + # For InvoiceDate, prefer dates that appear near invoice-related keywords + # For InvoiceDueDate, prefer dates near due-date keywords + if candidate['context_keywords']: + score = min(1.0, score + 0.05) + + if score >= self.min_score_threshold: + matches.append(Match( + field=field_name, + value=candidate['date_str'], + bbox=candidate['token'].bbox, + page_no=candidate['token'].page_no, + score=score, + matched_text=candidate['matched_text'], + context_keywords=candidate['context_keywords'] + )) + + # Sort by score and return best matches + matches.sort(key=lambda m: m.score, reverse=True) + + # Only return the best match to avoid multiple labels for same field + return matches[:1] if matches else [] + + def _find_context_keywords( + self, + tokens: list[TokenLike], + target_token: TokenLike, + field_name: str + ) -> tuple[list[str], float]: + """Find context keywords near the target token.""" + keywords = CONTEXT_KEYWORDS.get(field_name, []) + found_keywords = [] + + target_center = ( + (target_token.bbox[0] + target_token.bbox[2]) / 2, + (target_token.bbox[1] + target_token.bbox[3]) / 2 + ) + + for token in tokens: + if token is target_token: + continue + + token_center = ( + (token.bbox[0] + token.bbox[2]) / 2, + (token.bbox[1] + token.bbox[3]) / 2 + ) + + # Calculate distance + distance = ( + (target_center[0] - token_center[0]) ** 2 + + (target_center[1] - token_center[1]) ** 2 + ) ** 0.5 + + if distance <= self.context_radius: + token_lower = token.text.lower() + for keyword in keywords: + if keyword in token_lower: + found_keywords.append(keyword) + + # Calculate boost based on keywords found + boost = min(0.15, len(found_keywords) * 0.05) + return found_keywords, boost + + def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool: + """Check if two tokens are on the same line.""" + # Check vertical overlap + y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1]) + min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1]) + return y_overlap > min_height * 0.5 + + def _parse_amount(self, text: str) -> float | None: + """Try to parse text as a monetary amount.""" + # Remove currency and spaces + text = re.sub(r'[SEK|kr|:-]', '', text, flags=re.IGNORECASE) + text = text.replace(' ', '').replace('\xa0', '') + + # Try comma as decimal separator + if ',' in text and '.' not in text: + text = text.replace(',', '.') + + try: + return float(text) + except ValueError: + return None + + def _deduplicate_matches(self, matches: list[Match]) -> list[Match]: + """Remove duplicate matches based on bbox overlap.""" + if not matches: + return [] + + # Sort by score descending + matches.sort(key=lambda m: m.score, reverse=True) + unique = [] + + for match in matches: + is_duplicate = False + for existing in unique: + if self._bbox_overlap(match.bbox, existing.bbox) > 0.7: + is_duplicate = True + break + + if not is_duplicate: + unique.append(match) + + return unique + + def _bbox_overlap( + self, + bbox1: tuple[float, float, float, float], + bbox2: tuple[float, float, float, float] + ) -> float: + """Calculate IoU (Intersection over Union) of two bounding boxes.""" + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) + + if x2 <= x1 or y2 <= y1: + return 0.0 + + intersection = (x2 - x1) * (y2 - y1) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0.0 + + +def find_field_matches( + tokens: list[TokenLike], + field_values: dict[str, str], + page_no: int = 0 +) -> dict[str, list[Match]]: + """ + Convenience function to find matches for multiple fields. + + Args: + tokens: List of tokens from the document + field_values: Dict of field_name -> value to search for + page_no: Page number + + Returns: + Dict of field_name -> list of matches + """ + from ..normalize import normalize_field + + matcher = FieldMatcher() + results = {} + + for field_name, value in field_values.items(): + if value is None or str(value).strip() == '': + continue + + normalized_values = normalize_field(field_name, str(value)) + matches = matcher.find_matches(tokens, field_name, normalized_values, page_no) + results[field_name] = matches + + return results diff --git a/src/normalize/__init__.py b/src/normalize/__init__.py new file mode 100644 index 0000000..4cfb1ec --- /dev/null +++ b/src/normalize/__init__.py @@ -0,0 +1,3 @@ +from .normalizer import normalize_field, FieldNormalizer + +__all__ = ['normalize_field', 'FieldNormalizer'] diff --git a/src/normalize/normalizer.py b/src/normalize/normalizer.py new file mode 100644 index 0000000..62d3e8f --- /dev/null +++ b/src/normalize/normalizer.py @@ -0,0 +1,290 @@ +""" +Field Normalization Module + +Normalizes field values to generate multiple candidate forms for matching. +""" + +import re +from dataclasses import dataclass +from datetime import datetime +from typing import Callable + + +@dataclass +class NormalizedValue: + """Represents a normalized value with its variants.""" + original: str + variants: list[str] + field_type: str + + +class FieldNormalizer: + """Handles normalization of different invoice field types.""" + + # Common Swedish month names for date parsing + SWEDISH_MONTHS = { + 'januari': '01', 'jan': '01', + 'februari': '02', 'feb': '02', + 'mars': '03', 'mar': '03', + 'april': '04', 'apr': '04', + 'maj': '05', + 'juni': '06', 'jun': '06', + 'juli': '07', 'jul': '07', + 'augusti': '08', 'aug': '08', + 'september': '09', 'sep': '09', 'sept': '09', + 'oktober': '10', 'okt': '10', + 'november': '11', 'nov': '11', + 'december': '12', 'dec': '12' + } + + @staticmethod + def clean_text(text: str) -> str: + """Remove invisible characters and normalize whitespace.""" + # Remove zero-width characters + text = re.sub(r'[\u200b\u200c\u200d\ufeff]', '', text) + # Normalize whitespace + text = ' '.join(text.split()) + return text.strip() + + @staticmethod + def normalize_invoice_number(value: str) -> list[str]: + """ + Normalize invoice number. + Keeps only digits for matching. + + Examples: + '100017500321' -> ['100017500321'] + 'INV-100017500321' -> ['100017500321', 'INV-100017500321'] + """ + value = FieldNormalizer.clean_text(value) + digits_only = re.sub(r'\D', '', value) + + variants = [value] + if digits_only and digits_only != value: + variants.append(digits_only) + + return list(set(v for v in variants if v)) + + @staticmethod + def normalize_ocr_number(value: str) -> list[str]: + """ + Normalize OCR number (Swedish payment reference). + Similar to invoice number - digits only. + """ + return FieldNormalizer.normalize_invoice_number(value) + + @staticmethod + def normalize_bankgiro(value: str) -> list[str]: + """ + Normalize Bankgiro number. + + Examples: + '5393-9484' -> ['5393-9484', '53939484'] + '53939484' -> ['53939484', '5393-9484'] + """ + value = FieldNormalizer.clean_text(value) + digits_only = re.sub(r'\D', '', value) + + variants = [value] + + if digits_only: + # Add without dash + variants.append(digits_only) + + # Add with dash (format: XXXX-XXXX for 8 digits) + if len(digits_only) == 8: + with_dash = f"{digits_only[:4]}-{digits_only[4:]}" + variants.append(with_dash) + elif len(digits_only) == 7: + # Some bankgiro numbers are 7 digits: XXX-XXXX + with_dash = f"{digits_only[:3]}-{digits_only[3:]}" + variants.append(with_dash) + + return list(set(v for v in variants if v)) + + @staticmethod + def normalize_plusgiro(value: str) -> list[str]: + """ + Normalize Plusgiro number. + + Examples: + '1234567-8' -> ['1234567-8', '12345678'] + '12345678' -> ['12345678', '1234567-8'] + """ + value = FieldNormalizer.clean_text(value) + digits_only = re.sub(r'\D', '', value) + + variants = [value] + + if digits_only: + variants.append(digits_only) + + # Plusgiro format: XXXXXXX-X (7 digits + check digit) + if len(digits_only) == 8: + with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" + variants.append(with_dash) + # Also handle 6+1 format + elif len(digits_only) == 7: + with_dash = f"{digits_only[:-1]}-{digits_only[-1]}" + variants.append(with_dash) + + return list(set(v for v in variants if v)) + + @staticmethod + def normalize_amount(value: str) -> list[str]: + """ + Normalize monetary amount. + + Examples: + '114' -> ['114', '114,00', '114.00'] + '114,00' -> ['114,00', '114.00', '114'] + '1 234,56' -> ['1234,56', '1234.56', '1 234,56'] + """ + value = FieldNormalizer.clean_text(value) + + # Remove currency symbols and common suffixes + value = re.sub(r'[SEK|kr|:-]+$', '', value, flags=re.IGNORECASE).strip() + + # Remove spaces (thousand separators) + no_space = value.replace(' ', '').replace('\xa0', '') + + variants = [value] + + # Normalize decimal separator + if ',' in no_space: + dot_version = no_space.replace(',', '.') + variants.append(no_space) + variants.append(dot_version) + elif '.' in no_space: + comma_version = no_space.replace('.', ',') + variants.append(no_space) + variants.append(comma_version) + else: + # Integer amount - add decimal versions + variants.append(no_space) + variants.append(f"{no_space},00") + variants.append(f"{no_space}.00") + + # Try to parse and get clean numeric value + try: + # Parse as float + clean = no_space.replace(',', '.') + num = float(clean) + + # Integer if no decimals + if num == int(num): + variants.append(str(int(num))) + variants.append(f"{int(num)},00") + variants.append(f"{int(num)}.00") + else: + variants.append(f"{num:.2f}") + variants.append(f"{num:.2f}".replace('.', ',')) + except ValueError: + pass + + return list(set(v for v in variants if v)) + + @staticmethod + def normalize_date(value: str) -> list[str]: + """ + Normalize date to YYYY-MM-DD and generate variants. + + Handles: + '2025-12-13' -> ['2025-12-13', '13/12/2025', '13.12.2025'] + '13/12/2025' -> ['2025-12-13', '13/12/2025', ...] + '13 december 2025' -> ['2025-12-13', ...] + """ + value = FieldNormalizer.clean_text(value) + variants = [value] + + parsed_date = None + + # Try different date formats + date_patterns = [ + # ISO format with optional time (e.g., 2026-01-09 00:00:00) + (r'^(\d{4})-(\d{1,2})-(\d{1,2})(?:\s+\d{1,2}:\d{2}:\d{2})?$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), + # European format with / + (r'^(\d{1,2})/(\d{1,2})/(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), + # European format with . + (r'^(\d{1,2})\.(\d{1,2})\.(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), + # European format with - + (r'^(\d{1,2})-(\d{1,2})-(\d{4})$', lambda m: (int(m[3]), int(m[2]), int(m[1]))), + # Swedish format: YYMMDD + (r'^(\d{2})(\d{2})(\d{2})$', lambda m: (2000 + int(m[1]) if int(m[1]) < 50 else 1900 + int(m[1]), int(m[2]), int(m[3]))), + # Swedish format: YYYYMMDD + (r'^(\d{4})(\d{2})(\d{2})$', lambda m: (int(m[1]), int(m[2]), int(m[3]))), + ] + + for pattern, extractor in date_patterns: + match = re.match(pattern, value) + if match: + try: + year, month, day = extractor(match) + parsed_date = datetime(year, month, day) + break + except ValueError: + continue + + # Try Swedish month names + if not parsed_date: + for month_name, month_num in FieldNormalizer.SWEDISH_MONTHS.items(): + if month_name in value.lower(): + # Extract day and year + numbers = re.findall(r'\d+', value) + if len(numbers) >= 2: + day = int(numbers[0]) + year = int(numbers[-1]) + if year < 100: + year = 2000 + year if year < 50 else 1900 + year + try: + parsed_date = datetime(year, int(month_num), day) + break + except ValueError: + continue + + if parsed_date: + # Generate different formats + iso = parsed_date.strftime('%Y-%m-%d') + eu_slash = parsed_date.strftime('%d/%m/%Y') + eu_dot = parsed_date.strftime('%d.%m.%Y') + compact = parsed_date.strftime('%Y%m%d') + + variants.extend([iso, eu_slash, eu_dot, compact]) + + return list(set(v for v in variants if v)) + + +# Field type to normalizer mapping +NORMALIZERS: dict[str, Callable[[str], list[str]]] = { + 'InvoiceNumber': FieldNormalizer.normalize_invoice_number, + 'OCR': FieldNormalizer.normalize_ocr_number, + 'Bankgiro': FieldNormalizer.normalize_bankgiro, + 'Plusgiro': FieldNormalizer.normalize_plusgiro, + 'Amount': FieldNormalizer.normalize_amount, + 'InvoiceDate': FieldNormalizer.normalize_date, + 'InvoiceDueDate': FieldNormalizer.normalize_date, +} + + +def normalize_field(field_name: str, value: str) -> list[str]: + """ + Normalize a field value based on its type. + + Args: + field_name: Name of the field (e.g., 'InvoiceNumber', 'Amount') + value: Raw value to normalize + + Returns: + List of normalized variants + """ + if value is None or (isinstance(value, str) and not value.strip()): + return [] + + value = str(value) + normalizer = NORMALIZERS.get(field_name) + + if normalizer: + return normalizer(value) + + # Default: just clean the text + return [FieldNormalizer.clean_text(value)] diff --git a/src/ocr/__init__.py b/src/ocr/__init__.py new file mode 100644 index 0000000..9badd37 --- /dev/null +++ b/src/ocr/__init__.py @@ -0,0 +1,3 @@ +from .paddle_ocr import OCREngine, extract_ocr_tokens + +__all__ = ['OCREngine', 'extract_ocr_tokens'] diff --git a/src/ocr/paddle_ocr.py b/src/ocr/paddle_ocr.py new file mode 100644 index 0000000..1973493 --- /dev/null +++ b/src/ocr/paddle_ocr.py @@ -0,0 +1,188 @@ +""" +OCR Extraction Module using PaddleOCR + +Extracts text tokens with bounding boxes from scanned PDFs. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Generator +import numpy as np + + +@dataclass +class OCRToken: + """Represents an OCR-extracted text token with its bounding box.""" + text: str + bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) + confidence: float + page_no: int = 0 + + @property + def x0(self) -> float: + return self.bbox[0] + + @property + def y0(self) -> float: + return self.bbox[1] + + @property + def x1(self) -> float: + return self.bbox[2] + + @property + def y1(self) -> float: + return self.bbox[3] + + @property + def center(self) -> tuple[float, float]: + return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2) + + +class OCREngine: + """PaddleOCR wrapper for text extraction.""" + + def __init__( + self, + lang: str = "en", + use_gpu: bool = False, + det_model_dir: str | None = None, + rec_model_dir: str | None = None + ): + """ + Initialize OCR engine. + + Args: + lang: Language code ('en', 'sv', 'ch', etc.) + use_gpu: Whether to use GPU acceleration + det_model_dir: Custom detection model directory + rec_model_dir: Custom recognition model directory + """ + from paddleocr import PaddleOCR + + # PaddleOCR 3.x API - simplified init + init_params = {'lang': lang} + if det_model_dir: + init_params['text_detection_model_dir'] = det_model_dir + if rec_model_dir: + init_params['text_recognition_model_dir'] = rec_model_dir + + self.ocr = PaddleOCR(**init_params) + + def extract_from_image( + self, + image: str | Path | np.ndarray, + page_no: int = 0 + ) -> list[OCRToken]: + """ + Extract text tokens from an image. + + Args: + image: Image path or numpy array + page_no: Page number for reference + + Returns: + List of OCRToken objects + """ + if isinstance(image, (str, Path)): + image = str(image) + + # PaddleOCR 3.x uses predict() method instead of ocr() + result = self.ocr.predict(image) + + tokens = [] + if result: + for item in result: + # PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys' + if isinstance(item, dict): + rec_texts = item.get('rec_texts', []) + rec_scores = item.get('rec_scores', []) + dt_polys = item.get('dt_polys', []) + + for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)): + # poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] + x_coords = [p[0] for p in poly] + y_coords = [p[1] for p in poly] + + bbox = ( + min(x_coords), + min(y_coords), + max(x_coords), + max(y_coords) + ) + + tokens.append(OCRToken( + text=text, + bbox=bbox, + confidence=float(score), + page_no=page_no + )) + elif isinstance(item, (list, tuple)) and len(item) == 2: + # Legacy format: [[bbox_points], (text, confidence)] + bbox_points, (text, confidence) = item + + x_coords = [p[0] for p in bbox_points] + y_coords = [p[1] for p in bbox_points] + + bbox = ( + min(x_coords), + min(y_coords), + max(x_coords), + max(y_coords) + ) + + tokens.append(OCRToken( + text=text, + bbox=bbox, + confidence=confidence, + page_no=page_no + )) + + return tokens + + def extract_from_pdf( + self, + pdf_path: str | Path, + dpi: int = 300 + ) -> Generator[list[OCRToken], None, None]: + """ + Extract text from all pages of a scanned PDF. + + Args: + pdf_path: Path to the PDF file + dpi: Resolution for rendering + + Yields: + List of OCRToken for each page + """ + from ..pdf.renderer import render_pdf_to_images + import io + from PIL import Image + + for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=dpi): + # Convert bytes to numpy array + image = Image.open(io.BytesIO(image_bytes)) + image_array = np.array(image) + + tokens = self.extract_from_image(image_array, page_no=page_no) + yield tokens + + +def extract_ocr_tokens( + image_path: str | Path, + lang: str = "en", + page_no: int = 0 +) -> list[OCRToken]: + """ + Convenience function to extract OCR tokens from an image. + + Args: + image_path: Path to the image file + lang: Language code + page_no: Page number for reference + + Returns: + List of OCRToken objects + """ + engine = OCREngine(lang=lang) + return engine.extract_from_image(image_path, page_no=page_no) diff --git a/src/pdf/__init__.py b/src/pdf/__init__.py new file mode 100644 index 0000000..e9053ee --- /dev/null +++ b/src/pdf/__init__.py @@ -0,0 +1,5 @@ +from .detector import is_text_pdf, get_pdf_type +from .renderer import render_pdf_to_images +from .extractor import extract_text_tokens + +__all__ = ['is_text_pdf', 'get_pdf_type', 'render_pdf_to_images', 'extract_text_tokens'] diff --git a/src/pdf/detector.py b/src/pdf/detector.py new file mode 100644 index 0000000..4b9ec99 --- /dev/null +++ b/src/pdf/detector.py @@ -0,0 +1,98 @@ +""" +PDF Type Detection Module + +Automatically distinguishes between: +- Text-based PDFs (digitally generated) +- Scanned image PDFs +""" + +from pathlib import Path +from typing import Literal +import fitz # PyMuPDF + + +PDFType = Literal["text", "scanned", "mixed"] + + +def extract_text_first_page(pdf_path: str | Path) -> str: + """Extract text from the first page of a PDF.""" + doc = fitz.open(pdf_path) + if len(doc) == 0: + return "" + + first_page = doc[0] + text = first_page.get_text() + doc.close() + return text + + +def is_text_pdf(pdf_path: str | Path, min_chars: int = 30) -> bool: + """ + Check if PDF has extractable text layer. + + Args: + pdf_path: Path to the PDF file + min_chars: Minimum characters to consider it a text PDF + + Returns: + True if PDF has text layer, False if scanned + """ + text = extract_text_first_page(pdf_path) + return len(text.strip()) > min_chars + + +def get_pdf_type(pdf_path: str | Path) -> PDFType: + """ + Determine the PDF type. + + Returns: + 'text' - Has extractable text layer + 'scanned' - Image-based, needs OCR + 'mixed' - Some pages have text, some don't + """ + doc = fitz.open(pdf_path) + + if len(doc) == 0: + doc.close() + return "scanned" + + text_pages = 0 + for page in doc: + text = page.get_text().strip() + if len(text) > 30: + text_pages += 1 + + doc.close() + + total_pages = len(doc) + if text_pages == total_pages: + return "text" + elif text_pages == 0: + return "scanned" + else: + return "mixed" + + +def get_page_info(pdf_path: str | Path) -> list[dict]: + """ + Get information about each page in the PDF. + + Returns: + List of dicts with page info (number, width, height, has_text) + """ + doc = fitz.open(pdf_path) + pages = [] + + for i, page in enumerate(doc): + text = page.get_text().strip() + rect = page.rect + pages.append({ + "page_no": i, + "width": rect.width, + "height": rect.height, + "has_text": len(text) > 30, + "char_count": len(text) + }) + + doc.close() + return pages diff --git a/src/pdf/extractor.py b/src/pdf/extractor.py new file mode 100644 index 0000000..df88850 --- /dev/null +++ b/src/pdf/extractor.py @@ -0,0 +1,176 @@ +""" +PDF Text Extraction Module + +Extracts text tokens with bounding boxes from text-layer PDFs. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Generator +import fitz # PyMuPDF + + +@dataclass +class Token: + """Represents a text token with its bounding box.""" + text: str + bbox: tuple[float, float, float, float] # (x0, y0, x1, y1) + page_no: int + + @property + def x0(self) -> float: + return self.bbox[0] + + @property + def y0(self) -> float: + return self.bbox[1] + + @property + def x1(self) -> float: + return self.bbox[2] + + @property + def y1(self) -> float: + return self.bbox[3] + + @property + def width(self) -> float: + return self.x1 - self.x0 + + @property + def height(self) -> float: + return self.y1 - self.y0 + + @property + def center(self) -> tuple[float, float]: + return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2) + + +def extract_text_tokens( + pdf_path: str | Path, + page_no: int | None = None +) -> Generator[Token, None, None]: + """ + Extract text tokens with bounding boxes from PDF. + + Args: + pdf_path: Path to the PDF file + page_no: Specific page to extract (None for all pages) + + Yields: + Token objects with text and bbox + """ + doc = fitz.open(pdf_path) + + pages_to_process = [page_no] if page_no is not None else range(len(doc)) + + for pg_no in pages_to_process: + page = doc[pg_no] + + # Get text with position info using "dict" mode + text_dict = page.get_text("dict") + + for block in text_dict.get("blocks", []): + if block.get("type") != 0: # Skip non-text blocks + continue + + for line in block.get("lines", []): + for span in line.get("spans", []): + text = span.get("text", "").strip() + if not text: + continue + + bbox = span.get("bbox") + if bbox: + yield Token( + text=text, + bbox=tuple(bbox), + page_no=pg_no + ) + + doc.close() + + +def extract_words( + pdf_path: str | Path, + page_no: int | None = None +) -> Generator[Token, None, None]: + """ + Extract individual words with bounding boxes. + + Uses PyMuPDF's word extraction which splits text into words. + """ + doc = fitz.open(pdf_path) + + pages_to_process = [page_no] if page_no is not None else range(len(doc)) + + for pg_no in pages_to_process: + page = doc[pg_no] + + # get_text("words") returns list of (x0, y0, x1, y1, word, block_no, line_no, word_no) + words = page.get_text("words") + + for word_info in words: + x0, y0, x1, y1, text, *_ = word_info + text = text.strip() + if text: + yield Token( + text=text, + bbox=(x0, y0, x1, y1), + page_no=pg_no + ) + + doc.close() + + +def extract_lines( + pdf_path: str | Path, + page_no: int | None = None +) -> Generator[Token, None, None]: + """ + Extract text lines with bounding boxes. + """ + doc = fitz.open(pdf_path) + + pages_to_process = [page_no] if page_no is not None else range(len(doc)) + + for pg_no in pages_to_process: + page = doc[pg_no] + text_dict = page.get_text("dict") + + for block in text_dict.get("blocks", []): + if block.get("type") != 0: + continue + + for line in block.get("lines", []): + spans = line.get("spans", []) + if not spans: + continue + + # Combine all spans in the line + line_text = " ".join(s.get("text", "") for s in spans).strip() + if not line_text: + continue + + # Get line bbox from all spans + x0 = min(s["bbox"][0] for s in spans) + y0 = min(s["bbox"][1] for s in spans) + x1 = max(s["bbox"][2] for s in spans) + y1 = max(s["bbox"][3] for s in spans) + + yield Token( + text=line_text, + bbox=(x0, y0, x1, y1), + page_no=pg_no + ) + + doc.close() + + +def get_page_dimensions(pdf_path: str | Path, page_no: int = 0) -> tuple[float, float]: + """Get the dimensions of a PDF page in points.""" + doc = fitz.open(pdf_path) + page = doc[page_no] + rect = page.rect + doc.close() + return rect.width, rect.height diff --git a/src/pdf/renderer.py b/src/pdf/renderer.py new file mode 100644 index 0000000..e7bd0aa --- /dev/null +++ b/src/pdf/renderer.py @@ -0,0 +1,117 @@ +""" +PDF Rendering Module + +Converts PDF pages to images for YOLO training. +""" + +from pathlib import Path +from typing import Generator +import fitz # PyMuPDF + + +def render_pdf_to_images( + pdf_path: str | Path, + output_dir: str | Path | None = None, + dpi: int = 300, + image_format: str = "png" +) -> Generator[tuple[int, Path | bytes], None, None]: + """ + Render PDF pages to images. + + Args: + pdf_path: Path to the PDF file + output_dir: Directory to save images (if None, returns bytes) + dpi: Resolution for rendering (default 300) + image_format: Output format ('png' or 'jpg') + + Yields: + Tuple of (page_number, image_path or image_bytes) + """ + doc = fitz.open(pdf_path) + + # Calculate zoom factor for desired DPI (72 is base DPI for PDF) + zoom = dpi / 72 + matrix = fitz.Matrix(zoom, zoom) + + if output_dir: + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + pdf_name = Path(pdf_path).stem + + for page_no, page in enumerate(doc): + # Render page to pixmap + pix = page.get_pixmap(matrix=matrix) + + if output_dir: + # Save to file + ext = "jpg" if image_format.lower() in ("jpg", "jpeg") else "png" + image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.{ext}" + + if ext == "jpg": + pix.save(str(image_path), "jpeg") + else: + pix.save(str(image_path)) + + yield page_no, image_path + else: + # Return bytes + if image_format.lower() in ("jpg", "jpeg"): + yield page_no, pix.tobytes("jpeg") + else: + yield page_no, pix.tobytes("png") + + doc.close() + + +def render_page_to_image( + pdf_path: str | Path, + page_no: int, + dpi: int = 300 +) -> bytes: + """ + Render a single page to image bytes. + + Args: + pdf_path: Path to the PDF file + page_no: Page number (0-indexed) + dpi: Resolution for rendering + + Returns: + PNG image bytes + """ + doc = fitz.open(pdf_path) + + if page_no >= len(doc): + doc.close() + raise ValueError(f"Page {page_no} does not exist (PDF has {len(doc)} pages)") + + zoom = dpi / 72 + matrix = fitz.Matrix(zoom, zoom) + + page = doc[page_no] + pix = page.get_pixmap(matrix=matrix) + image_bytes = pix.tobytes("png") + + doc.close() + return image_bytes + + +def get_render_dimensions(pdf_path: str | Path, page_no: int = 0, dpi: int = 300) -> tuple[int, int]: + """ + Get the dimensions of a rendered page. + + Returns: + (width, height) in pixels + """ + doc = fitz.open(pdf_path) + page = doc[page_no] + + zoom = dpi / 72 + rect = page.rect + + width = int(rect.width * zoom) + height = int(rect.height * zoom) + + doc.close() + return width, height diff --git a/src/yolo/__init__.py b/src/yolo/__init__.py new file mode 100644 index 0000000..5ba9419 --- /dev/null +++ b/src/yolo/__init__.py @@ -0,0 +1,4 @@ +from .annotation_generator import AnnotationGenerator, generate_annotations +from .dataset_builder import DatasetBuilder + +__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder'] diff --git a/src/yolo/annotation_generator.py b/src/yolo/annotation_generator.py new file mode 100644 index 0000000..47798a6 --- /dev/null +++ b/src/yolo/annotation_generator.py @@ -0,0 +1,281 @@ +""" +YOLO Annotation Generator + +Generates YOLO format annotations from matched fields. +""" + +from dataclasses import dataclass +from pathlib import Path +from typing import Any +import csv + + +# Field class mapping for YOLO +FIELD_CLASSES = { + 'InvoiceNumber': 0, + 'InvoiceDate': 1, + 'InvoiceDueDate': 2, + 'OCR': 3, + 'Bankgiro': 4, + 'Plusgiro': 5, + 'Amount': 6, +} + +CLASS_NAMES = [ + 'invoice_number', + 'invoice_date', + 'invoice_due_date', + 'ocr_number', + 'bankgiro', + 'plusgiro', + 'amount', +] + + +@dataclass +class YOLOAnnotation: + """Represents a single YOLO annotation.""" + class_id: int + x_center: float # normalized 0-1 + y_center: float # normalized 0-1 + width: float # normalized 0-1 + height: float # normalized 0-1 + confidence: float = 1.0 + + def to_string(self) -> str: + """Convert to YOLO format string.""" + return f"{self.class_id} {self.x_center:.6f} {self.y_center:.6f} {self.width:.6f} {self.height:.6f}" + + +class AnnotationGenerator: + """Generates YOLO annotations from document matches.""" + + def __init__( + self, + min_confidence: float = 0.7, + bbox_padding_px: int = 20, # Absolute padding in pixels + min_bbox_height_px: int = 30 # Minimum bbox height + ): + """ + Initialize annotation generator. + + Args: + min_confidence: Minimum match score to include in training + bbox_padding_px: Absolute padding in pixels to add around bboxes + min_bbox_height_px: Minimum bbox height in pixels + """ + self.min_confidence = min_confidence + self.bbox_padding_px = bbox_padding_px + self.min_bbox_height_px = min_bbox_height_px + + def generate_from_matches( + self, + matches: dict[str, list[Any]], # field_name -> list of Match + image_width: float, + image_height: float, + dpi: int = 300 + ) -> list[YOLOAnnotation]: + """ + Generate YOLO annotations from field matches. + + Args: + matches: Dict of field_name -> list of Match objects + image_width: Width of the rendered image in pixels + image_height: Height of the rendered image in pixels + dpi: DPI used for rendering (needed to convert PDF coords to pixels) + + Returns: + List of YOLOAnnotation objects + """ + annotations = [] + + # Scale factor to convert PDF points (72 DPI) to rendered pixels + scale = dpi / 72.0 + + for field_name, field_matches in matches.items(): + if field_name not in FIELD_CLASSES: + continue + + class_id = FIELD_CLASSES[field_name] + + # Take only the best match per field + if field_matches: + best_match = field_matches[0] # Already sorted by score + + if best_match.score < self.min_confidence: + continue + + # best_match.bbox is in PDF points (72 DPI), convert to pixels + x0, y0, x1, y1 = best_match.bbox + x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale + + # Add absolute padding + pad = self.bbox_padding_px + x0 = max(0, x0 - pad) + y0 = max(0, y0 - pad) + x1 = min(image_width, x1 + pad) + y1 = min(image_height, y1 + pad) + + # Ensure minimum height + current_height = y1 - y0 + if current_height < self.min_bbox_height_px: + extra = (self.min_bbox_height_px - current_height) / 2 + y0 = max(0, y0 - extra) + y1 = min(image_height, y1 + extra) + + # Convert to YOLO format (normalized center + size) + x_center = (x0 + x1) / 2 / image_width + y_center = (y0 + y1) / 2 / image_height + width = (x1 - x0) / image_width + height = (y1 - y0) / image_height + + # Clamp values to 0-1 + x_center = max(0, min(1, x_center)) + y_center = max(0, min(1, y_center)) + width = max(0, min(1, width)) + height = max(0, min(1, height)) + + annotations.append(YOLOAnnotation( + class_id=class_id, + x_center=x_center, + y_center=y_center, + width=width, + height=height, + confidence=best_match.score + )) + + return annotations + + def save_annotations( + self, + annotations: list[YOLOAnnotation], + output_path: str | Path + ) -> None: + """Save annotations to a YOLO format text file.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + for ann in annotations: + f.write(ann.to_string() + '\n') + + @staticmethod + def generate_classes_file(output_path: str | Path) -> None: + """Generate the classes.txt file for YOLO.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + with open(output_path, 'w') as f: + for class_name in CLASS_NAMES: + f.write(class_name + '\n') + + @staticmethod + def generate_yaml_config( + output_path: str | Path, + train_path: str = 'train/images', + val_path: str = 'val/images', + test_path: str = 'test/images' + ) -> None: + """Generate YOLO dataset YAML configuration.""" + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Use absolute path for WSL compatibility + dataset_dir = output_path.parent.absolute() + # Convert Windows path to WSL path if needed + dataset_path_str = str(dataset_dir).replace('\\', '/') + if dataset_path_str[1] == ':': + # Windows path like C:/... -> /mnt/c/... + drive = dataset_path_str[0].lower() + dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}" + + config = f"""# Invoice Field Detection Dataset +path: {dataset_path_str} +train: {train_path} +val: {val_path} +test: {test_path} + +# Classes +names: +""" + for i, name in enumerate(CLASS_NAMES): + config += f" {i}: {name}\n" + + with open(output_path, 'w') as f: + f.write(config) + + +def generate_annotations( + pdf_path: str | Path, + structured_data: dict[str, Any], + output_dir: str | Path, + dpi: int = 300 +) -> list[Path]: + """ + Generate YOLO annotations for a PDF using structured data. + + Args: + pdf_path: Path to the PDF file + structured_data: Dict with field values from CSV + output_dir: Directory to save images and labels + dpi: Resolution for rendering + + Returns: + List of paths to generated annotation files + """ + from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens + from ..pdf.renderer import get_render_dimensions + from ..ocr import OCREngine + from ..matcher import FieldMatcher + from ..normalize import normalize_field + + output_dir = Path(output_dir) + images_dir = output_dir / 'images' + labels_dir = output_dir / 'labels' + images_dir.mkdir(parents=True, exist_ok=True) + labels_dir.mkdir(parents=True, exist_ok=True) + + generator = AnnotationGenerator() + matcher = FieldMatcher() + annotation_files = [] + + # Check PDF type + use_ocr = not is_text_pdf(pdf_path) + + # Initialize OCR if needed + ocr_engine = OCREngine() if use_ocr else None + + # Process each page + for page_no, image_path in render_pdf_to_images(pdf_path, images_dir, dpi=dpi): + # Get image dimensions + img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi) + + # Extract tokens + if use_ocr: + from PIL import Image + tokens = ocr_engine.extract_from_image(str(image_path), page_no) + else: + tokens = list(extract_text_tokens(pdf_path, page_no)) + + # Match fields + matches = {} + for field_name in FIELD_CLASSES.keys(): + value = structured_data.get(field_name) + if value is None or str(value).strip() == '': + continue + + normalized = normalize_field(field_name, str(value)) + field_matches = matcher.find_matches(tokens, field_name, normalized, page_no) + if field_matches: + matches[field_name] = field_matches + + # Generate annotations (pass DPI for coordinate conversion) + annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) + + # Save annotations + if annotations: + label_path = labels_dir / f"{image_path.stem}.txt" + generator.save_annotations(annotations, label_path) + annotation_files.append(label_path) + + return annotation_files diff --git a/src/yolo/dataset_builder.py b/src/yolo/dataset_builder.py new file mode 100644 index 0000000..97bbf8d --- /dev/null +++ b/src/yolo/dataset_builder.py @@ -0,0 +1,249 @@ +""" +YOLO Dataset Builder + +Builds training dataset from PDFs and structured CSV data. +""" + +import csv +import shutil +from dataclasses import dataclass +from pathlib import Path +from typing import Generator +import random + + +@dataclass +class DatasetStats: + """Statistics about the generated dataset.""" + total_documents: int + successful: int + failed: int + total_annotations: int + annotations_per_class: dict[str, int] + train_count: int + val_count: int + test_count: int + + +class DatasetBuilder: + """Builds YOLO training dataset from PDFs and CSV data.""" + + def __init__( + self, + pdf_dir: str | Path, + csv_path: str | Path, + output_dir: str | Path, + document_id_column: str = 'DocumentId', + dpi: int = 300, + train_ratio: float = 0.8, + val_ratio: float = 0.1, + test_ratio: float = 0.1 + ): + """ + Initialize dataset builder. + + Args: + pdf_dir: Directory containing PDF files + csv_path: Path to structured data CSV + output_dir: Output directory for dataset + document_id_column: Column name for document ID + dpi: Resolution for rendering + train_ratio: Fraction for training set + val_ratio: Fraction for validation set + test_ratio: Fraction for test set + """ + self.pdf_dir = Path(pdf_dir) + self.csv_path = Path(csv_path) + self.output_dir = Path(output_dir) + self.document_id_column = document_id_column + self.dpi = dpi + self.train_ratio = train_ratio + self.val_ratio = val_ratio + self.test_ratio = test_ratio + + def load_structured_data(self) -> dict[str, dict]: + """Load structured data from CSV.""" + data = {} + + with open(self.csv_path, 'r', encoding='utf-8') as f: + reader = csv.DictReader(f) + for row in reader: + doc_id = row.get(self.document_id_column) + if doc_id: + data[doc_id] = row + + return data + + def find_pdf_for_document(self, doc_id: str) -> Path | None: + """Find PDF file for a document ID.""" + # Try common naming patterns + patterns = [ + f"{doc_id}.pdf", + f"{doc_id.lower()}.pdf", + f"{doc_id.upper()}.pdf", + f"*{doc_id}*.pdf", + ] + + for pattern in patterns: + matches = list(self.pdf_dir.glob(pattern)) + if matches: + return matches[0] + + return None + + def build(self, seed: int = 42) -> DatasetStats: + """ + Build the complete dataset. + + Args: + seed: Random seed for train/val/test split + + Returns: + DatasetStats with build results + """ + from .annotation_generator import AnnotationGenerator, CLASS_NAMES + + random.seed(seed) + + # Setup output directories + for split in ['train', 'val', 'test']: + (self.output_dir / split / 'images').mkdir(parents=True, exist_ok=True) + (self.output_dir / split / 'labels').mkdir(parents=True, exist_ok=True) + + # Generate config files + AnnotationGenerator.generate_classes_file(self.output_dir / 'classes.txt') + AnnotationGenerator.generate_yaml_config(self.output_dir / 'dataset.yaml') + + # Load structured data + structured_data = self.load_structured_data() + + # Stats tracking + stats = { + 'total': 0, + 'successful': 0, + 'failed': 0, + 'annotations': 0, + 'per_class': {name: 0 for name in CLASS_NAMES}, + 'splits': {'train': 0, 'val': 0, 'test': 0} + } + + # Process each document + processed_items = [] + + for doc_id, data in structured_data.items(): + stats['total'] += 1 + + pdf_path = self.find_pdf_for_document(doc_id) + if not pdf_path: + print(f"Warning: PDF not found for document {doc_id}") + stats['failed'] += 1 + continue + + try: + # Generate to temp dir first + temp_dir = self.output_dir / 'temp' / doc_id + temp_dir.mkdir(parents=True, exist_ok=True) + + from .annotation_generator import generate_annotations + annotation_files = generate_annotations( + pdf_path, data, temp_dir, self.dpi + ) + + if annotation_files: + processed_items.append({ + 'doc_id': doc_id, + 'temp_dir': temp_dir, + 'annotation_files': annotation_files + }) + stats['successful'] += 1 + else: + print(f"Warning: No annotations generated for {doc_id}") + stats['failed'] += 1 + + except Exception as e: + print(f"Error processing {doc_id}: {e}") + stats['failed'] += 1 + + # Shuffle and split + random.shuffle(processed_items) + + n_train = int(len(processed_items) * self.train_ratio) + n_val = int(len(processed_items) * self.val_ratio) + + splits = { + 'train': processed_items[:n_train], + 'val': processed_items[n_train:n_train + n_val], + 'test': processed_items[n_train + n_val:] + } + + # Move files to final locations + for split_name, items in splits.items(): + for item in items: + temp_dir = item['temp_dir'] + + # Move images + for img in (temp_dir / 'images').glob('*'): + dest = self.output_dir / split_name / 'images' / img.name + shutil.move(str(img), str(dest)) + + # Move labels and count annotations + for label in (temp_dir / 'labels').glob('*.txt'): + dest = self.output_dir / split_name / 'labels' / label.name + shutil.move(str(label), str(dest)) + + # Count annotations per class + with open(dest, 'r') as f: + for line in f: + class_id = int(line.strip().split()[0]) + if 0 <= class_id < len(CLASS_NAMES): + stats['per_class'][CLASS_NAMES[class_id]] += 1 + stats['annotations'] += 1 + + stats['splits'][split_name] += 1 + + # Cleanup temp dir + shutil.rmtree(self.output_dir / 'temp', ignore_errors=True) + + return DatasetStats( + total_documents=stats['total'], + successful=stats['successful'], + failed=stats['failed'], + total_annotations=stats['annotations'], + annotations_per_class=stats['per_class'], + train_count=stats['splits']['train'], + val_count=stats['splits']['val'], + test_count=stats['splits']['test'] + ) + + def process_single_document( + self, + doc_id: str, + data: dict, + split: str = 'train' + ) -> bool: + """ + Process a single document (for incremental building). + + Args: + doc_id: Document ID + data: Structured data dict + split: Which split to add to + + Returns: + True if successful + """ + from .annotation_generator import generate_annotations + + pdf_path = self.find_pdf_for_document(doc_id) + if not pdf_path: + return False + + try: + output_subdir = self.output_dir / split + annotation_files = generate_annotations( + pdf_path, data, output_subdir, self.dpi + ) + return len(annotation_files) > 0 + except Exception as e: + print(f"Error processing {doc_id}: {e}") + return False