This commit is contained in:
Yaojia Wang
2026-01-13 00:10:27 +01:00
parent 1b7c61cdd8
commit b26fd61852
43 changed files with 7751 additions and 578 deletions

40
.claude/README.md Normal file
View File

@@ -0,0 +1,40 @@
# Claude Code Configuration
This directory contains Claude Code specific configurations.
## Configuration Files
### Main Controller
- **Location**: `../CLAUDE.md` (project root)
- **Purpose**: Main controller configuration for the Swedish Invoice Extraction System
- **Version**: v1.3.0
### Sub-Agents
Located in `agents/` directory:
- `developer.md` - Development agent
- `code-reviewer.md` - Code review agent
- `tester.md` - Testing agent
- `researcher.md` - Research agent
- `project-manager.md` - Project management agent
### Skills
Located in `skills/` directory:
- `code-generation.md` - High-quality code generation skill
## Important Notes
⚠️ **The main CLAUDE.md file is in the project root**, not in this directory.
This is intentional because:
1. CLAUDE.md is a project-level configuration
2. It should be visible alongside README.md and other important docs
3. It serves as the "constitution" for the entire project
When Claude Code starts, it will read:
1. `../CLAUDE.md` (main controller instructions)
2. Files in `agents/` (when agents are called)
3. Files in `skills/` (when skills are used)
---
For the full main controller configuration, see: [../CLAUDE.md](../CLAUDE.md)

7
.claude/config.toml Normal file
View File

@@ -0,0 +1,7 @@
[permissions]
read = true
write = true
execute = true
[permissions.scope]
paths = ["."]

View File

13
.claude/settings.json Normal file
View File

@@ -0,0 +1,13 @@
{
"permissions": {
"allow": [
"Bash(*)",
"Read(*)",
"Write(*)",
"Edit(*)",
"Glob(*)",
"Grep(*)",
"Task(*)"
]
}
}

View File

@@ -0,0 +1,81 @@
{
"permissions": {
"allow": [
"Bash(*)",
"Bash(wsl*)",
"Bash(wsl -e bash*)",
"Read(*)",
"Write(*)",
"Edit(*)",
"Glob(*)",
"Grep(*)",
"WebFetch(*)",
"WebSearch(*)",
"Task(*)",
"Bash(wsl -e bash -c:*)",
"Bash(powershell -c:*)",
"Bash(dir \"C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2\\\\runs\\\\detect\\\\runs\\\\train\\\\invoice_fields_v3\"\")",
"Bash(timeout:*)",
"Bash(powershell:*)",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && nvidia-smi 2>/dev/null | head -10\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && nohup python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4 > training.log 2>&1 &\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sleep 10 && tail -20 training.log 2>/dev/null\":*)",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && cat training.log 2>/dev/null | head -30\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la training.log 2>/dev/null && ps aux | grep python\")",
"Bash(wsl -e bash -c \"ps aux | grep -E ''python|train''\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4 2>&1 | tee training.log &\")",
"Bash(wsl -e bash -c \"sleep 15 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && tail -15 training.log\":*)",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4\")",
"Bash(wsl -e bash -c \"which screen || sudo apt-get install -y screen 2>/dev/null\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\nfrom ultralytics import YOLO\n\n# Load model\nmodel = YOLO\\(''runs/detect/runs/train/invoice_fields_v4/weights/best.pt''\\)\n\n# Run inference on a test image\nresults = model.predict\\(\n ''data/dataset/test/images/36a4fd23-0a66-4428-9149-4f95c93db9cb_page_000.png'',\n conf=0.5,\n save=True,\n project=''results'',\n name=''test_inference''\n\\)\n\n# Print results\nfor r in results:\n print\\(''Image:'', r.path\\)\n print\\(''Boxes:'', len\\(r.boxes\\)\\)\n for box in r.boxes:\n cls = int\\(box.cls[0]\\)\n conf = float\\(box.conf[0]\\)\n name = model.names[cls]\n print\\(f'' - {name}: {conf:.2%}''\\)\n\"\"\")",
"Bash(python:*)",
"Bash(dir:*)",
"Bash(timeout 180 tail:*)",
"Bash(python3:*)",
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source .venv/bin/activate && python -m src.cli.autolabel --csv ''data/structured_data/document_export_20260110_141554_page1.csv,data/structured_data/document_export_20260110_141612_page2.csv'' --report reports/autolabel_test_2csv.jsonl --workers 2\")",
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice && python -m src.cli.autolabel --csv ''data/structured_data/document_export_20260110_141554_page1.csv,data/structured_data/document_export_20260110_141612_page2.csv'' --report reports/autolabel_test_2csv.jsonl --workers 2\")",
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda info --envs\":*)",
"Bash(wsl:*)",
"Bash(for f in reports/autolabel_shard_test_part*.jsonl)",
"Bash(done)",
"Bash(timeout 10 tail:*)",
"Bash(more:*)",
"Bash(cmd /c type \"C:\\\\Users\\\\yaoji\\\\AppData\\\\Local\\\\Temp\\\\claude\\\\c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2\\\\tasks\\\\b4d8070.output\")",
"Bash(cmd /c \"dir C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2\\\\reports\\\\autolabel_report_v4*.jsonl\")",
"Bash(wsl wc:*)",
"Bash(wsl bash:*)",
"Bash(wsl bash -c \"ps aux | grep python | grep -v grep\")",
"Bash(wsl bash -c \"kill -9 130864 130870 414045 414046\")",
"Bash(wsl bash -c \"ps aux | grep python | grep -v grep | grep -v networkd | grep -v unattended\")",
"Bash(wsl bash -c \"kill -9 414046 2>/dev/null; pkill -9 -f autolabel 2>/dev/null; sleep 1; ps aux | grep autolabel | grep -v grep\")",
"Bash(python -m src.cli.import_report_to_db:*)",
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pip install psycopg2-binary\")",
"Bash(conda activate:*)",
"Bash(python -m src.cli.analyze_report:*)",
"Bash(/c/Users/yaoji/miniconda3/envs/yolo11/python.exe -m src.cli.analyze_report:*)",
"Bash(c:/Users/yaoji/miniconda3/envs/yolo11/python.exe -m src.cli.analyze_report:*)",
"Bash(cmd /c \"cd /d C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2 && C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report\")",
"Bash(cmd /c \"cd /d C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2 && C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report 2>&1\")",
"Bash(powershell -Command \"cd C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2; C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report\":*)",
"Bash(powershell -Command \"ls C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\"\")",
"Bash(where:*)",
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
"Bash(\"C:/Users/yaoji/anaconda3/envs/torch-gpu/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
"Bash(\"C:/Users/yaoji/anaconda3/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -m pip install psycopg2-binary)",
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -m src.cli.analyze_report)",
"Bash(wsl -d Ubuntu bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && conda activate invoice-master && python -m src.cli.autolabel --help\":*)",
"Bash(wsl -d Ubuntu-22.04 bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-extract && python -m src.cli.train --export-only 2>&1\")",
"Bash(wsl -d Ubuntu-22.04 bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la data/dataset/\")",
"Bash(wsl -d Ubuntu-22.04 bash -c \"ps aux | grep python | grep -v grep\")",
"Bash(wsl -d Ubuntu-22.04 bash -c:*)",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")"
],
"deny": [],
"ask": [],
"defaultMode": "default"
}
}

9
.gitignore vendored
View File

@@ -34,13 +34,8 @@ env/
*~ *~
# Data files (large files) # Data files (large files)
data/raw_pdfs/ /data/
data/dataset/train/images/ /results/
data/dataset/val/images/
data/dataset/test/images/
data/dataset/train/labels/
data/dataset/val/labels/
data/dataset/test/labels/
*.pdf *.pdf
*.png *.png
*.jpg *.jpg

452
README.md
View File

@@ -1,90 +1,62 @@
# Invoice Master POC v2 # Invoice Master POC v2
自动账单信息提取系统 - 使用 YOLO + OCR 从 PDF 发票中提取结构化数据。 自动账单信息提取系统 - 使用 YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
## 运行环境 ## 运行环境
> **重要**: 本项目需要在 **WSL (Windows Subsystem for Linux)** 环境运行。 本项目**必须****WSL + Conda** 环境运行。
### 系统要求 ### 系统要求
- WSL 2 (Ubuntu 22.04 推荐) | 环境 | 要求 |
- Python 3.10+ |------|------|
- **NVIDIA GPU + CUDA 12.x (强烈推荐)** - GPU 训练比 CPU 快 10-50 倍 | **WSL** | WSL 2 + Ubuntu 22.04 |
| **Conda** | Miniconda 或 Anaconda |
| **Python** | 3.10+ (通过 Conda 管理) |
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
| **数据库** | PostgreSQL (存储标注结果) |
## 功能特点 ## 功能特点
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF - **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据 - **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
- **字段检测**: 使用 YOLOv8 检测发票字段区域 - **多池处理架构**: CPU 池处理文本 PDFGPU 池处理扫描 PDF
- **OCR 识别**: 使用 PaddleOCR 提取检测区域的文本 - **数据库存储**: 标注结果存储在 PostgreSQL支持增量处理
- **智能匹配**: 支持多种格式规范化和上下文关键词增强 - **YOLO 检测**: 使用 YOLOv11 检测发票字段区域
- **OCR 识别**: 使用 PaddleOCR 3.x 提取检测区域的文本
- **Web 应用**: 提供 REST API 和可视化界面
- **增量训练**: 支持在已训练模型基础上继续训练
## 支持的字段 ## 支持的字段
| 字段 | 说明 | | 类别 ID | 字段 | 说明 |
|------|------| |---------|--------|------|
| InvoiceNumber | 发票号码 | | 0 | invoice_number | 发票号码 |
| InvoiceDate | 发票日期 | | 1 | invoice_date | 发票日期 |
| InvoiceDueDate | 到期日期 | | 2 | invoice_due_date | 到期日期 |
| OCR | OCR 参考号 (瑞典) | | 3 | ocr_number | OCR 参考号 (瑞典支付系统) |
| Bankgiro | Bankgiro 号码 | | 4 | bankgiro | Bankgiro 号码 |
| Plusgiro | Plusgiro 号码 | | 5 | plusgiro | Plusgiro 号码 |
| Amount | 金额 | | 6 | amount | 金额 |
## 安装 (WSL) ## 安装
### 1. 进入 WSL 环境
```bash ```bash
# 从 Windows 终端进入 WSL # 1. 进入 WSL
wsl wsl -d Ubuntu-22.04
# 进入项目目录 (Windows 路径映射到 /mnt/) # 2. 创建 Conda 环境
conda create -n invoice-py311 python=3.11 -y
conda activate invoice-py311
# 3. 进入项目目录
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
```
### 2. 安装系统依赖 # 4. 安装依赖
```bash
# 更新系统
sudo apt update && sudo apt upgrade -y
# 安装 Python 和必要工具
sudo apt install -y python3.10 python3.10-venv python3-pip
# 安装 OpenCV 依赖
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6
```
### 3. 创建虚拟环境并安装依赖
```bash
# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate
# 升级 pip
pip install --upgrade pip
# 安装依赖
pip install -r requirements.txt pip install -r requirements.txt
# 或使用 pip install (开发模式) # 5. 安装 Web 依赖
pip install -e . pip install uvicorn fastapi python-multipart pydantic
```
### GPU 支持 (可选)
```bash
# 确保 WSL 已配置 CUDA
nvidia-smi # 检查 GPU 是否可用
# 安装 GPU 版本 PaddlePaddle
pip install paddlepaddle-gpu
# 或指定 CUDA 版本
pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
``` ```
## 快速开始 ## 快速开始
@@ -92,12 +64,14 @@ pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/w
### 1. 准备数据 ### 1. 准备数据
``` ```
data/ ~/invoice-data/
├── raw_pdfs/ ├── raw_pdfs/
│ ├── {DocumentId}.pdf │ ├── {DocumentId}.pdf
│ └── ... │ └── ...
── structured_data/ ── structured_data/
└── invoices.csv └── document_export_YYYYMMDD.csv
└── dataset/
└── temp/ (渲染的图片)
``` ```
CSV 格式: CSV 格式:
@@ -109,118 +83,336 @@ DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
### 2. 自动标注 ### 2. 自动标注
```bash ```bash
# 使用双池模式 (CPU + GPU)
python -m src.cli.autolabel \ python -m src.cli.autolabel \
--csv data/structured_data/invoices.csv \ --dual-pool \
--pdf-dir data/raw_pdfs \ --cpu-workers 3 \
--output data/dataset \ --gpu-workers 1
--report reports/autolabel_report.jsonl
# 单线程模式
python -m src.cli.autolabel --workers 4
``` ```
### 3. 训练模型 ### 3. 训练模型
> **重要**: 务必使用 GPU 进行训练CPU 训练速度非常慢。
```bash ```bash
# GPU 训练 (强烈推荐) # 从预训练模型开始训练
python -m src.cli.train \ python -m src.cli.train \
--data data/dataset/dataset.yaml \
--model yolo11n.pt \ --model yolo11n.pt \
--epochs 100 \ --epochs 100 \
--batch 16 \ --batch 16 \
--device 0 # 使用 GPU --name invoice_yolo11n_full \
--dpi 150
# 验证 GPU 可用
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
``` ```
GPU vs CPU 训练时间对比 (100 epochs, 77 训练图片): ### 4. 增量训练
- **GPU (RTX 5080)**: ~2 分钟
- **CPU**: 30+ 分钟
### 4. 推理 当添加新数据后,可以在已训练模型基础上继续训练:
```bash ```bash
# 从已训练的 best.pt 继续训练
python -m src.cli.train \
--model runs/train/invoice_yolo11n_full/weights/best.pt \
--epochs 30 \
--batch 16 \
--name invoice_yolo11n_v2 \
--dpi 150
```
**增量训练建议**:
| 场景 | 建议 |
|------|------|
| 添加少量新数据 (<20%) | 继续训练 10-30 epochs |
| 添加大量新数据 (>50%) | 继续训练 50-100 epochs |
| 修正大量标注错误 | 从头训练 |
| 添加新的字段类型 | 从头训练 |
### 5. 推理
```bash
# 命令行推理
python -m src.cli.infer \ python -m src.cli.infer \
--model runs/train/invoice_fields/weights/best.pt \ --model runs/train/invoice_yolo11n_full/weights/best.pt \
--input path/to/invoice.pdf \ --input path/to/invoice.pdf \
--output result.json --output result.json \
--gpu
``` ```
## 输出示例 ### 6. Web 应用
```json ```bash
{ # 启动 Web 服务器
"DocumentId": "3be53fd7-d5ea-458c-a229-8d360b8ba6a9", python run_server.py --port 8000
"InvoiceNumber": "100017500321",
"InvoiceDate": "2025-12-13", # 开发模式 (自动重载)
"InvoiceDueDate": "2026-01-03", python run_server.py --debug --reload
"OCR": "100017500321",
"Bankgiro": "5393-9484", # 禁用 GPU
"Plusgiro": null, python run_server.py --no-gpu
"Amount": "114.00",
"confidence": {
"InvoiceNumber": 0.96,
"InvoiceDate": 0.92,
"Amount": 0.93
}
}
``` ```
访问 **http://localhost:8000** 使用 Web 界面。
#### Web API 端点
| 方法 | 端点 | 描述 |
|------|------|------|
| GET | `/` | Web UI 界面 |
| GET | `/api/v1/health` | 健康检查 |
| POST | `/api/v1/infer` | 上传文件并推理 |
| GET | `/api/v1/results/{filename}` | 获取可视化图片 |
## 训练配置
### YOLO 训练参数
```bash
python -m src.cli.train [OPTIONS]
Options:
--model, -m 基础模型 (默认: yolo11n.pt)
--epochs, -e 训练轮数 (默认: 100)
--batch, -b 批大小 (默认: 16)
--imgsz 图像尺寸 (默认: 1280)
--dpi PDF 渲染 DPI (默认: 150)
--name 训练名称
--limit 限制文档数 (用于测试)
--device 设备 (0=GPU, cpu)
```
### 训练最佳实践
1. **禁用翻转增强** (文本检测):
```python
fliplr=0.0, flipud=0.0
```
2. **使用 Early Stopping**:
```python
patience=20
```
3. **启用 AMP** (混合精度训练):
```python
amp=True
```
4. **保存检查点**:
```python
save_period=10
```
### 训练结果示例
使用 15,571 张训练图片100 epochs 后的结果:
| 指标 | 值 |
|------|-----|
| **mAP@0.5** | 98.7% |
| **mAP@0.5-0.95** | 87.4% |
| **Precision** | 97.5% |
| **Recall** | 95.5% |
## 项目结构 ## 项目结构
``` ```
invoice-master-poc-v2/ invoice-master-poc-v2/
├── src/ ├── src/
│ ├── pdf/ # PDF 处理模块 │ ├── cli/ # 命令行工具
│ ├── ocr/ # OCR 提取模块 │ ├── autolabel.py # 自动标注
│ ├── normalize/ # 字段规范化模块 │ ├── train.py # 模型训练
│ ├── matcher/ # 字段匹配模块 │ ├── infer.py # 推理
├── yolo/ # YOLO 标注生成 │ └── serve.py # Web 服务器
│ ├── pdf/ # PDF 处理
│ │ ├── extractor.py # 文本提取
│ │ ├── renderer.py # 图像渲染
│ │ └── detector.py # 类型检测
│ ├── ocr/ # PaddleOCR 封装
│ ├── normalize/ # 字段规范化
│ ├── matcher/ # 字段匹配
│ ├── yolo/ # YOLO 相关
│ │ ├── annotation_generator.py
│ │ └── db_dataset.py
│ ├── inference/ # 推理管道 │ ├── inference/ # 推理管道
├── data/ # 数据加载模块 │ ├── pipeline.py
│ └── cli/ # 命令行工具 │ │ ├── yolo_detector.py
├── configs/ # 配置文件 └── field_extractor.py
├── data/ # 数据目录 │ ├── processing/ # 多池处理架构
│ │ ├── worker_pool.py
│ │ ├── cpu_pool.py
│ │ ├── gpu_pool.py
│ │ ├── task_dispatcher.py
│ │ └── dual_pool_coordinator.py
│ ├── web/ # Web 应用
│ │ ├── app.py # FastAPI 应用
│ │ ├── routes.py # API 路由
│ │ ├── services.py # 业务逻辑
│ │ ├── schemas.py # 数据模型
│ │ └── config.py # 配置
│ └── data/ # 数据处理
├── config.py # 配置文件
├── run_server.py # Web 服务器启动脚本
├── runs/ # 训练输出
│ └── train/
│ └── invoice_yolo11n_full/
│ └── weights/
│ ├── best.pt
│ └── last.pt
└── requirements.txt └── requirements.txt
``` ```
## 开发优先级 ## 多池处理架构
1. ✅ 文本层 PDF 自动标注 项目使用 CPU + GPU 双池架构处理不同类型的 PDF
2. ✅ 扫描图 OCR 自动标注
3. 🔄 金额 / OCR / Bankgiro 三字段稳定
4. ⏳ 日期、Plusgiro 扩展
5. ⏳ 表格 items 处理
## 配置 ```
┌─────────────────────────────────────────────────────┐
│ DualPoolCoordinator │
│ ┌─────────────────┐ ┌─────────────────┐ │
│ │ CPU Pool │ │ GPU Pool │ │
│ │ (3 workers) │ │ (1 worker) │ │
│ │ │ │ │ │
│ │ Text PDFs │ │ Scanned PDFs │ │
│ │ ~50-87 it/s │ │ ~1-2 it/s │ │
│ └─────────────────┘ └─────────────────┘ │
│ │
│ TaskDispatcher: 根据 PDF 类型分配任务 │
└─────────────────────────────────────────────────────┘
```
编辑 `configs/default.yaml` 自定义: ### 关键设计
- PDF 渲染 DPI
- OCR 语言
- 匹配置信度阈值
- 上下文关键词
- 数据增强参数
## API 使用 - **spawn 启动方式**: 兼容 CUDA 多进程
- **as_completed()**: 无死锁结果收集
- **进程初始化器**: 每个 worker 加载一次模型
- **协调器持久化**: 跨 CSV 文件复用 worker 池
## 配置文件
### config.py
```python
# 数据库配置
DATABASE = {
'host': '192.168.68.31',
'port': 5432,
'database': 'docmaster',
'user': 'docmaster',
'password': '******',
}
# 路径配置
PATHS = {
'csv_dir': '~/invoice-data/structured_data',
'pdf_dir': '~/invoice-data/raw_pdfs',
'output_dir': '~/invoice-data/dataset',
}
```
## CLI 命令参考
### autolabel
```bash
python -m src.cli.autolabel [OPTIONS]
Options:
--csv, -c CSV 文件路径 (支持 glob)
--pdf-dir, -p PDF 文件目录
--output, -o 输出目录
--workers, -w 单线程模式 worker 数 (默认: 4)
--dual-pool 启用双池模式
--cpu-workers CPU 池 worker 数 (默认: 3)
--gpu-workers GPU 池 worker 数 (默认: 1)
--dpi 渲染 DPI (默认: 150)
--limit, -l 限制处理文档数
```
### train
```bash
python -m src.cli.train [OPTIONS]
Options:
--model, -m 基础模型路径
--epochs, -e 训练轮数 (默认: 100)
--batch, -b 批大小 (默认: 16)
--imgsz 图像尺寸 (默认: 1280)
--dpi PDF 渲染 DPI (默认: 150)
--name 训练名称
--limit 限制文档数
```
### infer
```bash
python -m src.cli.infer [OPTIONS]
Options:
--model, -m 模型路径
--input, -i 输入 PDF/图像
--output, -o 输出 JSON 路径
--confidence 置信度阈值 (默认: 0.5)
--dpi 渲染 DPI (默认: 300)
--gpu 使用 GPU
```
### serve
```bash
python run_server.py [OPTIONS]
Options:
--host 绑定地址 (默认: 0.0.0.0)
--port 端口 (默认: 8000)
--model, -m 模型路径
--confidence 置信度阈值 (默认: 0.3)
--dpi 渲染 DPI (默认: 150)
--no-gpu 禁用 GPU
--reload 开发模式自动重载
--debug 调试模式
```
## Python API
```python ```python
from src.inference import InferencePipeline from src.inference import InferencePipeline
# 初始化 # 初始化
pipeline = InferencePipeline( pipeline = InferencePipeline(
model_path='models/best.pt', model_path='runs/train/invoice_yolo11n_full/weights/best.pt',
confidence_threshold=0.5, confidence_threshold=0.3,
ocr_lang='en' use_gpu=True,
dpi=150
) )
# 处理 PDF # 处理 PDF
result = pipeline.process_pdf('invoice.pdf') result = pipeline.process_pdf('invoice.pdf')
# 获取字段 # 处理图片
print(result.fields) result = pipeline.process_image('invoice.png')
print(result.confidence)
# 获取结果
print(result.fields) # {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
print(result.to_json()) # JSON 格式输出
``` ```
## 开发状态
- [x] 文本层 PDF 自动标注
- [x] 扫描图 OCR 自动标注
- [x] 多池处理架构 (CPU + GPU)
- [x] PostgreSQL 数据库存储
- [x] YOLO 训练 (98.7% mAP@0.5)
- [x] 推理管道
- [x] 字段规范化和验证
- [x] Web 应用 (FastAPI + 前端 UI)
- [x] 增量训练支持
- [ ] 表格 items 处理
- [ ] 模型量化部署
## 许可证 ## 许可证
MIT License MIT License

216
claude.md Normal file
View File

@@ -0,0 +1,216 @@
# Claude Code Instructions - Invoice Master POC v2
## Environment Requirements
> **IMPORTANT**: This project MUST run in **WSL + Conda** environment.
| Requirement | Details |
|-------------|---------|
| **WSL** | WSL 2 with Ubuntu 22.04+ |
| **Conda** | Miniconda or Anaconda |
| **Python** | 3.10+ (managed by Conda) |
| **GPU** | NVIDIA drivers on Windows + CUDA in WSL |
```bash
# Verify environment before running any commands
uname -a # Should show "Linux"
conda --version # Should show conda version
conda activate <env> # Activate project environment
which python # Should point to conda environment
```
**All commands must be executed in WSL terminal with Conda environment activated.**
---
## Project Overview
**Automated invoice field extraction system** for Swedish PDF invoices:
- **YOLO Object Detection** (YOLOv8/v11) for field region detection
- **PaddleOCR** for text extraction
- **Multi-strategy matching** for field validation
**Stack**: Python 3.10+ | PyTorch | Ultralytics | PaddleOCR | PyMuPDF
**Target Fields**: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount
---
## Architecture Principles
### SOLID
- **Single Responsibility**: Each module handles one concern
- **Open/Closed**: Extend via new strategies, not modifying existing code
- **Liskov Substitution**: Use Protocol/ABC for interchangeable components
- **Interface Segregation**: Small, focused interfaces
- **Dependency Inversion**: Depend on abstractions, inject dependencies
### Project Structure
```
src/
├── cli/ # Entry points only, no business logic
├── pdf/ # PDF processing (extraction, rendering, detection)
├── ocr/ # OCR engines (PaddleOCR wrapper)
├── normalize/ # Field normalization and validation
├── matcher/ # Multi-strategy field matching
├── yolo/ # YOLO annotation and dataset building
├── inference/ # Inference pipeline
└── data/ # Data loading and reporting
```
### Configuration
- `configs/default.yaml` — All tunable parameters
- `config.py` — Sensitive data (credentials, use environment variables)
- Never hardcode magic numbers
---
## Python Standards
### Required
- **Type hints** on all public functions (PEP 484/585)
- **Docstrings** in Google style (PEP 257)
- **Dataclasses** for data structures (`frozen=True, slots=True` when immutable)
- **Protocol** for interfaces (PEP 544)
- **Enum** for constants
- **pathlib.Path** instead of string paths
### Naming Conventions
| Type | Convention | Example |
|------|------------|---------|
| Functions/Variables | snake_case | `extract_tokens`, `page_count` |
| Classes | PascalCase | `FieldMatcher`, `AutoLabelReport` |
| Constants | UPPER_SNAKE | `DEFAULT_DPI`, `FIELD_TYPES` |
| Private | _prefix | `_parse_date`, `_cache` |
### Import Order (isort)
1. `from __future__ import annotations`
2. Standard library
3. Third-party
4. Local modules
5. `if TYPE_CHECKING:` block
### Code Quality Tools
| Tool | Purpose | Config |
|------|---------|--------|
| Black | Formatting | line-length=100 |
| Ruff | Linting | E, F, W, I, N, D, UP, B, C4, SIM, ARG, PTH |
| MyPy | Type checking | strict=true |
| Pytest | Testing | tests/ directory |
---
## Error Handling
- Use **custom exception hierarchy** (base: `InvoiceMasterError`)
- Use **logging** instead of print (`logger = logging.getLogger(__name__)`)
- Implement **graceful degradation** with fallback strategies
- Use **context managers** for resource cleanup
---
## Machine Learning Standards
### Data Management
- **Immutable raw data**: Never modify `data/raw/`
- **Version datasets**: Track with checksum and metadata
- **Reproducible splits**: Use fixed random seed (42)
- **Split ratios**: 80% train / 10% val / 10% test
### YOLO Training
- **Disable flips** for text detection (`fliplr=0.0, flipud=0.0`)
- **Use early stopping** (`patience=20`)
- **Enable AMP** for faster training (`amp=true`)
- **Save checkpoints** periodically (`save_period=10`)
### Reproducibility
- Set random seeds: `random`, `numpy`, `torch`
- Enable deterministic mode: `torch.backends.cudnn.deterministic = True`
- Track experiment config: model, epochs, batch_size, learning_rate, dataset_version, git_commit
### Evaluation Metrics
- Precision, Recall, F1 Score
- mAP@0.5, mAP@0.5:0.95
- Per-class AP
---
## Testing Standards
### Structure
```
tests/
├── unit/ # Isolated, fast tests
├── integration/ # Multi-module tests
├── e2e/ # End-to-end workflow tests
├── fixtures/ # Test data
└── conftest.py # Shared fixtures
```
### Practices
- Follow **AAA pattern**: Arrange, Act, Assert
- Use **parametrized tests** for multiple inputs
- Use **fixtures** for shared setup
- Use **mocking** for external dependencies
- Mark slow tests with `@pytest.mark.slow`
---
## Performance
- **Parallel processing**: Use `ProcessPoolExecutor` with progress tracking
- **Lazy loading**: Use `@cached_property` for expensive resources
- **Generators**: Use for large datasets to save memory
- **Batch processing**: Process items in batches when possible
---
## Security
- **Never commit**: credentials, API keys, `.env` files
- **Use environment variables** for sensitive config
- **Validate paths**: Prevent path traversal attacks
- **Validate inputs**: At system boundaries
---
## Commands
| Task | Command |
|------|---------|
| Run autolabel | `python run_autolabel.py` |
| Train YOLO | `python -m src.cli.train --config configs/training.yaml` |
| Run inference | `python -m src.cli.infer --model models/best.pt` |
| Run tests | `pytest tests/ -v` |
| Coverage | `pytest tests/ --cov=src --cov-report=html` |
| Format | `black src/ tests/` |
| Lint | `ruff check src/ tests/ --fix` |
| Type check | `mypy src/` |
---
## DO NOT
- Hardcode file paths or magic numbers
- Use `print()` for logging
- Skip type hints on public APIs
- Write functions longer than 50 lines
- Mix business logic with I/O
- Commit credentials or `.env` files
- Use `# type: ignore` without explanation
- Use mutable default arguments
- Catch bare `except:`
- Use flip augmentation for text detection
## DO
- Use type hints everywhere
- Write descriptive docstrings
- Log with appropriate levels
- Use dataclasses for data structures
- Use enums for constants
- Use Protocol for interfaces
- Set random seeds for reproducibility
- Track experiment configurations
- Use context managers for resources
- Validate inputs at boundaries

64
config.py Normal file
View File

@@ -0,0 +1,64 @@
"""
Configuration settings for the invoice extraction system.
"""
import os
import platform
def _is_wsl() -> bool:
"""Check if running inside WSL (Windows Subsystem for Linux)."""
if platform.system() != 'Linux':
return False
# Check for WSL-specific indicators
if os.environ.get('WSL_DISTRO_NAME'):
return True
try:
with open('/proc/version', 'r') as f:
return 'microsoft' in f.read().lower()
except (FileNotFoundError, PermissionError):
return False
# PostgreSQL Database Configuration
DATABASE = {
'host': '192.168.68.31',
'port': 5432,
'database': 'docmaster',
'user': 'docmaster',
'password': '0412220',
}
# Connection string for psycopg2
def get_db_connection_string():
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
# Paths Configuration - auto-detect WSL vs Windows
if _is_wsl():
# WSL: use native Linux filesystem for better I/O performance
PATHS = {
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
'reports_dir': 'reports', # Keep reports in project directory
}
else:
# Windows or native Linux: use relative paths
PATHS = {
'csv_dir': 'data/structured_data',
'pdf_dir': 'data/raw_pdfs',
'output_dir': 'data/dataset',
'reports_dir': 'reports',
}
# Auto-labeling Configuration
AUTOLABEL = {
'workers': 2,
'dpi': 150,
'min_confidence': 0.5,
'train_ratio': 0.8,
'val_ratio': 0.1,
'test_ratio': 0.1,
'max_records_per_report': 10000,
}

619
docs/multi_pool_design.md Normal file
View File

@@ -0,0 +1,619 @@
# 多池处理架构设计文档
## 1. 研究总结
### 1.1 当前问题分析
我们之前实现的双池模式存在稳定性问题,主要原因:
| 问题 | 原因 | 解决方案 |
|------|------|----------|
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
### 1.2 推荐架构方案
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**
```
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│ Main Process │ │ CPU Workers │ │ GPU Worker │
│ │ │ (4 processes) │ │ (1 process) │
│ ┌───────────┐ │ │ │ │ │
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
│ └───────────┘ │ │ │ │ │
│ ▲ │ │ │ │ │ │ │
│ │ │ │ ▼ │ │ ▼ │
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
│ │ Collector │ │ │ │ │ │
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
│ │ │
│ ▼ │
│ ┌───────────┐ │
│ │ Database │ │
│ │ Batch │ │
│ │ Writer │ │
│ └───────────┘ │
└─────────────────┘
```
---
## 2. 核心设计原则
### 2.1 CUDA 兼容性
```python
# 关键:使用 spawn 启动方式
import multiprocessing as mp
ctx = mp.get_context("spawn")
# GPU worker 初始化时设置设备
def init_gpu_worker(gpu_id: int = 0):
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
global _ocr
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=True, ...)
```
### 2.2 Worker 初始化模式
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
```python
# 全局变量保存模型
_ocr = None
def init_worker(use_gpu: bool, gpu_id: int = 0):
global _ocr
if use_gpu:
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
else:
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from paddleocr import PaddleOCR
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
# 创建 Pool 时使用 initializer
pool = ProcessPoolExecutor(
max_workers=1,
initializer=init_worker,
initargs=(True, 0), # use_gpu=True, gpu_id=0
mp_context=mp.get_context("spawn")
)
```
### 2.3 队列模式 vs as_completed
| 方式 | 优点 | 缺点 | 适用场景 |
|------|------|------|----------|
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
---
## 3. 详细开发计划
### 阶段 1重构基础架构 (2-3天)
#### 1.1 创建 WorkerPool 抽象类
```python
# src/processing/worker_pool.py
from __future__ import annotations
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, Future
from dataclasses import dataclass
from typing import List, Any, Optional, Callable
import multiprocessing as mp
@dataclass
class TaskResult:
"""任务结果容器"""
task_id: str
success: bool
data: Any
error: Optional[str] = None
processing_time: float = 0.0
class WorkerPool(ABC):
"""Worker Pool 抽象基类"""
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
self.max_workers = max_workers
self.use_gpu = use_gpu
self.gpu_id = gpu_id
self._executor: Optional[ProcessPoolExecutor] = None
@abstractmethod
def get_initializer(self) -> Callable:
"""返回 worker 初始化函数"""
pass
@abstractmethod
def get_init_args(self) -> tuple:
"""返回初始化参数"""
pass
def start(self):
"""启动 worker pool"""
ctx = mp.get_context("spawn")
self._executor = ProcessPoolExecutor(
max_workers=self.max_workers,
mp_context=ctx,
initializer=self.get_initializer(),
initargs=self.get_init_args()
)
def submit(self, fn: Callable, *args, **kwargs) -> Future:
"""提交任务"""
if not self._executor:
raise RuntimeError("Pool not started")
return self._executor.submit(fn, *args, **kwargs)
def shutdown(self, wait: bool = True):
"""关闭 pool"""
if self._executor:
self._executor.shutdown(wait=wait)
self._executor = None
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.shutdown()
```
#### 1.2 实现 CPU 和 GPU Worker Pool
```python
# src/processing/cpu_pool.py
class CPUWorkerPool(WorkerPool):
"""CPU-only worker pool for text PDF processing"""
def __init__(self, max_workers: int = 4):
super().__init__(max_workers=max_workers, use_gpu=False)
def get_initializer(self) -> Callable:
return init_cpu_worker
def get_init_args(self) -> tuple:
return ()
# src/processing/gpu_pool.py
class GPUWorkerPool(WorkerPool):
"""GPU worker pool for OCR processing"""
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
def get_initializer(self) -> Callable:
return init_gpu_worker
def get_init_args(self) -> tuple:
return (self.gpu_id,)
```
---
### 阶段 2实现双池协调器 (2-3天)
#### 2.1 任务分发器
```python
# src/processing/task_dispatcher.py
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Tuple
class TaskType(Enum):
CPU = auto() # Text PDF
GPU = auto() # Scanned PDF
@dataclass
class Task:
id: str
task_type: TaskType
data: Any
class TaskDispatcher:
"""根据 PDF 类型分发任务到不同的 pool"""
def classify_task(self, doc_info: dict) -> TaskType:
"""判断文档是否需要 OCR"""
# 基于 PDF 特征判断
if self._is_scanned_pdf(doc_info):
return TaskType.GPU
return TaskType.CPU
def _is_scanned_pdf(self, doc_info: dict) -> bool:
"""检测是否为扫描件"""
# 1. 检查是否有可提取文本
# 2. 检查图片比例
# 3. 检查文本密度
pass
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
"""将任务分为 CPU 和 GPU 两组"""
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
return cpu_tasks, gpu_tasks
```
#### 2.2 双池协调器
```python
# src/processing/dual_pool_coordinator.py
from concurrent.futures import as_completed
from typing import List, Iterator
import logging
logger = logging.getLogger(__name__)
class DualPoolCoordinator:
"""协调 CPU 和 GPU 两个 worker pool"""
def __init__(
self,
cpu_workers: int = 4,
gpu_workers: int = 1,
gpu_id: int = 0
):
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
self.dispatcher = TaskDispatcher()
def __enter__(self):
self.cpu_pool.start()
self.gpu_pool.start()
return self
def __exit__(self, *args):
self.cpu_pool.shutdown()
self.gpu_pool.shutdown()
def process_batch(
self,
documents: List[dict],
cpu_task_fn: Callable,
gpu_task_fn: Callable,
on_result: Optional[Callable[[TaskResult], None]] = None,
on_error: Optional[Callable[[str, Exception], None]] = None
) -> List[TaskResult]:
"""
处理一批文档,自动分发到 CPU 或 GPU pool
Args:
documents: 待处理文档列表
cpu_task_fn: CPU 任务处理函数
gpu_task_fn: GPU 任务处理函数
on_result: 结果回调(可选)
on_error: 错误回调(可选)
Returns:
所有任务结果列表
"""
# 分类任务
tasks = [
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
for doc in documents
]
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
# 提交任务到各自的 pool
cpu_futures = {
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
for t in cpu_tasks
}
gpu_futures = {
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
for t in gpu_tasks
}
# 收集结果
results = []
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
for future in as_completed(all_futures):
task_id = cpu_futures.get(future) or gpu_futures.get(future)
pool_type = "CPU" if future in cpu_futures else "GPU"
try:
data = future.result(timeout=300) # 5分钟超时
result = TaskResult(task_id=task_id, success=True, data=data)
if on_result:
on_result(result)
except Exception as e:
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
if on_error:
on_error(task_id, e)
results.append(result)
return results
```
---
### 阶段 3集成到 autolabel (1-2天)
#### 3.1 修改 autolabel.py
```python
# src/cli/autolabel.py
def run_autolabel_dual_pool(args):
"""使用双池模式运行自动标注"""
from src.processing.dual_pool_coordinator import DualPoolCoordinator
# 初始化数据库批处理
db_batch = []
db_batch_size = 100
def on_result(result: TaskResult):
"""处理成功结果"""
nonlocal db_batch
db_batch.append(result.data)
if len(db_batch) >= db_batch_size:
save_documents_batch(db_batch)
db_batch.clear()
def on_error(task_id: str, error: Exception):
"""处理错误"""
logger.error(f"Task {task_id} failed: {error}")
# 创建双池协调器
with DualPoolCoordinator(
cpu_workers=args.cpu_workers or 4,
gpu_workers=args.gpu_workers or 1,
gpu_id=0
) as coordinator:
# 处理所有 CSV
for csv_file in csv_files:
documents = load_documents_from_csv(csv_file)
results = coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error
)
logger.info(f"CSV {csv_file}: {len(results)} processed")
# 保存剩余批次
if db_batch:
save_documents_batch(db_batch)
```
---
### 阶段 4测试与验证 (1-2天)
#### 4.1 单元测试
```python
# tests/unit/test_dual_pool.py
import pytest
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
class TestDualPoolCoordinator:
def test_cpu_only_batch(self):
"""测试纯 CPU 任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 10
assert all(r.success for r in results)
def test_mixed_batch(self):
"""测试混合任务批处理"""
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
docs = [
{"id": "text_1", "type": "text"},
{"id": "scan_1", "type": "scanned"},
{"id": "text_2", "type": "text"},
]
results = coord.process_batch(docs, cpu_fn, gpu_fn)
assert len(results) == 3
def test_timeout_handling(self):
"""测试超时处理"""
pass
def test_error_recovery(self):
"""测试错误恢复"""
pass
```
#### 4.2 集成测试
```python
# tests/integration/test_autolabel_dual_pool.py
def test_autolabel_with_dual_pool():
"""端到端测试双池模式"""
# 使用少量测试数据
result = subprocess.run([
"python", "-m", "src.cli.autolabel",
"--cpu-workers", "2",
"--gpu-workers", "1",
"--limit", "50"
], capture_output=True)
assert result.returncode == 0
# 验证数据库记录
```
---
## 4. 关键技术点
### 4.1 避免死锁的策略
```python
# 1. 使用 timeout
try:
result = future.result(timeout=300)
except TimeoutError:
logger.warning(f"Task timed out")
# 2. 使用哨兵值
SENTINEL = object()
queue.put(SENTINEL) # 发送结束信号
# 3. 检查进程状态
if not worker.is_alive():
logger.error("Worker died unexpectedly")
break
# 4. 先清空队列再 join
while not queue.empty():
results.append(queue.get_nowait())
worker.join(timeout=5.0)
```
### 4.2 PaddleOCR 特殊处理
```python
# PaddleOCR 必须在 worker 进程中初始化
def init_paddle_worker(gpu_id: int):
global _ocr
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# 延迟导入,确保 CUDA 环境变量生效
from paddleocr import PaddleOCR
_ocr = PaddleOCR(
use_angle_cls=True,
lang='en',
use_gpu=True,
show_log=False,
# 重要:设置 GPU 内存比例
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
)
```
### 4.3 资源监控
```python
import psutil
import GPUtil
def get_resource_usage():
"""获取系统资源使用情况"""
cpu_percent = psutil.cpu_percent(interval=1)
memory = psutil.virtual_memory()
gpu_info = []
for gpu in GPUtil.getGPUs():
gpu_info.append({
"id": gpu.id,
"memory_used": gpu.memoryUsed,
"memory_total": gpu.memoryTotal,
"utilization": gpu.load * 100
})
return {
"cpu_percent": cpu_percent,
"memory_percent": memory.percent,
"gpu": gpu_info
}
```
---
## 5. 风险评估与应对
| 风险 | 可能性 | 影响 | 应对策略 |
|------|--------|------|----------|
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1设置 gpu_mem 参数 |
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
| 任务分类错误 | 中 | 中 | 添加回退机制CPU 失败后尝试 GPU |
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
---
## 6. 备选方案
如果上述方案仍存在问题,可以考虑:
### 6.1 使用 Ray
```python
import ray
ray.init()
@ray.remote(num_cpus=1)
def cpu_task(data):
return process_text_pdf(data)
@ray.remote(num_gpus=1)
def gpu_task(data):
return process_scanned_pdf(data)
# 自动资源调度
futures = [cpu_task.remote(d) for d in cpu_docs]
futures += [gpu_task.remote(d) for d in gpu_docs]
results = ray.get(futures)
```
### 6.2 单池 + 动态 GPU 调度
保持单池模式,但在每个任务内部动态决定是否使用 GPU
```python
def process_document(doc_data):
if is_scanned_pdf(doc_data):
# 使用 GPU (需要全局锁或信号量控制并发)
with gpu_semaphore:
return process_with_ocr(doc_data)
else:
return process_text_only(doc_data)
```
---
## 7. 时间线总结
| 阶段 | 任务 | 预计工作量 |
|------|------|------------|
| 阶段 1 | 基础架构重构 | 2-3 天 |
| 阶段 2 | 双池协调器实现 | 2-3 天 |
| 阶段 3 | 集成到 autolabel | 1-2 天 |
| 阶段 4 | 测试与验证 | 1-2 天 |
| **总计** | | **6-10 天** |
---
## 8. 参考资料
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)

14
run_server.py Normal file
View File

@@ -0,0 +1,14 @@
#!/usr/bin/env python
"""
Quick start script for the web server.
Usage:
python run_server.py
python run_server.py --port 8080
python run_server.py --debug --reload
"""
from src.cli.serve import main
if __name__ == "__main__":
main()

600
src/cli/analyze_labels.py Normal file
View File

@@ -0,0 +1,600 @@
#!/usr/bin/env python3
"""
Label Analysis CLI
Analyzes auto-generated labels to identify failures and diagnose root causes.
Now reads from PostgreSQL database instead of JSONL files.
"""
import argparse
import csv
import json
import sys
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
from ..normalize import normalize_field
from ..matcher import FieldMatcher
from ..pdf import is_text_pdf, extract_text_tokens
from ..yolo.annotation_generator import FIELD_CLASSES
from ..data.db import DocumentDB
@dataclass
class FieldAnalysis:
"""Analysis result for a single field."""
field_name: str
csv_value: str
expected: bool # True if CSV has value
labeled: bool # True if label file has this field
matched: bool # True if matcher finds it
# Diagnosis
failure_reason: Optional[str] = None
details: dict = field(default_factory=dict)
@dataclass
class DocumentAnalysis:
"""Analysis result for a document."""
doc_id: str
pdf_exists: bool
pdf_type: str # "text" or "scanned"
total_pages: int
# Per-field analysis
fields: list[FieldAnalysis] = field(default_factory=list)
# Summary
csv_fields_count: int = 0 # Fields with values in CSV
labeled_fields_count: int = 0 # Fields in label files
matched_fields_count: int = 0 # Fields matcher can find
@property
def has_issues(self) -> bool:
"""Check if document has any labeling issues."""
return any(
f.expected and not f.labeled
for f in self.fields
)
@property
def missing_labels(self) -> list[FieldAnalysis]:
"""Get fields that should be labeled but aren't."""
return [f for f in self.fields if f.expected and not f.labeled]
class LabelAnalyzer:
"""Analyzes labels and diagnoses failures."""
def __init__(
self,
csv_path: str,
pdf_dir: str,
dataset_dir: str,
use_db: bool = True
):
self.csv_path = Path(csv_path)
self.pdf_dir = Path(pdf_dir)
self.dataset_dir = Path(dataset_dir)
self.use_db = use_db
self.matcher = FieldMatcher()
self.csv_data = {}
self.label_data = {}
self.report_data = {}
# Database connection
self.db = None
if use_db:
self.db = DocumentDB()
self.db.connect()
# Class ID to name mapping
self.class_names = list(FIELD_CLASSES.keys())
def load_csv(self):
"""Load CSV data."""
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
reader = csv.DictReader(f)
for row in reader:
doc_id = row['DocumentId']
self.csv_data[doc_id] = row
print(f"Loaded {len(self.csv_data)} records from CSV")
def load_labels(self):
"""Load all label files from dataset."""
for split in ['train', 'val', 'test']:
label_dir = self.dataset_dir / split / 'labels'
if not label_dir.exists():
continue
for label_file in label_dir.glob('*.txt'):
# Parse document ID from filename (uuid_page_XXX.txt)
name = label_file.stem
parts = name.rsplit('_page_', 1)
if len(parts) == 2:
doc_id = parts[0]
page_no = int(parts[1])
else:
continue
if doc_id not in self.label_data:
self.label_data[doc_id] = {'pages': {}, 'split': split}
# Parse label file
labels = []
with open(label_file, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 5:
class_id = int(parts[0])
labels.append({
'class_id': class_id,
'class_name': self.class_names[class_id],
'x_center': float(parts[1]),
'y_center': float(parts[2]),
'width': float(parts[3]),
'height': float(parts[4])
})
self.label_data[doc_id]['pages'][page_no] = labels
total_docs = len(self.label_data)
total_labels = sum(
len(labels)
for doc in self.label_data.values()
for labels in doc['pages'].values()
)
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
def load_report(self):
"""Load autolabel report from database."""
if not self.db:
print("Database not configured, skipping report loading")
return
# Get document IDs from CSV to query
doc_ids = list(self.csv_data.keys())
if not doc_ids:
return
# Query in batches to avoid memory issues
batch_size = 1000
loaded = 0
for i in range(0, len(doc_ids), batch_size):
batch_ids = doc_ids[i:i + batch_size]
for doc_id in batch_ids:
doc = self.db.get_document(doc_id)
if doc:
self.report_data[doc_id] = doc
loaded += 1
print(f"Loaded {loaded} autolabel reports from database")
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
"""Analyze a single document."""
csv_row = self.csv_data.get(doc_id, {})
label_info = self.label_data.get(doc_id, {'pages': {}})
report = self.report_data.get(doc_id, {})
# Check PDF
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
pdf_exists = pdf_path.exists()
# Skip documents without PDF if requested
if skip_missing_pdf and not pdf_exists:
return None
pdf_type = "unknown"
total_pages = 0
if pdf_exists:
pdf_type = "scanned" if not is_text_pdf(pdf_path) else "text"
total_pages = len(label_info['pages']) or report.get('total_pages', 0)
analysis = DocumentAnalysis(
doc_id=doc_id,
pdf_exists=pdf_exists,
pdf_type=pdf_type,
total_pages=total_pages
)
# Get labeled classes
labeled_classes = set()
for page_labels in label_info['pages'].values():
for label in page_labels:
labeled_classes.add(label['class_name'])
# Analyze each field
for field_name in FIELD_CLASSES.keys():
csv_value = csv_row.get(field_name, '')
if csv_value is None:
csv_value = ''
csv_value = str(csv_value).strip()
# Handle datetime values (remove time part)
if ' 00:00:00' in csv_value:
csv_value = csv_value.replace(' 00:00:00', '')
expected = bool(csv_value)
labeled = field_name in labeled_classes
field_analysis = FieldAnalysis(
field_name=field_name,
csv_value=csv_value,
expected=expected,
labeled=labeled,
matched=False
)
if expected:
analysis.csv_fields_count += 1
if labeled:
analysis.labeled_fields_count += 1
# Diagnose failures
if expected and not labeled:
field_analysis.failure_reason = self._diagnose_failure(
doc_id, field_name, csv_value, pdf_path, pdf_type, report
)
field_analysis.details = self._get_failure_details(
doc_id, field_name, csv_value, pdf_path, pdf_type
)
elif not expected and labeled:
field_analysis.failure_reason = "EXTRA_LABEL"
field_analysis.details = {'note': 'Labeled but no CSV value'}
analysis.fields.append(field_analysis)
return analysis
def _diagnose_failure(
self,
doc_id: str,
field_name: str,
csv_value: str,
pdf_path: Path,
pdf_type: str,
report: dict
) -> str:
"""Diagnose why a field wasn't labeled."""
if not pdf_path.exists():
return "PDF_NOT_FOUND"
if pdf_type == "scanned":
return "SCANNED_PDF"
# Try to match now with current normalizer (not historical report)
if pdf_path.exists() and pdf_type == "text":
try:
# Check all pages
for page_no in range(10): # Max 10 pages
try:
tokens = list(extract_text_tokens(pdf_path, page_no))
if not tokens:
break
normalized = normalize_field(field_name, csv_value)
matches = self.matcher.find_matches(tokens, field_name, normalized, page_no)
if matches:
return "MATCHER_OK_NOW" # Would match with current normalizer
except Exception:
break
return "VALUE_NOT_IN_PDF"
except Exception as e:
return f"PDF_ERROR: {str(e)[:50]}"
return "UNKNOWN"
def _get_failure_details(
self,
doc_id: str,
field_name: str,
csv_value: str,
pdf_path: Path,
pdf_type: str
) -> dict:
"""Get detailed information about a failure."""
details = {
'csv_value': csv_value,
'normalized_candidates': [],
'pdf_tokens_sample': [],
'potential_matches': []
}
# Get normalized candidates
try:
details['normalized_candidates'] = normalize_field(field_name, csv_value)
except Exception:
pass
# Get PDF tokens if available
if pdf_path.exists() and pdf_type == "text":
try:
tokens = list(extract_text_tokens(pdf_path, 0))[:100]
# Find tokens that might be related
candidates = details['normalized_candidates']
for token in tokens:
text = token.text.strip()
# Check if any candidate is substring or similar
for cand in candidates:
if cand in text or text in cand:
details['potential_matches'].append({
'token': text,
'candidate': cand,
'bbox': token.bbox
})
break
# Also collect date-like or number-like tokens for reference
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
if any(c.isdigit() for c in text) and len(text) >= 6:
details['pdf_tokens_sample'].append(text)
elif field_name == 'Amount':
if any(c.isdigit() for c in text) and (',' in text or '.' in text or len(text) >= 4):
details['pdf_tokens_sample'].append(text)
# Limit samples
details['pdf_tokens_sample'] = details['pdf_tokens_sample'][:10]
details['potential_matches'] = details['potential_matches'][:5]
except Exception:
pass
return details
def run_analysis(self, limit: Optional[int] = None, skip_missing_pdf: bool = True) -> list[DocumentAnalysis]:
"""Run analysis on all documents."""
self.load_csv()
self.load_labels()
self.load_report()
results = []
doc_ids = list(self.csv_data.keys())
skipped = 0
for doc_id in doc_ids:
analysis = self.analyze_document(doc_id, skip_missing_pdf=skip_missing_pdf)
if analysis is None:
skipped += 1
continue
results.append(analysis)
if limit and len(results) >= limit:
break
if skipped > 0:
print(f"Skipped {skipped} documents without PDF files")
return results
def generate_report(
self,
results: list[DocumentAnalysis],
output_path: str,
verbose: bool = False
):
"""Generate analysis report."""
output = Path(output_path)
output.parent.mkdir(parents=True, exist_ok=True)
# Collect statistics
stats = {
'total_documents': len(results),
'documents_with_issues': 0,
'total_expected_fields': 0,
'total_labeled_fields': 0,
'missing_labels': 0,
'extra_labels': 0,
'failure_reasons': defaultdict(int),
'failures_by_field': defaultdict(lambda: defaultdict(int))
}
issues = []
for analysis in results:
stats['total_expected_fields'] += analysis.csv_fields_count
stats['total_labeled_fields'] += analysis.labeled_fields_count
if analysis.has_issues:
stats['documents_with_issues'] += 1
for f in analysis.fields:
if f.expected and not f.labeled:
stats['missing_labels'] += 1
stats['failure_reasons'][f.failure_reason] += 1
stats['failures_by_field'][f.field_name][f.failure_reason] += 1
issues.append({
'doc_id': analysis.doc_id,
'field': f.field_name,
'csv_value': f.csv_value,
'reason': f.failure_reason,
'details': f.details if verbose else {}
})
elif not f.expected and f.labeled:
stats['extra_labels'] += 1
# Write JSON report
report = {
'summary': {
'total_documents': stats['total_documents'],
'documents_with_issues': stats['documents_with_issues'],
'issue_rate': f"{stats['documents_with_issues'] / stats['total_documents'] * 100:.1f}%",
'total_expected_fields': stats['total_expected_fields'],
'total_labeled_fields': stats['total_labeled_fields'],
'label_coverage': f"{stats['total_labeled_fields'] / max(1, stats['total_expected_fields']) * 100:.1f}%",
'missing_labels': stats['missing_labels'],
'extra_labels': stats['extra_labels']
},
'failure_reasons': dict(stats['failure_reasons']),
'failures_by_field': {
field: dict(reasons)
for field, reasons in stats['failures_by_field'].items()
},
'issues': issues
}
with open(output, 'w', encoding='utf-8') as f:
json.dump(report, f, indent=2, ensure_ascii=False)
print(f"\nReport saved to: {output}")
return report
def print_summary(report: dict):
"""Print summary to console."""
summary = report['summary']
print("\n" + "=" * 60)
print("LABEL ANALYSIS SUMMARY")
print("=" * 60)
print(f"\nDocuments:")
print(f" Total: {summary['total_documents']}")
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
print(f"\nFields:")
print(f" Expected: {summary['total_expected_fields']}")
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
print(f" Missing: {summary['missing_labels']}")
print(f" Extra: {summary['extra_labels']}")
print(f"\nFailure Reasons:")
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
print(f" {reason}: {count}")
print(f"\nFailures by Field:")
for field, reasons in report['failures_by_field'].items():
total = sum(reasons.values())
print(f" {field}: {total}")
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
print(f" - {reason}: {count}")
# Show sample issues
if report['issues']:
print(f"\n" + "-" * 60)
print("SAMPLE ISSUES (first 10)")
print("-" * 60)
for issue in report['issues'][:10]:
print(f"\n[{issue['doc_id']}] {issue['field']}")
print(f" CSV value: {issue['csv_value']}")
print(f" Reason: {issue['reason']}")
if issue.get('details'):
details = issue['details']
if details.get('normalized_candidates'):
print(f" Candidates: {details['normalized_candidates'][:5]}")
if details.get('pdf_tokens_sample'):
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
if details.get('potential_matches'):
print(f" Potential matches:")
for pm in details['potential_matches'][:3]:
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
def main():
parser = argparse.ArgumentParser(
description='Analyze auto-generated labels and diagnose failures'
)
parser.add_argument(
'--csv', '-c',
default='data/structured_data/document_export_20260109_220326.csv',
help='Path to structured data CSV file'
)
parser.add_argument(
'--pdf-dir', '-p',
default='data/raw_pdfs',
help='Directory containing PDF files'
)
parser.add_argument(
'--dataset', '-d',
default='data/dataset',
help='Dataset directory with labels'
)
parser.add_argument(
'--output', '-o',
default='reports/label_analysis.json',
help='Output path for analysis report'
)
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to analyze'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Include detailed failure information'
)
parser.add_argument(
'--single', '-s',
help='Analyze single document ID'
)
parser.add_argument(
'--no-db',
action='store_true',
help='Skip database, only analyze label files'
)
args = parser.parse_args()
analyzer = LabelAnalyzer(
csv_path=args.csv,
pdf_dir=args.pdf_dir,
dataset_dir=args.dataset,
use_db=not args.no_db
)
if args.single:
# Analyze single document
analyzer.load_csv()
analyzer.load_labels()
analyzer.load_report()
analysis = analyzer.analyze_document(args.single)
print(f"\n{'=' * 60}")
print(f"Document: {analysis.doc_id}")
print(f"{'=' * 60}")
print(f"PDF exists: {analysis.pdf_exists}")
print(f"PDF type: {analysis.pdf_type}")
print(f"Pages: {analysis.total_pages}")
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
for f in analysis.fields:
status = "" if f.labeled else ("" if f.expected else "-")
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
print(f" [{status}] {f.field_name}: {value_str}")
if f.failure_reason:
print(f" Reason: {f.failure_reason}")
if f.details.get('normalized_candidates'):
print(f" Candidates: {f.details['normalized_candidates']}")
if f.details.get('potential_matches'):
print(f" Potential matches in PDF:")
for pm in f.details['potential_matches'][:3]:
print(f" - '{pm['token']}'")
else:
# Full analysis
print("Running label analysis...")
results = analyzer.run_analysis(limit=args.limit)
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
print_summary(report)
if __name__ == '__main__':
main()

435
src/cli/analyze_report.py Normal file
View File

@@ -0,0 +1,435 @@
#!/usr/bin/env python3
"""
Analyze Auto-Label Report
Generates statistics and insights from database or autolabel_report.jsonl
"""
import argparse
import json
import sys
from collections import defaultdict
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
def load_reports_from_db() -> dict:
"""Load statistics directly from database using SQL aggregation."""
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
'by_field': defaultdict(lambda: {
'total': 0,
'matched': 0,
'exact_match': 0,
'flexible_match': 0,
'scores': [],
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}),
'errors': defaultdict(int),
'processing_times': [],
}
conn = db.connect()
with conn.cursor() as cursor:
# Overall stats
cursor.execute("""
SELECT
COUNT(*) as total,
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful,
SUM(CASE WHEN NOT success THEN 1 ELSE 0 END) as failed
FROM documents
""")
row = cursor.fetchone()
stats['total'] = row[0] or 0
stats['successful'] = row[1] or 0
stats['failed'] = row[2] or 0
# By PDF type
cursor.execute("""
SELECT
pdf_type,
COUNT(*) as total,
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful
FROM documents
GROUP BY pdf_type
""")
for row in cursor.fetchall():
pdf_type = row[0] or 'unknown'
stats['by_pdf_type'][pdf_type] = {
'total': row[1] or 0,
'successful': row[2] or 0
}
# Processing times
cursor.execute("""
SELECT AVG(processing_time_ms), MIN(processing_time_ms), MAX(processing_time_ms)
FROM documents
WHERE processing_time_ms > 0
""")
row = cursor.fetchone()
if row[0]:
stats['processing_time_stats'] = {
'avg_ms': float(row[0]),
'min_ms': float(row[1]),
'max_ms': float(row[2])
}
# Field stats
cursor.execute("""
SELECT
field_name,
COUNT(*) as total,
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched,
SUM(CASE WHEN matched AND score >= 0.99 THEN 1 ELSE 0 END) as exact_match,
SUM(CASE WHEN matched AND score < 0.99 THEN 1 ELSE 0 END) as flexible_match,
AVG(CASE WHEN matched THEN score END) as avg_score
FROM field_results
GROUP BY field_name
ORDER BY field_name
""")
for row in cursor.fetchall():
field_name = row[0]
stats['by_field'][field_name] = {
'total': row[1] or 0,
'matched': row[2] or 0,
'exact_match': row[3] or 0,
'flexible_match': row[4] or 0,
'avg_score': float(row[5]) if row[5] else 0,
'scores': [], # Not loading individual scores for efficiency
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}
# Field stats by PDF type
cursor.execute("""
SELECT
fr.field_name,
d.pdf_type,
COUNT(*) as total,
SUM(CASE WHEN fr.matched THEN 1 ELSE 0 END) as matched
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
GROUP BY fr.field_name, d.pdf_type
""")
for row in cursor.fetchall():
field_name = row[0]
pdf_type = row[1] or 'unknown'
if field_name in stats['by_field']:
stats['by_field'][field_name]['by_pdf_type'][pdf_type] = {
'total': row[2] or 0,
'matched': row[3] or 0
}
db.close()
return stats
def load_reports_from_file(report_path: str) -> list[dict]:
"""Load all reports from JSONL file(s). Supports glob patterns."""
path = Path(report_path)
# Handle glob pattern
if '*' in str(path) or '?' in str(path):
parent = path.parent
pattern = path.name
report_files = sorted(parent.glob(pattern))
else:
report_files = [path]
if not report_files:
return []
print(f"Reading {len(report_files)} report file(s):")
for f in report_files:
print(f" - {f.name}")
reports = []
for report_file in report_files:
if not report_file.exists():
continue
with open(report_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
reports.append(json.loads(line))
return reports
def analyze_reports(reports: list[dict]) -> dict:
"""Analyze reports and generate statistics."""
stats = {
'total': len(reports),
'successful': 0,
'failed': 0,
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
'by_field': defaultdict(lambda: {
'total': 0,
'matched': 0,
'exact_match': 0, # score == 1.0
'flexible_match': 0, # score < 1.0
'scores': [],
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
}),
'errors': defaultdict(int),
'processing_times': [],
}
for report in reports:
pdf_type = report.get('pdf_type') or 'unknown'
success = report.get('success', False)
# Overall stats
if success:
stats['successful'] += 1
else:
stats['failed'] += 1
# By PDF type
stats['by_pdf_type'][pdf_type]['total'] += 1
if success:
stats['by_pdf_type'][pdf_type]['successful'] += 1
# Processing time
proc_time = report.get('processing_time_ms', 0)
if proc_time > 0:
stats['processing_times'].append(proc_time)
# Errors
for error in report.get('errors', []):
stats['errors'][error] += 1
# Field results
for field_result in report.get('field_results', []):
field_name = field_result['field_name']
matched = field_result.get('matched', False)
score = field_result.get('score', 0.0)
stats['by_field'][field_name]['total'] += 1
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['total'] += 1
if matched:
stats['by_field'][field_name]['matched'] += 1
stats['by_field'][field_name]['scores'].append(score)
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['matched'] += 1
if score >= 0.99:
stats['by_field'][field_name]['exact_match'] += 1
else:
stats['by_field'][field_name]['flexible_match'] += 1
return stats
def print_report(stats: dict, verbose: bool = False):
"""Print analysis report."""
print("\n" + "=" * 60)
print("AUTO-LABEL REPORT ANALYSIS")
print("=" * 60)
# Overall stats
print(f"\n{'OVERALL STATISTICS':^60}")
print("-" * 60)
total = stats['total']
successful = stats['successful']
failed = stats['failed']
success_rate = successful / total * 100 if total > 0 else 0
print(f"Total documents: {total:>8}")
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
# Processing time
if 'processing_time_stats' in stats:
pts = stats['processing_time_stats']
print(f"\nProcessing time (ms):")
print(f" Average: {pts['avg_ms']:>8.1f}")
print(f" Min: {pts['min_ms']:>8.1f}")
print(f" Max: {pts['max_ms']:>8.1f}")
elif stats.get('processing_times'):
times = stats['processing_times']
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
print(f"\nProcessing time (ms):")
print(f" Average: {avg_time:>8.1f}")
print(f" Min: {min_time:>8.1f}")
print(f" Max: {max_time:>8.1f}")
# By PDF type
print(f"\n{'BY PDF TYPE':^60}")
print("-" * 60)
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
print("-" * 60)
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
type_total = type_stats['total']
type_success = type_stats['successful']
type_rate = type_success / type_total * 100 if type_total > 0 else 0
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
# By field
print(f"\n{'FIELD MATCH STATISTICS':^60}")
print("-" * 60)
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
print("-" * 60)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']:
continue
field_stats = stats['by_field'][field_name]
total = field_stats['total']
matched = field_stats['matched']
exact = field_stats['exact_match']
flex = field_stats['flexible_match']
rate = matched / total * 100 if total > 0 else 0
# Handle avg_score from either DB or file analysis
if 'avg_score' in field_stats:
avg_score = field_stats['avg_score']
elif field_stats['scores']:
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
else:
avg_score = 0
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
# Field match by PDF type
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
print("-" * 60)
for pdf_type in sorted(stats['by_pdf_type'].keys()):
print(f"\n[{pdf_type.upper()}]")
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
print("-" * 50)
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
if field_name not in stats['by_field']:
continue
type_stats = stats['by_field'][field_name]['by_pdf_type'].get(pdf_type, {'total': 0, 'matched': 0})
total = type_stats['total']
matched = type_stats['matched']
rate = matched / total * 100 if total > 0 else 0
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
# Errors
if stats.get('errors') and verbose:
print(f"\n{'ERRORS':^60}")
print("-" * 60)
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
print(f"{count:>5}x {error[:50]}")
print("\n" + "=" * 60)
def export_json(stats: dict, output_path: str):
"""Export statistics to JSON file."""
# Convert defaultdicts to regular dicts for JSON serialization
export_data = {
'total': stats['total'],
'successful': stats['successful'],
'failed': stats['failed'],
'by_pdf_type': dict(stats['by_pdf_type']),
'by_field': {},
'errors': dict(stats.get('errors', {})),
}
# Processing time stats
if 'processing_time_stats' in stats:
export_data['processing_time_stats'] = stats['processing_time_stats']
elif stats.get('processing_times'):
times = stats['processing_times']
export_data['processing_time_stats'] = {
'avg_ms': sum(times) / len(times),
'min_ms': min(times),
'max_ms': max(times),
'count': len(times)
}
# Field stats
for field_name, field_stats in stats['by_field'].items():
avg_score = field_stats.get('avg_score', 0)
if not avg_score and field_stats.get('scores'):
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
export_data['by_field'][field_name] = {
'total': field_stats['total'],
'matched': field_stats['matched'],
'exact_match': field_stats['exact_match'],
'flexible_match': field_stats['flexible_match'],
'match_rate': field_stats['matched'] / field_stats['total'] if field_stats['total'] > 0 else 0,
'avg_score': avg_score,
'by_pdf_type': dict(field_stats['by_pdf_type'])
}
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(export_data, f, indent=2, ensure_ascii=False)
print(f"\nStatistics exported to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Analyze auto-label report'
)
parser.add_argument(
'--report', '-r',
default=None,
help='Path to autolabel report JSONL file (uses database if not specified)'
)
parser.add_argument(
'--output', '-o',
help='Export statistics to JSON file'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='Show detailed error messages'
)
parser.add_argument(
'--from-file',
action='store_true',
help='Force reading from JSONL file instead of database'
)
args = parser.parse_args()
# Decide source
use_db = not args.from_file and args.report is None
if use_db:
print("Loading statistics from database...")
stats = load_reports_from_db()
print(f"Loaded stats for {stats['total']} documents")
else:
report_path = args.report or 'reports/autolabel_report.jsonl'
path = Path(report_path)
# Check if file exists (handle glob patterns)
if '*' not in str(path) and '?' not in str(path) and not path.exists():
print(f"Error: Report file not found: {path}")
return 1
print(f"Loading reports from: {report_path}")
reports = load_reports_from_file(report_path)
print(f"Loaded {len(reports)} reports")
stats = analyze_reports(reports)
print_report(stats, verbose=args.verbose)
if args.output:
export_json(stats, args.output)
return 0
if __name__ == '__main__':
exit(main())

View File

@@ -8,31 +8,83 @@ Generates YOLO training data from PDFs and structured CSV data.
import argparse import argparse
import sys import sys
import time import time
import os
import warnings
from pathlib import Path from pathlib import Path
from tqdm import tqdm from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing import multiprocessing
# Windows compatibility: use 'spawn' method for multiprocessing
# This is required on Windows and is also safer for libraries like PaddleOCR
if sys.platform == 'win32':
multiprocessing.set_start_method('spawn', force=True)
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS, AUTOLABEL
# Global OCR engine for worker processes (initialized once per worker) # Global OCR engine for worker processes (initialized once per worker)
_worker_ocr_engine = None _worker_ocr_engine = None
_worker_initialized = False
_worker_type = None # 'cpu' or 'gpu'
def _init_cpu_worker():
"""Initialize CPU worker (no OCR engine needed)."""
global _worker_initialized, _worker_type
_worker_initialized = True
_worker_type = 'cpu'
def _init_gpu_worker():
"""Initialize GPU worker with OCR engine (called once per worker)."""
global _worker_ocr_engine, _worker_initialized, _worker_type
# Suppress PaddlePaddle/PaddleX reinitialization warnings
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
# Set environment variable to suppress paddle warnings
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
# OCR engine will be lazily initialized on first use
_worker_ocr_engine = None
_worker_initialized = True
_worker_type = 'gpu'
def _init_worker(): def _init_worker():
"""Initialize worker process with OCR engine (called once per worker).""" """Initialize worker process with OCR engine (called once per worker).
global _worker_ocr_engine Legacy function for backwards compatibility.
# OCR engine will be lazily initialized on first use """
_worker_ocr_engine = None _init_gpu_worker()
def _get_ocr_engine(): def _get_ocr_engine():
"""Get or create OCR engine for current worker.""" """Get or create OCR engine for current worker."""
global _worker_ocr_engine global _worker_ocr_engine
if _worker_ocr_engine is None: if _worker_ocr_engine is None:
# Suppress warnings during OCR initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from ..ocr import OCREngine from ..ocr import OCREngine
_worker_ocr_engine = OCREngine() _worker_ocr_engine = OCREngine()
return _worker_ocr_engine return _worker_ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR output_img to replace the original rendered image."""
from PIL import Image as PILImage
# Convert numpy array to PIL Image and save
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
# If output_img is None, the original image is already saved
def process_single_document(args_tuple): def process_single_document(args_tuple):
""" """
Process a single document (worker function for parallel processing). Process a single document (worker function for parallel processing).
@@ -47,8 +99,7 @@ def process_single_document(args_tuple):
# Import inside worker to avoid pickling issues # Import inside worker to avoid pickling issues
from ..data import AutoLabelReport, FieldMatchResult from ..data import AutoLabelReport, FieldMatchResult
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens from ..pdf import PDFDocument
from ..pdf.renderer import get_render_dimensions
from ..matcher import FieldMatcher from ..matcher import FieldMatcher
from ..normalize import normalize_field from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
@@ -70,8 +121,11 @@ def process_single_document(args_tuple):
} }
try: try:
# Check PDF type # Use PDFDocument context manager for efficient PDF handling
use_ocr = not is_text_pdf(pdf_path) # Opens PDF only once, caches dimensions, handles cleanup automatically
with PDFDocument(pdf_path) as pdf_doc:
# Check PDF type (uses cached document)
use_ocr = not pdf_doc.is_text_pdf()
report.pdf_type = "scanned" if use_ocr else "text" report.pdf_type = "scanned" if use_ocr else "text"
# Skip OCR if requested # Skip OCR if requested
@@ -91,20 +145,37 @@ def process_single_document(args_tuple):
# Process each page # Process each page
page_annotations = [] page_annotations = []
matched_fields = set()
for page_no, image_path in render_pdf_to_images( # Render all pages and process (uses cached document handle)
pdf_path, images_dir = output_dir / 'temp' / doc_id / 'images'
output_dir / 'temp' / doc_id / 'images', for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
dpi=dpi
):
report.total_pages += 1 report.total_pages += 1
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
# Get dimensions from cache (no additional PDF open)
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Extract tokens # Extract tokens
if use_ocr: if use_ocr:
tokens = ocr_engine.extract_from_image(str(image_path), page_no) # Use extract_with_image to get both tokens and preprocessed image
# PaddleOCR coordinates are relative to output_img, not original image
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi
)
tokens = ocr_result.tokens
# Save output_img to replace the original rendered image
# This ensures coordinates match the saved image
_save_output_img(ocr_result.output_img, image_path)
# Update image dimensions to match output_img
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
else: else:
tokens = list(extract_text_tokens(pdf_path, page_no)) # Use cached document for text extraction
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Match fields # Match fields
matches = {} matches = {}
@@ -120,6 +191,7 @@ def process_single_document(args_tuple):
if field_matches: if field_matches:
best = field_matches[0] best = field_matches[0]
matches[field_name] = field_matches matches[field_name] = field_matches
matched_fields.add(field_name)
report.add_field_result(FieldMatchResult( report.add_field_result(FieldMatchResult(
field_name=field_name, field_name=field_name,
csv_value=str(value), csv_value=str(value),
@@ -131,23 +203,14 @@ def process_single_document(args_tuple):
page_no=page_no, page_no=page_no,
context_keywords=best.context_keywords context_keywords=best.context_keywords
)) ))
else:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=page_no
))
# Generate annotations # Count annotations
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi) annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
if annotations: if annotations:
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
generator.save_annotations(annotations, label_path)
page_annotations.append({ page_annotations.append({
'image_path': str(image_path), 'image_path': str(image_path),
'label_path': str(label_path), 'page_no': page_no,
'count': len(annotations) 'count': len(annotations)
}) })
@@ -156,6 +219,17 @@ def process_single_document(args_tuple):
class_name = list(FIELD_CLASSES.keys())[ann.class_id] class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result['stats'][class_name] += 1 result['stats'][class_name] += 1
# Record unmatched fields
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if value and field_name not in matched_fields:
report.add_field_result(FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=-1
))
if page_annotations: if page_annotations:
result['pages'] = page_annotations result['pages'] = page_annotations
result['success'] = True result['success'] = True
@@ -178,47 +252,41 @@ def main():
) )
parser.add_argument( parser.add_argument(
'--csv', '-c', '--csv', '-c',
default='data/structured_data/document_export_20260109_212743.csv', default=f"{PATHS['csv_dir']}/*.csv",
help='Path to structured data CSV file' help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
) )
parser.add_argument( parser.add_argument(
'--pdf-dir', '-p', '--pdf-dir', '-p',
default='data/raw_pdfs', default=PATHS['pdf_dir'],
help='Directory containing PDF files' help='Directory containing PDF files'
) )
parser.add_argument( parser.add_argument(
'--output', '-o', '--output', '-o',
default='data/dataset', default=PATHS['output_dir'],
help='Output directory for dataset' help='Output directory for dataset'
) )
parser.add_argument( parser.add_argument(
'--dpi', '--dpi',
type=int, type=int,
default=300, default=AUTOLABEL['dpi'],
help='DPI for PDF rendering (default: 300)' help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
) )
parser.add_argument( parser.add_argument(
'--min-confidence', '--min-confidence',
type=float, type=float,
default=0.7, default=AUTOLABEL['min_confidence'],
help='Minimum match confidence (default: 0.7)' help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
)
parser.add_argument(
'--train-ratio',
type=float,
default=0.8,
help='Training set ratio (default: 0.8)'
)
parser.add_argument(
'--val-ratio',
type=float,
default=0.1,
help='Validation set ratio (default: 0.1)'
) )
parser.add_argument( parser.add_argument(
'--report', '--report',
default='reports/autolabel_report.jsonl', default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
help='Path for auto-label report (JSONL)' help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
)
parser.add_argument(
'--max-records',
type=int,
default=10000,
help='Max records per report file for sharding (default: 10000, 0 = single file)'
) )
parser.add_argument( parser.add_argument(
'--single', '--single',
@@ -233,20 +301,37 @@ def main():
'--workers', '-w', '--workers', '-w',
type=int, type=int,
default=4, default=4,
help='Number of parallel workers (default: 4)' help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
)
parser.add_argument(
'--cpu-workers',
type=int,
default=None,
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
)
parser.add_argument(
'--gpu-workers',
type=int,
default=1,
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
) )
parser.add_argument( parser.add_argument(
'--skip-ocr', '--skip-ocr',
action='store_true', action='store_true',
help='Skip scanned PDFs (text-layer only)' help='Skip scanned PDFs (text-layer only)'
) )
parser.add_argument(
'--limit', '-l',
type=int,
default=None,
help='Limit number of documents to process (for testing)'
)
args = parser.parse_args() args = parser.parse_args()
# Import here to avoid slow startup # Import here to avoid slow startup
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
from ..data.autolabel_report import ReportWriter from ..data.autolabel_report import ReportWriter
from ..yolo import DatasetBuilder
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
from ..pdf.renderer import get_render_dimensions from ..pdf.renderer import get_render_dimensions
from ..ocr import OCREngine from ..ocr import OCREngine
@@ -254,66 +339,206 @@ def main():
from ..normalize import normalize_field from ..normalize import normalize_field
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
print(f"Loading CSV data from: {args.csv}") # Handle comma-separated CSV paths
loader = CSVLoader(args.csv, args.pdf_dir) csv_input = args.csv
if ',' in csv_input and '*' not in csv_input:
csv_input = [p.strip() for p in csv_input.split(',')]
# Validate data # Get list of CSV files (don't load all data at once)
issues = loader.validate() temp_loader = CSVLoader(csv_input, args.pdf_dir)
if issues: csv_files = temp_loader.csv_paths
print(f"Warning: Found {len(issues)} validation issues") pdf_dir = temp_loader.pdf_dir
print(f"Found {len(csv_files)} CSV file(s) to process")
# Setup output directories
output_dir = Path(args.output)
# Only create temp directory for images (no train/val/test split during labeling)
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
# Report writer with optional sharding
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
# Database connection for checking existing documents
from ..data.db import DocumentDB
db = DocumentDB()
db.connect()
print("Connected to database for status checking")
# Global stats
stats = {
'total': 0,
'successful': 0,
'failed': 0,
'skipped': 0,
'skipped_db': 0, # Skipped because already in DB
'retried': 0, # Re-processed failed ones
'annotations': 0,
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Track all processed items for final split (write to temp file to save memory)
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
processed_count = 0
seen_doc_ids = set()
# Batch for database updates
db_batch = []
DB_BATCH_SIZE = 100
# Helper function to handle result and update database
# Defined outside the loop so nonlocal can properly reference db_batch
def handle_result(result):
nonlocal processed_count, db_batch
# Write report to file
if result['report']:
report_writer.write_dict(result['report'])
# Add to database batch
db_batch.append(result['report'])
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if result['success']:
# Write to temp file instead of memory
import json
processed_items_writer.write(json.dumps({
'doc_id': result['doc_id'],
'pages': result['pages']
}) + '\n')
processed_items_writer.flush()
processed_count += 1
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
def handle_error(doc_id, error):
nonlocal db_batch
stats['failed'] += 1
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(error)}"]
}
report_writer.write_dict(error_report)
db_batch.append(error_report)
if len(db_batch) >= DB_BATCH_SIZE:
db.save_documents_batch(db_batch)
db_batch.clear()
if args.verbose: if args.verbose:
for issue in issues[:10]: print(f"Error processing {doc_id}: {error}")
print(f" - {issue}")
rows = loader.load_all() # Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
print(f"Loaded {len(rows)} invoice records") dual_pool_coordinator = None
use_dual_pool = args.cpu_workers is not None
if use_dual_pool:
from src.processing import DualPoolCoordinator
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
dual_pool_coordinator = DualPoolCoordinator(
cpu_workers=args.cpu_workers,
gpu_workers=args.gpu_workers,
gpu_id=0,
task_timeout=300.0,
)
dual_pool_coordinator.start()
try:
# Process CSV files one by one (streaming)
for csv_idx, csv_file in enumerate(csv_files):
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
# Load only this CSV file
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
rows = single_loader.load_all()
# Filter to single document if specified # Filter to single document if specified
if args.single: if args.single:
rows = [r for r in rows if r.DocumentId == args.single] rows = [r for r in rows if r.DocumentId == args.single]
if not rows: if not rows:
print(f"Error: Document {args.single} not found")
sys.exit(1)
print(f"Processing single document: {args.single}")
# Setup output directories
output_dir = Path(args.output)
for split in ['train', 'val', 'test']:
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
# Generate YOLO config files
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
# Report writer
report_path = Path(args.report)
report_path.parent.mkdir(parents=True, exist_ok=True)
report_writer = ReportWriter(args.report)
# Stats
stats = {
'total': len(rows),
'successful': 0,
'failed': 0,
'skipped': 0,
'annotations': 0,
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
}
# Prepare tasks
tasks = []
for row in rows:
pdf_path = loader.get_pdf_path(row)
if not pdf_path:
# Write report for missing PDF
report = AutoLabelReport(document_id=row.DocumentId)
report.errors.append("PDF not found")
report_writer.write(report)
stats['failed'] += 1
continue continue
# Convert row to dict for pickling # Deduplicate across CSV files
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
for r in rows:
seen_doc_ids.add(r.DocumentId)
if not rows:
print(f" Skipping CSV (no new documents)")
continue
# Batch query database for all document IDs in this CSV
csv_doc_ids = [r.DocumentId for r in rows]
db_status_map = db.check_documents_status_batch(csv_doc_ids)
# Count how many are already processed successfully
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
# Skip entire CSV if all documents are already processed
if already_processed == len(rows):
print(f" Skipping CSV (all {len(rows)} documents already processed)")
stats['skipped_db'] += len(rows)
continue
# Count how many new documents need processing in this CSV
new_to_process = len(rows) - already_processed
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
stats['total'] += len(rows)
# Prepare tasks for this CSV
tasks = []
skipped_in_csv = 0
retry_in_csv = 0
# Calculate how many more we can process if limit is set
# Use tasks_submitted counter which tracks across all CSVs
if args.limit:
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
if remaining_limit <= 0:
print(f" Reached limit of {args.limit} new documents, stopping.")
break
else:
remaining_limit = float('inf')
for row in rows:
# Stop adding tasks if we've reached the limit
if len(tasks) >= remaining_limit:
break
doc_id = row.DocumentId
# Check document status from batch query result
db_status = db_status_map.get(doc_id) # None if not in DB
# Skip if already successful in database
if db_status is True:
stats['skipped_db'] += 1
skipped_in_csv += 1
continue
# Check if this is a retry (was failed before)
if db_status is False:
stats['retried'] += 1
retry_in_csv += 1
pdf_path = single_loader.get_pdf_path(row)
if not pdf_path:
stats['skipped'] += 1
continue
row_dict = { row_dict = {
'DocumentId': row.DocumentId, 'DocumentId': row.DocumentId,
'InvoiceNumber': row.InvoiceNumber, 'InvoiceNumber': row.InvoiceNumber,
@@ -334,33 +559,87 @@ def main():
args.skip_ocr args.skip_ocr
)) ))
if skipped_in_csv > 0 or retry_in_csv > 0:
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
if not tasks:
continue
# Update tasks_submitted counter for limit tracking
stats['tasks_submitted'] += len(tasks)
if use_dual_pool:
# Dual-pool mode using pre-initialized DualPoolCoordinator
# (process_text_pdf, process_scanned_pdf already imported above)
# Convert tasks to new format
documents = []
for task in tasks:
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
# Pre-classify PDF type
try:
is_text = is_text_pdf(pdf_path_str)
except Exception:
is_text = False
documents.append({
"id": row_dict["DocumentId"],
"row_dict": row_dict,
"pdf_path": pdf_path_str,
"output_dir": output_dir_str,
"dpi": dpi,
"min_confidence": min_confidence,
"is_scanned": not is_text,
"has_text": is_text,
"text_length": 1000 if is_text else 0, # Approximate
})
# Count task types
text_count = sum(1 for d in documents if not d["is_scanned"])
scan_count = len(documents) - text_count
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
# Progress tracking with tqdm
pbar = tqdm(total=len(documents), desc="Processing")
def on_result(task_result):
"""Handle successful result."""
result = task_result.data
handle_result(result)
pbar.update(1)
def on_error(task_id, error):
"""Handle failed task."""
handle_error(task_id, error)
pbar.update(1)
# Process with pre-initialized coordinator (workers stay alive)
results = dual_pool_coordinator.process_batch(
documents=documents,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=on_result,
on_error=on_error,
id_field="id",
)
pbar.close()
# Log summary
successful = sum(1 for r in results if r.success)
failed = len(results) - successful
print(f" Batch complete: {successful} successful, {failed} failed")
else:
# Single-pool mode (original behavior)
print(f" Processing {len(tasks)} documents with {args.workers} workers...") print(f" Processing {len(tasks)} documents with {args.workers} workers...")
# Process documents in parallel # Process documents in parallel (inside CSV loop for streaming)
processed_items = []
# Use single process for debugging or when workers=1 # Use single process for debugging or when workers=1
if args.workers == 1: if args.workers == 1:
for task in tqdm(tasks, desc="Processing"): for task in tqdm(tasks, desc="Processing"):
result = process_single_document(task) result = process_single_document(task)
handle_result(result)
# Write report
if result['report']:
report_writer.write_dict(result['report'])
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
else: else:
# Parallel processing with worker initialization # Parallel processing with worker initialization
# Each worker initializes OCR engine once and reuses it # Each worker initializes OCR engine once and reuses it
@@ -372,67 +651,31 @@ def main():
doc_id = futures[future] doc_id = futures[future]
try: try:
result = future.result() result = future.result()
handle_result(result)
# Write report
if result['report']:
report_writer.write_dict(result['report'])
if result['success']:
processed_items.append({
'doc_id': result['doc_id'],
'pages': result['pages']
})
stats['successful'] += 1
for field, count in result['stats'].items():
stats['by_field'][field] += count
stats['annotations'] += count
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
stats['skipped'] += 1
else:
stats['failed'] += 1
except Exception as e: except Exception as e:
stats['failed'] += 1 handle_error(doc_id, e)
# Write error report for failed documents
error_report = {
'document_id': doc_id,
'success': False,
'errors': [f"Worker error: {str(e)}"]
}
report_writer.write_dict(error_report)
if args.verbose:
print(f"Error processing {doc_id}: {e}")
# Split and move files # Flush remaining database batch after each CSV
import random if db_batch:
random.seed(42) db.save_documents_batch(db_batch)
random.shuffle(processed_items) db_batch.clear()
n_train = int(len(processed_items) * args.train_ratio) finally:
n_val = int(len(processed_items) * args.val_ratio) # Shutdown dual-pool coordinator if it was started
if dual_pool_coordinator is not None:
dual_pool_coordinator.shutdown()
splits = { # Close temp file
'train': processed_items[:n_train], processed_items_writer.close()
'val': processed_items[n_train:n_train + n_val],
'test': processed_items[n_train + n_val:]
}
import shutil # Use the in-memory counter instead of re-reading the file (performance fix)
for split_name, items in splits.items(): # processed_count already tracks the number of successfully processed items
for item in items:
for page in item['pages']:
# Move image
image_path = Path(page['image_path'])
label_path = Path(page['label_path'])
dest_img = output_dir / split_name / 'images' / image_path.name
shutil.move(str(image_path), str(dest_img))
# Move label # Cleanup processed_items temp file (not needed anymore)
dest_label = output_dir / split_name / 'labels' / label_path.name processed_items_file.unlink(missing_ok=True)
shutil.move(str(label_path), str(dest_label))
# Cleanup temp # Close database connection
shutil.rmtree(output_dir / 'temp', ignore_errors=True) db.close()
# Print summary # Print summary
print("\n" + "=" * 50) print("\n" + "=" * 50)
@@ -441,17 +684,22 @@ def main():
print(f"Total documents: {stats['total']}") print(f"Total documents: {stats['total']}")
print(f"Successful: {stats['successful']}") print(f"Successful: {stats['successful']}")
print(f"Failed: {stats['failed']}") print(f"Failed: {stats['failed']}")
print(f"Skipped (OCR): {stats['skipped']}") print(f"Skipped (no PDF): {stats['skipped']}")
print(f"Skipped (in DB): {stats['skipped_db']}")
print(f"Retried (failed): {stats['retried']}")
print(f"Total annotations: {stats['annotations']}") print(f"Total annotations: {stats['annotations']}")
print(f"\nDataset split:") print(f"\nImages saved to: {output_dir / 'temp'}")
print(f" Train: {len(splits['train'])} documents") print(f"Labels stored in: PostgreSQL database")
print(f" Val: {len(splits['val'])} documents")
print(f" Test: {len(splits['test'])} documents")
print(f"\nAnnotations by field:") print(f"\nAnnotations by field:")
for field, count in stats['by_field'].items(): for field, count in stats['by_field'].items():
print(f" {field}: {count}") print(f" {field}: {count}")
print(f"\nOutput: {output_dir}") shard_files = report_writer.get_shard_files()
print(f"Report: {args.report}") if len(shard_files) > 1:
print(f"\nReport files ({len(shard_files)}):")
for sf in shard_files:
print(f" - {sf}")
else:
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -0,0 +1,262 @@
#!/usr/bin/env python3
"""
Import existing JSONL report files into PostgreSQL database.
Usage:
python -m src.cli.import_report_to_db --report "reports/autolabel_report_v4*.jsonl"
"""
import argparse
import json
import sys
from pathlib import Path
import psycopg2
from psycopg2.extras import execute_values
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string, PATHS
def create_tables(conn):
"""Create database tables."""
with conn.cursor() as cursor:
cursor.execute("""
CREATE TABLE IF NOT EXISTS documents (
document_id TEXT PRIMARY KEY,
pdf_path TEXT,
pdf_type TEXT,
success BOOLEAN,
total_pages INTEGER,
fields_matched INTEGER,
fields_total INTEGER,
annotations_generated INTEGER,
processing_time_ms REAL,
timestamp TIMESTAMPTZ,
errors JSONB DEFAULT '[]'
);
CREATE TABLE IF NOT EXISTS field_results (
id SERIAL PRIMARY KEY,
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
field_name TEXT,
csv_value TEXT,
matched BOOLEAN,
score REAL,
matched_text TEXT,
candidate_used TEXT,
bbox JSONB,
page_no INTEGER,
context_keywords JSONB DEFAULT '[]',
error TEXT
);
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
""")
conn.commit()
def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_size: int = 1000) -> dict:
"""Import a single JSONL file into database."""
stats = {'imported': 0, 'skipped': 0, 'errors': 0}
# Get existing document IDs if skipping
existing_ids = set()
if skip_existing:
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents")
existing_ids = {row[0] for row in cursor.fetchall()}
doc_batch = []
field_batch = []
def flush_batches():
nonlocal doc_batch, field_batch
if doc_batch:
with conn.cursor() as cursor:
execute_values(cursor, """
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
VALUES %s
ON CONFLICT (document_id) DO UPDATE SET
pdf_path = EXCLUDED.pdf_path,
pdf_type = EXCLUDED.pdf_type,
success = EXCLUDED.success,
total_pages = EXCLUDED.total_pages,
fields_matched = EXCLUDED.fields_matched,
fields_total = EXCLUDED.fields_total,
annotations_generated = EXCLUDED.annotations_generated,
processing_time_ms = EXCLUDED.processing_time_ms,
timestamp = EXCLUDED.timestamp,
errors = EXCLUDED.errors
""", doc_batch)
doc_batch = []
if field_batch:
with conn.cursor() as cursor:
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_batch)
field_batch = []
conn.commit()
with open(jsonl_path, 'r', encoding='utf-8') as f:
for line_no, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
record = json.loads(line)
except json.JSONDecodeError as e:
print(f" Warning: Line {line_no} - JSON parse error: {e}")
stats['errors'] += 1
continue
doc_id = record.get('document_id')
if not doc_id:
stats['errors'] += 1
continue
# Only import successful documents
if not record.get('success'):
stats['skipped'] += 1
continue
# Check if already exists
if skip_existing and doc_id in existing_ids:
stats['skipped'] += 1
continue
# Add to batch
doc_batch.append((
doc_id,
record.get('pdf_path'),
record.get('pdf_type'),
record.get('success'),
record.get('total_pages'),
record.get('fields_matched'),
record.get('fields_total'),
record.get('annotations_generated'),
record.get('processing_time_ms'),
record.get('timestamp'),
json.dumps(record.get('errors', []))
))
for field in record.get('field_results', []):
field_batch.append((
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
))
stats['imported'] += 1
existing_ids.add(doc_id)
# Flush batch if needed
if len(doc_batch) >= batch_size:
flush_batches()
print(f" Processed {stats['imported'] + stats['skipped']} records...")
# Final flush
flush_batches()
return stats
def main():
parser = argparse.ArgumentParser(description='Import JSONL reports to PostgreSQL database')
parser.add_argument('--report', type=str, default=f"{PATHS['reports_dir']}/autolabel_report*.jsonl",
help='Report file path or glob pattern')
parser.add_argument('--db', type=str, default=None,
help='PostgreSQL connection string (uses config.py if not specified)')
parser.add_argument('--no-skip', action='store_true',
help='Do not skip existing documents (replace them)')
parser.add_argument('--batch-size', type=int, default=1000,
help='Batch size for bulk inserts')
args = parser.parse_args()
# Use config if db not specified
db_connection = args.db or get_db_connection_string()
# Find report files
report_path = Path(args.report)
if '*' in str(report_path) or '?' in str(report_path):
parent = report_path.parent
pattern = report_path.name
report_files = sorted(parent.glob(pattern))
else:
report_files = [report_path] if report_path.exists() else []
if not report_files:
print(f"No report files found: {args.report}")
return
print(f"Found {len(report_files)} report file(s)")
# Connect to database
conn = psycopg2.connect(db_connection)
create_tables(conn)
# Import each file
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
for report_file in report_files:
print(f"\nImporting: {report_file.name}")
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
for key in total_stats:
total_stats[key] += stats[key]
# Print summary
print("\n" + "=" * 50)
print("Import Complete")
print("=" * 50)
print(f"Total imported: {total_stats['imported']}")
print(f"Total skipped: {total_stats['skipped']}")
print(f"Total errors: {total_stats['errors']}")
# Quick stats from database
with conn.cursor() as cursor:
cursor.execute("SELECT COUNT(*) FROM documents")
total_docs = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM documents WHERE success = true")
success_docs = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM field_results")
total_fields = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM field_results WHERE matched = true")
matched_fields = cursor.fetchone()[0]
conn.close()
print(f"\nDatabase Stats:")
print(f" Documents: {total_docs} ({success_docs} successful)")
print(f" Field results: {total_fields} ({matched_fields} matched)")
if total_fields > 0:
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
if __name__ == '__main__':
main()

158
src/cli/serve.py Normal file
View File

@@ -0,0 +1,158 @@
"""
Web Server CLI
Command-line interface for starting the web server.
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
# Add project root to path
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
def setup_logging(debug: bool = False) -> None:
"""Configure logging."""
level = logging.DEBUG if debug else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def parse_args() -> argparse.Namespace:
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(
description="Start the Invoice Field Extraction web server",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--host",
type=str,
default="0.0.0.0",
help="Host to bind to",
)
parser.add_argument(
"--port",
type=int,
default=8000,
help="Port to listen on",
)
parser.add_argument(
"--model",
"-m",
type=Path,
default=Path("runs/train/invoice_yolo11n_full/weights/best.pt"),
help="Path to YOLO model weights",
)
parser.add_argument(
"--confidence",
type=float,
default=0.3,
help="Detection confidence threshold",
)
parser.add_argument(
"--dpi",
type=int,
default=150,
help="DPI for PDF rendering",
)
parser.add_argument(
"--no-gpu",
action="store_true",
help="Disable GPU acceleration",
)
parser.add_argument(
"--reload",
action="store_true",
help="Enable auto-reload for development",
)
parser.add_argument(
"--workers",
type=int,
default=1,
help="Number of worker processes",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode",
)
return parser.parse_args()
def main() -> None:
"""Main entry point."""
args = parse_args()
setup_logging(debug=args.debug)
logger = logging.getLogger(__name__)
# Validate model path
if not args.model.exists():
logger.error(f"Model file not found: {args.model}")
sys.exit(1)
logger.info("=" * 60)
logger.info("Invoice Field Extraction Web Server")
logger.info("=" * 60)
logger.info(f"Model: {args.model}")
logger.info(f"Confidence threshold: {args.confidence}")
logger.info(f"GPU enabled: {not args.no_gpu}")
logger.info(f"Server: http://{args.host}:{args.port}")
logger.info("=" * 60)
# Create config
from src.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
config = AppConfig(
model=ModelConfig(
model_path=args.model,
confidence_threshold=args.confidence,
use_gpu=not args.no_gpu,
dpi=args.dpi,
),
server=ServerConfig(
host=args.host,
port=args.port,
debug=args.debug,
reload=args.reload,
workers=args.workers,
),
storage=StorageConfig(),
)
# Create and run app
import uvicorn
from src.web.app import create_app
app = create_app(config)
uvicorn.run(
app,
host=config.server.host,
port=config.server.port,
reload=config.server.reload,
workers=config.server.workers if not config.server.reload else 1,
log_level="debug" if config.server.debug else "info",
)
if __name__ == "__main__":
main()

View File

@@ -2,22 +2,26 @@
""" """
Training CLI Training CLI
Trains YOLO model on generated dataset. Trains YOLO model on dataset with labels from PostgreSQL database.
Images are read from filesystem, labels are dynamically generated from DB.
""" """
import argparse import argparse
import sys import sys
from pathlib import Path from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import PATHS
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description='Train YOLO model for invoice field detection' description='Train YOLO model for invoice field detection'
) )
parser.add_argument( parser.add_argument(
'--data', '-d', '--dataset-dir', '-d',
required=True, default=PATHS['output_dir'],
help='Path to dataset.yaml file' help='Dataset directory containing temp/{doc_id}/images/ (default: data/dataset)'
) )
parser.add_argument( parser.add_argument(
'--model', '-m', '--model', '-m',
@@ -62,24 +66,117 @@ def main():
help='Resume from checkpoint' help='Resume from checkpoint'
) )
parser.add_argument( parser.add_argument(
'--config', '--train-ratio',
help='Path to training config YAML' type=float,
default=0.8,
help='Training set ratio (default: 0.8)'
)
parser.add_argument(
'--val-ratio',
type=float,
default=0.1,
help='Validation set ratio (default: 0.1)'
)
parser.add_argument(
'--seed',
type=int,
default=42,
help='Random seed for split (default: 42)'
)
parser.add_argument(
'--dpi',
type=int,
default=300,
help='DPI used for rendering (default: 300)'
)
parser.add_argument(
'--export-only',
action='store_true',
help='Only export dataset to YOLO format, do not train'
)
parser.add_argument(
'--limit',
type=int,
default=None,
help='Limit number of documents for training (default: all)'
) )
args = parser.parse_args() args = parser.parse_args()
# Validate data file # Validate dataset directory
data_path = Path(args.data) dataset_dir = Path(args.dataset_dir)
if not data_path.exists(): temp_dir = dataset_dir / 'temp'
print(f"Error: Dataset file not found: {data_path}") if not temp_dir.exists():
print(f"Error: Temp directory not found: {temp_dir}")
print("Run autolabel first to generate images.")
sys.exit(1) sys.exit(1)
print(f"Training YOLO model for invoice field detection") print("=" * 60)
print(f"Dataset: {args.data}") print("YOLO Training with Database Labels")
print("=" * 60)
print(f"Dataset dir: {dataset_dir}")
print(f"Model: {args.model}") print(f"Model: {args.model}")
print(f"Epochs: {args.epochs}") print(f"Epochs: {args.epochs}")
print(f"Batch size: {args.batch}") print(f"Batch size: {args.batch}")
print(f"Image size: {args.imgsz}") print(f"Image size: {args.imgsz}")
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
if args.limit:
print(f"Document limit: {args.limit}")
# Connect to database
from ..data.db import DocumentDB
print("\nConnecting to database...")
db = DocumentDB()
db.connect()
# Create datasets from database
from ..yolo.db_dataset import create_datasets
print("Loading dataset from database...")
datasets = create_datasets(
images_dir=dataset_dir,
db=db,
train_ratio=args.train_ratio,
val_ratio=args.val_ratio,
seed=args.seed,
dpi=args.dpi,
limit=args.limit
)
print(f"\nDataset splits:")
print(f" Train: {len(datasets['train'])} items")
print(f" Val: {len(datasets['val'])} items")
print(f" Test: {len(datasets['test'])} items")
if len(datasets['train']) == 0:
print("\nError: No training data found!")
print("Make sure autolabel has been run and images exist in temp directory.")
db.close()
sys.exit(1)
# Export to YOLO format (required for Ultralytics training)
print("\nExporting dataset to YOLO format...")
for split_name, dataset in datasets.items():
count = dataset.export_to_yolo_format(dataset_dir, split_name)
print(f" {split_name}: {count} items exported")
# Generate YOLO config files
from ..yolo.annotation_generator import AnnotationGenerator
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
if args.export_only:
print("\nExport complete (--export-only specified, skipping training)")
db.close()
return
# Start training
print("\n" + "=" * 60)
print("Starting YOLO Training")
print("=" * 60)
from ultralytics import YOLO from ultralytics import YOLO
@@ -91,8 +188,9 @@ def main():
model = YOLO(args.model) model = YOLO(args.model)
# Training arguments # Training arguments
data_yaml = dataset_dir / 'dataset.yaml'
train_args = { train_args = {
'data': str(data_path.absolute()), 'data': str(data_yaml.absolute()),
'epochs': args.epochs, 'epochs': args.epochs,
'batch': args.batch, 'batch': args.batch,
'imgsz': args.imgsz, 'imgsz': args.imgsz,
@@ -121,18 +219,21 @@ def main():
results = model.train(**train_args) results = model.train(**train_args)
# Print results # Print results
print("\n" + "=" * 50) print("\n" + "=" * 60)
print("Training Complete") print("Training Complete")
print("=" * 50) print("=" * 60)
print(f"Best model: {args.project}/{args.name}/weights/best.pt") print(f"Best model: {args.project}/{args.name}/weights/best.pt")
print(f"Last model: {args.project}/{args.name}/weights/last.pt") print(f"Last model: {args.project}/{args.name}/weights/last.pt")
# Validate on test set # Validate on test set
print("\nRunning validation...") print("\nRunning validation on test set...")
metrics = model.val() metrics = model.val(split='test')
print(f"mAP50: {metrics.box.map50:.4f}") print(f"mAP50: {metrics.box.map50:.4f}")
print(f"mAP50-95: {metrics.box.map:.4f}") print(f"mAP50-95: {metrics.box.map:.4f}")
# Close database
db.close()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@@ -114,57 +114,106 @@ class AutoLabelReport:
class ReportWriter: class ReportWriter:
"""Writes auto-label reports to file.""" """Writes auto-label reports to file with optional sharding."""
def __init__(self, output_path: str | Path): def __init__(
self,
output_path: str | Path,
max_records_per_file: int = 0
):
""" """
Initialize report writer. Initialize report writer.
Args: Args:
output_path: Path to output JSONL file output_path: Path to output JSONL file (base name if sharding)
max_records_per_file: Max records per file (0 = no limit, single file)
""" """
self.output_path = Path(output_path) self.output_path = Path(output_path)
self.output_path.parent.mkdir(parents=True, exist_ok=True) self.output_path.parent.mkdir(parents=True, exist_ok=True)
self.max_records_per_file = max_records_per_file
# Sharding state
self._current_shard = 0
self._records_in_current_shard = 0
self._shard_files: list[Path] = []
def _get_shard_path(self) -> Path:
"""Get the path for current shard."""
if self.max_records_per_file > 0:
base = self.output_path.stem
suffix = self.output_path.suffix
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
else:
shard_path = self.output_path
if shard_path not in self._shard_files:
self._shard_files.append(shard_path)
return shard_path
def _check_shard_rotation(self) -> None:
"""Check if we need to rotate to a new shard file."""
if self.max_records_per_file > 0:
if self._records_in_current_shard >= self.max_records_per_file:
self._current_shard += 1
self._records_in_current_shard = 0
def write(self, report: AutoLabelReport) -> None: def write(self, report: AutoLabelReport) -> None:
"""Append a report to the output file.""" """Append a report to the output file."""
with open(self.output_path, 'a', encoding='utf-8') as f: self._check_shard_rotation()
shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(report.to_json() + '\n') f.write(report.to_json() + '\n')
self._records_in_current_shard += 1
def write_dict(self, report_dict: dict) -> None: def write_dict(self, report_dict: dict) -> None:
"""Append a report dict to the output file (for parallel processing).""" """Append a report dict to the output file (for parallel processing)."""
import json self._check_shard_rotation()
with open(self.output_path, 'a', encoding='utf-8') as f: shard_path = self._get_shard_path()
with open(shard_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n') f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
f.flush() f.flush()
self._records_in_current_shard += 1
def write_batch(self, reports: list[AutoLabelReport]) -> None: def write_batch(self, reports: list[AutoLabelReport]) -> None:
"""Write multiple reports.""" """Write multiple reports."""
with open(self.output_path, 'a', encoding='utf-8') as f:
for report in reports: for report in reports:
f.write(report.to_json() + '\n') self.write(report)
def get_shard_files(self) -> list[Path]:
"""Get list of all shard files created."""
return self._shard_files.copy()
class ReportReader: class ReportReader:
"""Reads auto-label reports from file.""" """Reads auto-label reports from file(s)."""
def __init__(self, input_path: str | Path): def __init__(self, input_path: str | Path):
""" """
Initialize report reader. Initialize report reader.
Args: Args:
input_path: Path to input JSONL file input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
""" """
self.input_path = Path(input_path) self.input_path = Path(input_path)
# Handle glob pattern
if '*' in str(input_path) or '?' in str(input_path):
parent = self.input_path.parent
pattern = self.input_path.name
self.input_paths = sorted(parent.glob(pattern))
else:
self.input_paths = [self.input_path]
def read_all(self) -> list[AutoLabelReport]: def read_all(self) -> list[AutoLabelReport]:
"""Read all reports from file.""" """Read all reports from file(s)."""
reports = [] reports = []
if not self.input_path.exists(): for input_path in self.input_paths:
return reports if not input_path.exists():
continue
with open(self.input_path, 'r', encoding='utf-8') as f: with open(input_path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:

View File

@@ -72,7 +72,7 @@ class CSVLoader:
def __init__( def __init__(
self, self,
csv_path: str | Path, csv_path: str | Path | list[str | Path],
pdf_dir: str | Path | None = None, pdf_dir: str | Path | None = None,
doc_map_path: str | Path | None = None, doc_map_path: str | Path | None = None,
encoding: str = 'utf-8' encoding: str = 'utf-8'
@@ -81,13 +81,31 @@ class CSVLoader:
Initialize CSV loader. Initialize CSV loader.
Args: Args:
csv_path: Path to the CSV file csv_path: Path to CSV file(s). Can be:
- Single path: 'data/file.csv'
- List of paths: ['data/file1.csv', 'data/file2.csv']
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
pdf_dir: Directory containing PDF files (default: data/raw_pdfs) pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
doc_map_path: Optional path to document mapping CSV doc_map_path: Optional path to document mapping CSV
encoding: CSV file encoding (default: utf-8) encoding: CSV file encoding (default: utf-8)
""" """
self.csv_path = Path(csv_path) # Handle multiple CSV files
self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs' if isinstance(csv_path, list):
self.csv_paths = [Path(p) for p in csv_path]
else:
csv_path = Path(csv_path)
# Check if it's a glob pattern (contains * or ?)
if '*' in str(csv_path) or '?' in str(csv_path):
parent = csv_path.parent
pattern = csv_path.name
self.csv_paths = sorted(parent.glob(pattern))
else:
self.csv_paths = [csv_path]
# For backward compatibility
self.csv_path = self.csv_paths[0] if self.csv_paths else None
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
self.doc_map_path = Path(doc_map_path) if doc_map_path else None self.doc_map_path = Path(doc_map_path) if doc_map_path else None
self.encoding = encoding self.encoding = encoding
@@ -185,21 +203,14 @@ class CSVLoader:
raw_data=dict(row) raw_data=dict(row)
) )
def load_all(self) -> list[InvoiceRow]: def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
"""Load all rows from CSV.""" """Iterate over rows from a single CSV file."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows."""
# Handle BOM - try utf-8-sig first to handle BOM correctly # Handle BOM - try utf-8-sig first to handle BOM correctly
encodings = ['utf-8-sig', self.encoding, 'latin-1'] encodings = ['utf-8-sig', self.encoding, 'latin-1']
for enc in encodings: for enc in encodings:
try: try:
with open(self.csv_path, 'r', encoding=enc) as f: with open(csv_path, 'r', encoding=enc) as f:
reader = csv.DictReader(f) reader = csv.DictReader(f)
for row in reader: for row in reader:
parsed = self._parse_row(row) parsed = self._parse_row(row)
@@ -209,7 +220,27 @@ class CSVLoader:
except UnicodeDecodeError: except UnicodeDecodeError:
continue continue
raise ValueError(f"Could not read CSV file with any supported encoding") raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
def load_all(self) -> list[InvoiceRow]:
"""Load all rows from CSV(s)."""
rows = []
for row in self.iter_rows():
rows.append(row)
return rows
def iter_rows(self) -> Iterator[InvoiceRow]:
"""Iterate over CSV rows from all CSV files."""
seen_doc_ids = set()
for csv_path in self.csv_paths:
if not csv_path.exists():
continue
for row in self._iter_single_csv(csv_path):
# Deduplicate by DocumentId
if row.DocumentId not in seen_doc_ids:
seen_doc_ids.add(row.DocumentId)
yield row
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None: def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
""" """
@@ -300,7 +331,7 @@ class CSVLoader:
return issues return issues
def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]: def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
"""Convenience function to load invoice CSV.""" """Convenience function to load invoice CSV(s)."""
loader = CSVLoader(csv_path, pdf_dir) loader = CSVLoader(csv_path, pdf_dir)
return loader.load_all() return loader.load_all()

429
src/data/db.py Normal file
View File

@@ -0,0 +1,429 @@
"""
Database utilities for autolabel workflow.
"""
import json
import psycopg2
from psycopg2.extras import execute_values
from typing import Set, Dict, Any, Optional
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from config import get_db_connection_string
class DocumentDB:
"""Database interface for document processing status."""
def __init__(self, connection_string: str = None):
self.connection_string = connection_string or get_db_connection_string()
self.conn = None
def connect(self):
"""Connect to database."""
if self.conn is None:
self.conn = psycopg2.connect(self.connection_string)
return self.conn
def close(self):
"""Close database connection."""
if self.conn:
self.conn.close()
self.conn = None
def __enter__(self):
self.connect()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def get_successful_doc_ids(self) -> Set[str]:
"""Get all document IDs that have been successfully processed."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = true")
return {row[0] for row in cursor.fetchall()}
def get_failed_doc_ids(self) -> Set[str]:
"""Get all document IDs that failed processing."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("SELECT document_id FROM documents WHERE success = false")
return {row[0] for row in cursor.fetchall()}
def check_document_status(self, doc_id: str) -> Optional[bool]:
"""
Check if a document exists and its success status.
Returns:
True if exists and success=true
False if exists and success=false
None if not exists
"""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT success FROM documents WHERE document_id = %s",
(doc_id,)
)
row = cursor.fetchone()
if row is None:
return None
return row[0]
def check_documents_status_batch(self, doc_ids: list[str]) -> Dict[str, Optional[bool]]:
"""
Batch check document status for multiple IDs.
Returns:
Dict mapping doc_id to status:
True if exists and success=true
False if exists and success=false
(missing from dict if not exists)
"""
if not doc_ids:
return {}
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute(
"SELECT document_id, success FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
return {row[0]: row[1] for row in cursor.fetchall()}
def delete_document(self, doc_id: str):
"""Delete a document and its field results (for re-processing)."""
conn = self.connect()
with conn.cursor() as cursor:
# field_results will be cascade deleted
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
conn.commit()
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
"""Get a single document with its field results."""
conn = self.connect()
with conn.cursor() as cursor:
# Get document
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors
FROM documents WHERE document_id = %s
""", (doc_id,))
row = cursor.fetchone()
if not row:
return None
doc = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
'field_results': []
}
# Get field results
cursor.execute("""
SELECT field_name, csv_value, matched, score, matched_text,
candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = %s
""", (doc_id,))
for fr in cursor.fetchall():
doc['field_results'].append({
'field_name': fr[0],
'csv_value': fr[1],
'matched': fr[2],
'score': fr[3],
'matched_text': fr[4],
'candidate_used': fr[5],
'bbox': fr[6] if isinstance(fr[6], list) else json.loads(fr[6]) if fr[6] else None,
'page_no': fr[7],
'context_keywords': fr[8] if isinstance(fr[8], list) else json.loads(fr[8] or '[]'),
'error': fr[9]
})
return doc
def get_all_documents_summary(self, success_only: bool = False, limit: int = None) -> list[Dict[str, Any]]:
"""Get summary of all documents (without field_results for efficiency)."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total
FROM documents
"""
if success_only:
query += " WHERE success = true"
query += " ORDER BY timestamp DESC"
if limit:
query += f" LIMIT {limit}"
cursor.execute(query)
return [
{
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6]
}
for row in cursor.fetchall()
]
def get_field_stats(self) -> Dict[str, Dict[str, int]]:
"""Get match statistics per field."""
conn = self.connect()
with conn.cursor() as cursor:
cursor.execute("""
SELECT field_name,
COUNT(*) as total,
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched
FROM field_results
GROUP BY field_name
ORDER BY field_name
""")
return {
row[0]: {'total': row[1], 'matched': row[2]}
for row in cursor.fetchall()
}
def get_failed_matches(self, field_name: str = None, limit: int = 100) -> list[Dict[str, Any]]:
"""Get field results that failed to match."""
conn = self.connect()
with conn.cursor() as cursor:
query = """
SELECT fr.document_id, fr.field_name, fr.csv_value, fr.error,
d.pdf_type
FROM field_results fr
JOIN documents d ON fr.document_id = d.document_id
WHERE fr.matched = false
"""
params = []
if field_name:
query += " AND fr.field_name = %s"
params.append(field_name)
query += f" LIMIT {limit}"
cursor.execute(query, params)
return [
{
'document_id': row[0],
'field_name': row[1],
'csv_value': row[2],
'error': row[3],
'pdf_type': row[4]
}
for row in cursor.fetchall()
]
def get_documents_batch(self, doc_ids: list[str]) -> Dict[str, Dict[str, Any]]:
"""
Get multiple documents with their field results in a single batch query.
This is much more efficient than calling get_document() in a loop.
Args:
doc_ids: List of document IDs to fetch
Returns:
Dict mapping doc_id to document data (with field_results)
"""
if not doc_ids:
return {}
conn = self.connect()
result: Dict[str, Dict[str, Any]] = {}
with conn.cursor() as cursor:
# Batch fetch all documents
cursor.execute("""
SELECT document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors
FROM documents WHERE document_id = ANY(%s)
""", (doc_ids,))
for row in cursor.fetchall():
result[row[0]] = {
'document_id': row[0],
'pdf_path': row[1],
'pdf_type': row[2],
'success': row[3],
'total_pages': row[4],
'fields_matched': row[5],
'fields_total': row[6],
'annotations_generated': row[7],
'processing_time_ms': row[8],
'timestamp': str(row[9]) if row[9] else None,
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
'field_results': []
}
if not result:
return {}
# Batch fetch all field results for these documents
cursor.execute("""
SELECT document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error
FROM field_results WHERE document_id = ANY(%s)
""", (list(result.keys()),))
for fr in cursor.fetchall():
doc_id = fr[0]
if doc_id in result:
result[doc_id]['field_results'].append({
'field_name': fr[1],
'csv_value': fr[2],
'matched': fr[3],
'score': fr[4],
'matched_text': fr[5],
'candidate_used': fr[6],
'bbox': fr[7] if isinstance(fr[7], list) else json.loads(fr[7]) if fr[7] else None,
'page_no': fr[8],
'context_keywords': fr[9] if isinstance(fr[9], list) else json.loads(fr[9] or '[]'),
'error': fr[10]
})
return result
def save_document(self, report: Dict[str, Any]):
"""Save or update a document and its field results using batch operations."""
conn = self.connect()
doc_id = report.get('document_id')
with conn.cursor() as cursor:
# Delete existing record if any (for update)
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
# Insert document
cursor.execute("""
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (
doc_id,
report.get('pdf_path'),
report.get('pdf_type'),
report.get('success'),
report.get('total_pages'),
report.get('fields_matched'),
report.get('fields_total'),
report.get('annotations_generated'),
report.get('processing_time_ms'),
report.get('timestamp'),
json.dumps(report.get('errors', []))
))
# Batch insert field results using execute_values
field_results = report.get('field_results', [])
if field_results:
field_values = [
(
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
)
for field in field_results
]
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()
def save_documents_batch(self, reports: list[Dict[str, Any]]):
"""Save multiple documents in a batch."""
if not reports:
return
conn = self.connect()
doc_ids = [r['document_id'] for r in reports]
with conn.cursor() as cursor:
# Delete existing records
cursor.execute(
"DELETE FROM documents WHERE document_id = ANY(%s)",
(doc_ids,)
)
# Batch insert documents
doc_values = [
(
r.get('document_id'),
r.get('pdf_path'),
r.get('pdf_type'),
r.get('success'),
r.get('total_pages'),
r.get('fields_matched'),
r.get('fields_total'),
r.get('annotations_generated'),
r.get('processing_time_ms'),
r.get('timestamp'),
json.dumps(r.get('errors', []))
)
for r in reports
]
execute_values(cursor, """
INSERT INTO documents
(document_id, pdf_path, pdf_type, success, total_pages,
fields_matched, fields_total, annotations_generated,
processing_time_ms, timestamp, errors)
VALUES %s
""", doc_values)
# Batch insert field results
field_values = []
for r in reports:
doc_id = r.get('document_id')
for field in r.get('field_results', []):
field_values.append((
doc_id,
field.get('field_name'),
field.get('csv_value'),
field.get('matched'),
field.get('score'),
field.get('matched_text'),
field.get('candidate_used'),
json.dumps(field.get('bbox')) if field.get('bbox') else None,
field.get('page_no'),
json.dumps(field.get('context_keywords', [])),
field.get('error')
))
if field_values:
execute_values(cursor, """
INSERT INTO field_results
(document_id, field_name, csv_value, matched, score,
matched_text, candidate_used, bbox, page_no, context_keywords, error)
VALUES %s
""", field_values)
conn.commit()

View File

@@ -72,7 +72,7 @@ class FieldExtractor:
"""Lazy-load OCR engine only when needed.""" """Lazy-load OCR engine only when needed."""
if self._ocr_engine is None: if self._ocr_engine is None:
from ..ocr import OCREngine from ..ocr import OCREngine
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu) self._ocr_engine = OCREngine(lang=self.ocr_lang)
return self._ocr_engine return self._ocr_engine
def extract_from_detection_with_pdf( def extract_from_detection_with_pdf(
@@ -290,31 +290,65 @@ class FieldExtractor:
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]: def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize monetary amount.""" """Normalize monetary amount."""
# Remove currency and common suffixes # Try to extract amount using regex patterns
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE) # Pattern 1: Number with comma as decimal (Swedish format: 1 234,56)
text = text.replace(' ', '').replace('\xa0', '') # Pattern 2: Number with dot as decimal (1234.56)
# Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK)
patterns = [
# Swedish format with space thousand separator: 1 234,56 or 1234,56
r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?',
# Simple decimal: 350.00 or 350,00
r'(\d+[,\.]\d{2})',
# Integer amount
r'(\d{2,})',
]
for pattern in patterns:
matches = re.findall(pattern, text, re.IGNORECASE)
if matches:
# Take the last match (usually the total amount)
amount_str = matches[-1]
# Clean up
amount_str = amount_str.replace(' ', '').replace('\xa0', '')
# Handle comma as decimal separator # Handle comma as decimal separator
if ',' in text and '.' not in text: if ',' in amount_str:
text = text.replace(',', '.') amount_str = amount_str.replace(',', '.')
# Try to parse as float
try: try:
amount = float(text) amount = float(amount_str)
if amount > 0:
return f"{amount:.2f}", True, None return f"{amount:.2f}", True, None
except ValueError: except ValueError:
continue
return None, False, f"Cannot parse amount: {text}" return None, False, f"Cannot parse amount: {text}"
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]: def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
"""Normalize date.""" """
Normalize date from text that may contain surrounding text.
Handles various date formats found in Swedish invoices:
- 2025-08-29 (ISO format)
- 2025.08.29 (dot separator)
- 29/08/2025 (European format)
- 29.08.2025 (European with dots)
- 20250829 (compact format)
"""
from datetime import datetime from datetime import datetime
# Common date patterns # Common date patterns - order matters, most specific first
patterns = [ patterns = [
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"), # ISO format: 2025-08-29
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"), (r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"), # Dot format: 2025.08.29 (common in Swedish)
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"), (r'(\d{4})\.(\d{1,2})\.(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
# European slash format: 29/08/2025
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
# European dot format: 29.08.2025
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
# Compact format: 20250829
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', lambda m: f"{m.group(1)}-{m.group(2)}-{m.group(3)}"),
] ]
for pattern, formatter in patterns: for pattern, formatter in patterns:
@@ -323,7 +357,9 @@ class FieldExtractor:
try: try:
date_str = formatter(match) date_str = formatter(match)
# Validate date # Validate date
datetime.strptime(date_str, '%Y-%m-%d') parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
# Sanity check: year should be reasonable (2000-2100)
if 2000 <= parsed_date.year <= 2100:
return date_str, True, None return date_str, True, None
except ValueError: except ValueError:
continue continue

View File

@@ -4,9 +4,16 @@ Field Matching Module
Matches normalized field values to tokens extracted from documents. Matches normalized field values to tokens extracted from documents.
""" """
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Protocol from typing import Protocol
import re import re
from functools import cached_property
# Pre-compiled regex patterns (module-level for efficiency)
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
class TokenLike(Protocol): class TokenLike(Protocol):
@@ -16,6 +23,93 @@ class TokenLike(Protocol):
page_no: int page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass @dataclass
class Match: class Match:
"""Represents a matched field in the document.""" """Represents a matched field in the document."""
@@ -57,18 +151,20 @@ class FieldMatcher:
def __init__( def __init__(
self, self,
context_radius: float = 100.0, # pixels context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5 min_score_threshold: float = 0.5
): ):
""" """
Initialize the matcher. Initialize the matcher.
Args: Args:
context_radius: Distance to search for context keywords context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid min_score_threshold: Minimum score to consider a match valid
""" """
self.context_radius = context_radius self.context_radius = context_radius
self.min_score_threshold = min_score_threshold self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
def find_matches( def find_matches(
self, self,
@@ -92,6 +188,9 @@ class FieldMatcher:
matches = [] matches = []
page_tokens = [t for t in tokens if t.page_no == page_no] page_tokens = [t for t in tokens if t.page_no == page_no]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values: for value in normalized_values:
# Strategy 1: Exact token match # Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name) exact_matches = self._find_exact_matches(page_tokens, value, field_name)
@@ -108,7 +207,7 @@ class FieldMatcher:
# Strategy 4: Substring match (for values embedded in longer text) # Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205" # e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'): if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name) substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches) matches.extend(substring_matches)
@@ -124,6 +223,9 @@ class FieldMatcher:
matches = self._deduplicate_matches(matches) matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True) matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold] return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches( def _find_exact_matches(
@@ -134,6 +236,8 @@ class FieldMatcher:
) -> list[Match]: ) -> list[Match]:
"""Find tokens that exactly match the value.""" """Find tokens that exactly match the value."""
matches = [] matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
for token in tokens: for token in tokens:
token_text = token.text.strip() token_text = token.text.strip()
@@ -141,13 +245,12 @@ class FieldMatcher:
# Exact match # Exact match
if token_text == value: if token_text == value:
score = 1.0 score = 1.0
# Case-insensitive match # Case-insensitive match (use cached lowercase from index)
elif token_text.lower() == value.lower(): elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95 score = 0.95
# Digits-only match for numeric fields # Digits-only match for numeric fields
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'): elif value_digits is not None:
token_digits = re.sub(r'\D', '', token_text) token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
value_digits = re.sub(r'\D', '', value)
if token_digits and token_digits == value_digits: if token_digits and token_digits == value_digits:
score = 0.9 score = 0.9
else: else:
@@ -181,7 +284,7 @@ class FieldMatcher:
) -> list[Match]: ) -> list[Match]:
"""Find value by concatenating adjacent tokens.""" """Find value by concatenating adjacent tokens."""
matches = [] matches = []
value_clean = re.sub(r'\s+', '', value) value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right) # Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0])) sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
@@ -213,7 +316,7 @@ class FieldMatcher:
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3]) concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match # Check for match
concat_clean = re.sub(r'\s+', '', concat_text) concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean: if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords( context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name tokens, start_token, field_name
@@ -252,7 +355,7 @@ class FieldMatcher:
matches = [] matches = []
# Supported fields for substring matching # Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
if field_name not in supported_fields: if field_name not in supported_fields:
return matches return matches
@@ -390,13 +493,12 @@ class FieldMatcher:
# Find all date-like tokens in the document # Find all date-like tokens in the document
date_candidates = [] date_candidates = []
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
for token in tokens: for token in tokens:
token_text = token.text.strip() token_text = token.text.strip()
# Search for date pattern in token # Search for date pattern in token (use pre-compiled pattern)
for match in date_pattern.finditer(token_text): for match in _DATE_PATTERN.finditer(token_text):
try: try:
found_date = datetime( found_date = datetime(
int(match.group(1)), int(match.group(1)),
@@ -491,10 +593,28 @@ class FieldMatcher:
target_token: TokenLike, target_token: TokenLike,
field_name: str field_name: str
) -> tuple[list[str], float]: ) -> tuple[list[str], float]:
"""Find context keywords near the target token.""" """
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, []) keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = [] found_keywords = []
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = ( target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2, (target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2 (target_token.bbox[1] + target_token.bbox[3]) / 2
@@ -509,7 +629,6 @@ class FieldMatcher:
(token.bbox[1] + token.bbox[3]) / 2 (token.bbox[1] + token.bbox[3]) / 2
) )
# Calculate distance
distance = ( distance = (
(target_center[0] - token_center[0]) ** 2 + (target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2 (target_center[1] - token_center[1]) ** 2
@@ -522,7 +641,8 @@ class FieldMatcher:
found_keywords.append(keyword) found_keywords.append(keyword)
# Calculate boost based on keywords found # Calculate boost based on keywords found
boost = min(0.15, len(found_keywords) * 0.05) # Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool: def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
@@ -548,23 +668,62 @@ class FieldMatcher:
return None return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]: def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""Remove duplicate matches based on bbox overlap.""" """
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches: if not matches:
return [] return []
# Sort by score descending # Sort by: 1) score descending, 2) prefer matches with context keywords,
matches.sort(key=lambda m: m.score, reverse=True) # 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = [] unique = []
for match in matches: for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False is_duplicate = False
for existing in unique: cells_to_check = set()
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7: for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True is_duplicate = True
break break
if is_duplicate:
break
if not is_duplicate: if not is_duplicate:
unique.append(match) unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique return unique
@@ -582,9 +741,9 @@ class FieldMatcher:
if x2 <= x1 or y2 <= y1: if x2 <= x1 or y2 <= y1:
return 0.0 return 0.0
intersection = (x2 - x1) * (y2 - y1) intersection = float(x2 - x1) * float(y2 - y1)
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0 return intersection / union if union > 0 else 0.0

View File

@@ -173,12 +173,29 @@ class FieldNormalizer:
# Integer if no decimals # Integer if no decimals
if num == int(num): if num == int(num):
variants.append(str(int(num))) int_val = int(num)
variants.append(f"{int(num)},00") variants.append(str(int_val))
variants.append(f"{int(num)}.00") variants.append(f"{int_val},00")
variants.append(f"{int_val}.00")
# European format with dot as thousand separator (e.g., 20.485,00)
if int_val >= 1000:
# Format: XX.XXX,XX
formatted = f"{int_val:,}".replace(',', '.')
variants.append(formatted) # 20.485
variants.append(f"{formatted},00") # 20.485,00
else: else:
variants.append(f"{num:.2f}") variants.append(f"{num:.2f}")
variants.append(f"{num:.2f}".replace('.', ',')) variants.append(f"{num:.2f}".replace('.', ','))
# European format with dot as thousand separator
if num >= 1000:
# Split integer and decimal parts
int_part = int(num)
dec_part = num - int_part
formatted_int = f"{int_part:,}".replace(',', '.')
formatted = f"{formatted_int},{dec_part:.2f}"[2:] # Remove "0."
variants.append(f"{formatted_int},{int(dec_part * 100):02d}") # 20.485,00
except ValueError: except ValueError:
pass pass
@@ -247,9 +264,35 @@ class FieldNormalizer:
iso = parsed_date.strftime('%Y-%m-%d') iso = parsed_date.strftime('%Y-%m-%d')
eu_slash = parsed_date.strftime('%d/%m/%Y') eu_slash = parsed_date.strftime('%d/%m/%Y')
eu_dot = parsed_date.strftime('%d.%m.%Y') eu_dot = parsed_date.strftime('%d.%m.%Y')
compact = parsed_date.strftime('%Y%m%d') compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
variants.extend([iso, eu_slash, eu_dot, compact]) # Short year with dot separator (e.g., 02.01.26)
eu_dot_short = parsed_date.strftime('%d.%m.%y')
# Spaced formats (e.g., "2026 01 12", "26 01 12")
spaced_full = parsed_date.strftime('%Y %m %d')
spaced_short = parsed_date.strftime('%y %m %d')
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
swedish_months_full = [
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
]
swedish_months_abbrev = [
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
]
month_full = swedish_months_full[parsed_date.month - 1]
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
variants.extend([
iso, eu_slash, eu_dot, compact, compact_short,
eu_dot_short, spaced_full, spaced_short,
swedish_format_full, swedish_format_abbrev
])
return list(set(v for v in variants if v)) return list(set(v for v in variants if v))

View File

@@ -1,3 +1,3 @@
from .paddle_ocr import OCREngine, extract_ocr_tokens from .paddle_ocr import OCREngine, OCRResult, OCRToken, extract_ocr_tokens
__all__ = ['OCREngine', 'extract_ocr_tokens'] __all__ = ['OCREngine', 'OCRResult', 'OCRToken', 'extract_ocr_tokens']

View File

@@ -4,11 +4,18 @@ OCR Extraction Module using PaddleOCR
Extracts text tokens with bounding boxes from scanned PDFs. Extracts text tokens with bounding boxes from scanned PDFs.
""" """
import os
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Generator
import numpy as np import numpy as np
# Suppress PaddlePaddle reinitialization warnings
os.environ.setdefault('GLOG_minloglevel', '2')
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
warnings.filterwarnings('ignore', message='.*reinitialization.*')
@dataclass @dataclass
class OCRToken: class OCRToken:
@@ -39,13 +46,19 @@ class OCRToken:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2) return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
@dataclass
class OCRResult:
"""Result from OCR extraction including tokens and preprocessed image."""
tokens: list[OCRToken]
output_img: np.ndarray | None = None # Preprocessed image from PaddleOCR
class OCREngine: class OCREngine:
"""PaddleOCR wrapper for text extraction.""" """PaddleOCR wrapper for text extraction."""
def __init__( def __init__(
self, self,
lang: str = "en", lang: str = "en",
use_gpu: bool = True, # Default to GPU for better performance
det_model_dir: str | None = None, det_model_dir: str | None = None,
rec_model_dir: str | None = None rec_model_dir: str | None = None
): ):
@@ -54,17 +67,21 @@ class OCREngine:
Args: Args:
lang: Language code ('en', 'sv', 'ch', etc.) lang: Language code ('en', 'sv', 'ch', etc.)
use_gpu: Whether to use GPU acceleration (default: True)
det_model_dir: Custom detection model directory det_model_dir: Custom detection model directory
rec_model_dir: Custom recognition model directory rec_model_dir: Custom recognition model directory
Note:
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
Use `paddle.set_device('gpu')` before initialization to force GPU.
""" """
# Suppress warnings during import and initialization
with warnings.catch_warnings():
warnings.filterwarnings('ignore')
from paddleocr import PaddleOCR from paddleocr import PaddleOCR
# PaddleOCR init with GPU support # PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
init_params = { init_params = {
'lang': lang, 'lang': lang,
'use_gpu': use_gpu,
'show_log': False, # Reduce log noise
} }
if det_model_dir: if det_model_dir:
init_params['text_detection_model_dir'] = det_model_dir init_params['text_detection_model_dir'] = det_model_dir
@@ -72,12 +89,13 @@ class OCREngine:
init_params['text_recognition_model_dir'] = rec_model_dir init_params['text_recognition_model_dir'] = rec_model_dir
self.ocr = PaddleOCR(**init_params) self.ocr = PaddleOCR(**init_params)
self.use_gpu = use_gpu
def extract_from_image( def extract_from_image(
self, self,
image: str | Path | np.ndarray, image: str | Path | np.ndarray,
page_no: int = 0 page_no: int = 0,
max_size: int = 2000,
scale_to_pdf_points: float | None = None
) -> list[OCRToken]: ) -> list[OCRToken]:
""" """
Extract text tokens from an image. Extract text tokens from an image.
@@ -85,17 +103,73 @@ class OCREngine:
Args: Args:
image: Image path or numpy array image: Image path or numpy array
page_no: Page number for reference page_no: Page number for reference
max_size: Maximum image dimension. Larger images will be scaled down
to avoid OCR issues with PaddleOCR on large images.
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI.
Returns: Returns:
List of OCRToken objects List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
""" """
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
return result.tokens
def extract_with_image(
self,
image: str | Path | np.ndarray,
page_no: int = 0,
max_size: int = 2000,
scale_to_pdf_points: float | None = None
) -> OCRResult:
"""
Extract text tokens from an image and return the preprocessed image.
PaddleOCR applies document preprocessing (unwarping, rotation, enhancement)
and returns coordinates relative to the preprocessed image (output_img).
This method returns both tokens and output_img so the caller can save
the correct image that matches the coordinates.
Args:
image: Image path or numpy array
page_no: Page number for reference
max_size: Maximum image dimension. Larger images will be scaled down
to avoid OCR issues with PaddleOCR on large images.
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
to convert from pixel to PDF point coordinates.
Use (72 / dpi) for images rendered at a specific DPI.
Returns:
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
"""
from PIL import Image as PILImage
# Load image if path
if isinstance(image, (str, Path)): if isinstance(image, (str, Path)):
image = str(image) img = PILImage.open(str(image))
img_array = np.array(img)
else:
img_array = image
# Check if image needs scaling for OCR
h, w = img_array.shape[:2]
ocr_scale_factor = 1.0
if max(h, w) > max_size:
ocr_scale_factor = max_size / max(h, w)
new_w = int(w * ocr_scale_factor)
new_h = int(h * ocr_scale_factor)
# Resize image for OCR
img = PILImage.fromarray(img_array)
img = img.resize((new_w, new_h), PILImage.Resampling.LANCZOS)
img_array = np.array(img)
# PaddleOCR 3.x uses predict() method instead of ocr() # PaddleOCR 3.x uses predict() method instead of ocr()
result = self.ocr.predict(image) result = self.ocr.predict(img_array)
tokens = [] tokens = []
output_img = None
if result: if result:
for item in result: for item in result:
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys' # PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
@@ -104,16 +178,30 @@ class OCREngine:
rec_scores = item.get('rec_scores', []) rec_scores = item.get('rec_scores', [])
dt_polys = item.get('dt_polys', []) dt_polys = item.get('dt_polys', [])
for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)): # Get output_img from doc_preprocessor_res
# This is the preprocessed image that coordinates are relative to
doc_preproc = item.get('doc_preprocessor_res', {})
if isinstance(doc_preproc, dict):
output_img = doc_preproc.get('output_img')
# Coordinates are relative to output_img (preprocessed image)
# No rotation compensation needed - just use coordinates directly
for text, score, poly in zip(rec_texts, rec_scores, dt_polys):
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] # poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
x_coords = [p[0] for p in poly] x_coords = [float(p[0]) for p in poly]
y_coords = [p[1] for p in poly] y_coords = [float(p[1]) for p in poly]
# Apply PDF points scale if requested
if scale_to_pdf_points is not None:
final_scale = scale_to_pdf_points
else:
final_scale = 1.0
bbox = ( bbox = (
min(x_coords), min(x_coords) * final_scale,
min(y_coords), min(y_coords) * final_scale,
max(x_coords), max(x_coords) * final_scale,
max(y_coords) max(y_coords) * final_scale
) )
tokens.append(OCRToken( tokens.append(OCRToken(
@@ -129,11 +217,17 @@ class OCREngine:
x_coords = [p[0] for p in bbox_points] x_coords = [p[0] for p in bbox_points]
y_coords = [p[1] for p in bbox_points] y_coords = [p[1] for p in bbox_points]
# Apply PDF points scale if requested
if scale_to_pdf_points is not None:
final_scale = scale_to_pdf_points
else:
final_scale = 1.0
bbox = ( bbox = (
min(x_coords), min(x_coords) * final_scale,
min(y_coords), min(y_coords) * final_scale,
max(x_coords), max(x_coords) * final_scale,
max(y_coords) max(y_coords) * final_scale
) )
tokens.append(OCRToken( tokens.append(OCRToken(
@@ -143,7 +237,11 @@ class OCREngine:
page_no=page_no page_no=page_no
)) ))
return tokens # If no output_img was found, use the original input array
if output_img is None:
output_img = img_array
return OCRResult(tokens=tokens, output_img=output_img)
def extract_from_pdf( def extract_from_pdf(
self, self,

View File

@@ -1,5 +1,12 @@
from .detector import is_text_pdf, get_pdf_type from .detector import is_text_pdf, get_pdf_type
from .renderer import render_pdf_to_images from .renderer import render_pdf_to_images
from .extractor import extract_text_tokens from .extractor import extract_text_tokens, PDFDocument, Token
__all__ = ['is_text_pdf', 'get_pdf_type', 'render_pdf_to_images', 'extract_text_tokens'] __all__ = [
'is_text_pdf',
'get_pdf_type',
'render_pdf_to_images',
'extract_text_tokens',
'PDFDocument',
'Token',
]

View File

@@ -6,7 +6,7 @@ Extracts text tokens with bounding boxes from text-layer PDFs.
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Generator from typing import Generator, Optional
import fitz # PyMuPDF import fitz # PyMuPDF
@@ -46,6 +46,134 @@ class Token:
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2) return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
class PDFDocument:
"""
Context manager for efficient PDF document handling.
Caches the open document handle to avoid repeated open/close cycles.
Use this when you need to perform multiple operations on the same PDF.
"""
def __init__(self, pdf_path: str | Path):
self.pdf_path = Path(pdf_path)
self._doc: Optional[fitz.Document] = None
self._dimensions_cache: dict[int, tuple[float, float]] = {}
def __enter__(self) -> 'PDFDocument':
self._doc = fitz.open(self.pdf_path)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self._doc:
self._doc.close()
self._doc = None
@property
def doc(self) -> fitz.Document:
if self._doc is None:
raise RuntimeError("PDFDocument must be used within a context manager")
return self._doc
@property
def page_count(self) -> int:
return len(self.doc)
def is_text_pdf(self, min_chars: int = 30) -> bool:
"""Check if PDF has extractable text layer."""
if self.page_count == 0:
return False
first_page = self.doc[0]
text = first_page.get_text()
return len(text.strip()) > min_chars
def get_page_dimensions(self, page_no: int = 0) -> tuple[float, float]:
"""Get page dimensions in points (cached)."""
if page_no not in self._dimensions_cache:
page = self.doc[page_no]
rect = page.rect
self._dimensions_cache[page_no] = (rect.width, rect.height)
return self._dimensions_cache[page_no]
def get_render_dimensions(self, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
"""Get rendered image dimensions in pixels."""
width, height = self.get_page_dimensions(page_no)
zoom = dpi / 72
return int(width * zoom), int(height * zoom)
def extract_text_tokens(self, page_no: int) -> Generator[Token, None, None]:
"""Extract text tokens from a specific page."""
page = self.doc[page_no]
text_dict = page.get_text("dict")
tokens_found = False
for block in text_dict.get("blocks", []):
if block.get("type") != 0:
continue
for line in block.get("lines", []):
for span in line.get("spans", []):
text = span.get("text", "").strip()
if not text:
continue
bbox = span.get("bbox")
if bbox and all(abs(b) < 1e9 for b in bbox):
tokens_found = True
yield Token(
text=text,
bbox=tuple(bbox),
page_no=page_no
)
# Fallback: if dict mode failed, use words mode
if not tokens_found:
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=page_no
)
def render_page(self, page_no: int, output_path: Path, dpi: int = 300) -> Path:
"""Render a page to an image file."""
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
page = self.doc[page_no]
pix = page.get_pixmap(matrix=matrix)
output_path.parent.mkdir(parents=True, exist_ok=True)
pix.save(str(output_path))
return output_path
def render_all_pages(
self,
output_dir: Path,
dpi: int = 300
) -> Generator[tuple[int, Path], None, None]:
"""Render all pages to images."""
output_dir.mkdir(parents=True, exist_ok=True)
pdf_name = self.pdf_path.stem
zoom = dpi / 72
matrix = fitz.Matrix(zoom, zoom)
for page_no in range(self.page_count):
page = self.doc[page_no]
pix = page.get_pixmap(matrix=matrix)
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.png"
pix.save(str(image_path))
yield page_no, image_path
def extract_text_tokens( def extract_text_tokens(
pdf_path: str | Path, pdf_path: str | Path,
page_no: int | None = None page_no: int | None = None
@@ -70,6 +198,7 @@ def extract_text_tokens(
# Get text with position info using "dict" mode # Get text with position info using "dict" mode
text_dict = page.get_text("dict") text_dict = page.get_text("dict")
tokens_found = False
for block in text_dict.get("blocks", []): for block in text_dict.get("blocks", []):
if block.get("type") != 0: # Skip non-text blocks if block.get("type") != 0: # Skip non-text blocks
continue continue
@@ -81,13 +210,28 @@ def extract_text_tokens(
continue continue
bbox = span.get("bbox") bbox = span.get("bbox")
if bbox: # Check for corrupted bbox (overflow values)
if bbox and all(abs(b) < 1e9 for b in bbox):
tokens_found = True
yield Token( yield Token(
text=text, text=text,
bbox=tuple(bbox), bbox=tuple(bbox),
page_no=pg_no page_no=pg_no
) )
# Fallback: if dict mode failed, use words mode
if not tokens_found:
words = page.get_text("words")
for word_info in words:
x0, y0, x1, y1, text, *_ = word_info
text = text.strip()
if text:
yield Token(
text=text,
bbox=(x0, y0, x1, y1),
page_no=pg_no
)
doc.close() doc.close()

View File

@@ -0,0 +1,22 @@
"""
Processing module for multi-pool parallel processing.
This module provides a robust dual-pool architecture for processing
documents with both CPU-bound and GPU-bound tasks.
"""
from src.processing.worker_pool import WorkerPool, TaskResult
from src.processing.cpu_pool import CPUWorkerPool
from src.processing.gpu_pool import GPUWorkerPool
from src.processing.task_dispatcher import TaskDispatcher, TaskType
from src.processing.dual_pool_coordinator import DualPoolCoordinator
__all__ = [
"WorkerPool",
"TaskResult",
"CPUWorkerPool",
"GPUWorkerPool",
"TaskDispatcher",
"TaskType",
"DualPoolCoordinator",
]

View File

@@ -0,0 +1,351 @@
"""
Task functions for autolabel processing in multi-pool architecture.
Provides CPU and GPU task functions that can be called from worker pools.
"""
from __future__ import annotations
import os
import time
import warnings
from pathlib import Path
from typing import Any, Dict, Optional
# Global OCR instance (initialized once per GPU worker process)
_ocr_engine: Optional[Any] = None
def init_cpu_worker() -> None:
"""
Initialize CPU worker process.
Disables GPU access and suppresses unnecessary warnings.
"""
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
def init_gpu_worker(gpu_id: int = 0, gpu_mem: int = 4000) -> None:
"""
Initialize GPU worker process with PaddleOCR.
Args:
gpu_id: GPU device ID.
gpu_mem: Maximum GPU memory in MB.
"""
global _ocr_engine
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
os.environ["GLOG_minloglevel"] = "2"
# Suppress PaddleX warnings
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
warnings.filterwarnings("ignore", message=".*reinitialization.*")
# Lazy initialization - OCR will be created on first use
_ocr_engine = None
def _get_ocr_engine():
"""Get or create OCR engine for current GPU worker."""
global _ocr_engine
if _ocr_engine is None:
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
from src.ocr import OCREngine
_ocr_engine = OCREngine()
return _ocr_engine
def _save_output_img(output_img, image_path: Path) -> None:
"""Save OCR preprocessed image to replace rendered image."""
from PIL import Image as PILImage
if output_img is not None:
img = PILImage.fromarray(output_img)
img.save(str(image_path))
def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a text PDF (CPU task - no OCR needed).
Args:
task_data: Dictionary with keys:
- row_dict: Document fields from CSV
- pdf_path: Path to PDF file
- output_dir: Output directory
- dpi: Rendering DPI
- min_confidence: Minimum match confidence
Returns:
Result dictionary with success status, annotations, and report.
"""
from src.data import AutoLabelReport, FieldMatchResult
from src.pdf import PDFDocument
from src.matcher import FieldMatcher
from src.normalize import normalize_field
from src.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", 150)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()
doc_id = row_dict["DocumentId"]
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
report.pdf_type = "text"
result = {
"doc_id": doc_id,
"success": False,
"pages": [],
"report": None,
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
}
try:
with PDFDocument(pdf_path) as pdf_doc:
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
page_annotations = []
matched_fields = set()
images_dir = output_dir / "temp" / doc_id / "images"
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# Text extraction (no OCR)
tokens = list(pdf_doc.extract_text_tokens(page_no))
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
matched_fields.add(field_name)
report.add_field_result(
FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords,
)
)
# Generate annotations
annotations = generator.generate_from_matches(
matches, img_width, img_height, dpi=dpi
)
if annotations:
page_annotations.append(
{
"image_path": str(image_path),
"page_no": page_no,
"count": len(annotations),
}
)
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result["stats"][class_name] += 1
# Record unmatched fields
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if value and field_name not in matched_fields:
report.add_field_result(
FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=-1,
)
)
if page_annotations:
result["pages"] = page_annotations
result["success"] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result["report"] = report.to_dict()
return result
def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a scanned PDF (GPU task - requires OCR).
Args:
task_data: Dictionary with keys:
- row_dict: Document fields from CSV
- pdf_path: Path to PDF file
- output_dir: Output directory
- dpi: Rendering DPI
- min_confidence: Minimum match confidence
Returns:
Result dictionary with success status, annotations, and report.
"""
from src.data import AutoLabelReport, FieldMatchResult
from src.pdf import PDFDocument
from src.matcher import FieldMatcher
from src.normalize import normalize_field
from src.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
row_dict = task_data["row_dict"]
pdf_path = Path(task_data["pdf_path"])
output_dir = Path(task_data["output_dir"])
dpi = task_data.get("dpi", 150)
min_confidence = task_data.get("min_confidence", 0.5)
start_time = time.time()
doc_id = row_dict["DocumentId"]
report = AutoLabelReport(document_id=doc_id)
report.pdf_path = str(pdf_path)
report.pdf_type = "scanned"
result = {
"doc_id": doc_id,
"success": False,
"pages": [],
"report": None,
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
}
try:
# Get OCR engine from worker cache
ocr_engine = _get_ocr_engine()
with PDFDocument(pdf_path) as pdf_doc:
generator = AnnotationGenerator(min_confidence=min_confidence)
matcher = FieldMatcher()
page_annotations = []
matched_fields = set()
images_dir = output_dir / "temp" / doc_id / "images"
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
report.total_pages += 1
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
# OCR extraction
ocr_result = ocr_engine.extract_with_image(
str(image_path),
page_no,
scale_to_pdf_points=72 / dpi,
)
tokens = ocr_result.tokens
# Save preprocessed image
_save_output_img(ocr_result.output_img, image_path)
# Update dimensions to match OCR output
if ocr_result.output_img is not None:
img_height, img_width = ocr_result.output_img.shape[:2]
# Match fields
matches = {}
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if not value:
continue
normalized = normalize_field(field_name, str(value))
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
if field_matches:
best = field_matches[0]
matches[field_name] = field_matches
matched_fields.add(field_name)
report.add_field_result(
FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=True,
score=best.score,
matched_text=best.matched_text,
candidate_used=best.value,
bbox=best.bbox,
page_no=page_no,
context_keywords=best.context_keywords,
)
)
# Generate annotations
annotations = generator.generate_from_matches(
matches, img_width, img_height, dpi=dpi
)
if annotations:
page_annotations.append(
{
"image_path": str(image_path),
"page_no": page_no,
"count": len(annotations),
}
)
report.annotations_generated += len(annotations)
for ann in annotations:
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
result["stats"][class_name] += 1
# Record unmatched fields
for field_name in FIELD_CLASSES.keys():
value = row_dict.get(field_name)
if value and field_name not in matched_fields:
report.add_field_result(
FieldMatchResult(
field_name=field_name,
csv_value=str(value),
matched=False,
page_no=-1,
)
)
if page_annotations:
result["pages"] = page_annotations
result["success"] = True
report.success = True
else:
report.errors.append("No annotations generated")
except Exception as e:
report.errors.append(str(e))
report.processing_time_ms = (time.time() - start_time) * 1000
result["report"] = report.to_dict()
return result

View File

@@ -0,0 +1,71 @@
"""
CPU Worker Pool for text PDF processing.
This pool handles CPU-bound tasks like text extraction from native PDFs
that don't require OCR.
"""
from __future__ import annotations
import logging
import os
from typing import Callable, Optional
from src.processing.worker_pool import WorkerPool
logger = logging.getLogger(__name__)
# Global resources for CPU workers (initialized once per process)
_cpu_initialized: bool = False
def _init_cpu_worker() -> None:
"""
Initialize a CPU worker process.
Disables GPU access and sets up CPU-only environment.
"""
global _cpu_initialized
# Disable GPU access for CPU workers
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Set threading limits for better CPU utilization
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
_cpu_initialized = True
logger.debug(f"CPU worker initialized in process {os.getpid()}")
class CPUWorkerPool(WorkerPool):
"""
Worker pool for CPU-bound tasks.
Handles text PDF processing that doesn't require OCR.
Each worker is initialized with CUDA disabled to prevent
accidental GPU memory consumption.
Example:
with CPUWorkerPool(max_workers=4) as pool:
future = pool.submit(process_text_pdf, pdf_path)
result = future.result()
"""
def __init__(self, max_workers: int = 4) -> None:
"""
Initialize CPU worker pool.
Args:
max_workers: Number of CPU worker processes.
Defaults to 4 for balanced performance.
"""
super().__init__(max_workers=max_workers, use_gpu=False, gpu_id=-1)
def get_initializer(self) -> Optional[Callable[..., None]]:
"""Return the CPU worker initializer."""
return _init_cpu_worker
def get_init_args(self) -> tuple:
"""Return empty args for CPU initializer."""
return ()

View File

@@ -0,0 +1,339 @@
"""
Dual Pool Coordinator for managing CPU and GPU worker pools.
Coordinates task distribution between CPU and GPU pools, handles result
collection using as_completed(), and provides callbacks for progress tracking.
"""
from __future__ import annotations
import logging
import time
from concurrent.futures import Future, TimeoutError, as_completed
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional
from src.processing.cpu_pool import CPUWorkerPool
from src.processing.gpu_pool import GPUWorkerPool
from src.processing.task_dispatcher import Task, TaskDispatcher, TaskType
from src.processing.worker_pool import TaskResult
logger = logging.getLogger(__name__)
@dataclass
class BatchStats:
"""Statistics for a batch processing run."""
total: int = 0
cpu_submitted: int = 0
gpu_submitted: int = 0
successful: int = 0
failed: int = 0
cpu_time: float = 0.0
gpu_time: float = 0.0
errors: List[str] = field(default_factory=list)
@property
def success_rate(self) -> float:
"""Calculate success rate as percentage."""
if self.total == 0:
return 0.0
return (self.successful / self.total) * 100
class DualPoolCoordinator:
"""
Coordinates CPU and GPU worker pools for parallel document processing.
Uses separate ProcessPoolExecutor instances for CPU and GPU tasks,
with as_completed() for efficient result collection across both pools.
Key features:
- Automatic task classification (CPU vs GPU)
- Parallel submission to both pools
- Unified result collection with timeouts
- Progress callbacks for UI integration
- Proper resource cleanup
Example:
with DualPoolCoordinator(cpu_workers=4, gpu_workers=1) as coord:
results = coord.process_batch(
documents=docs,
cpu_task_fn=process_text_pdf,
gpu_task_fn=process_scanned_pdf,
on_result=lambda r: save_to_db(r),
)
"""
def __init__(
self,
cpu_workers: int = 4,
gpu_workers: int = 1,
gpu_id: int = 0,
task_timeout: float = 300.0,
) -> None:
"""
Initialize the dual pool coordinator.
Args:
cpu_workers: Number of CPU worker processes.
gpu_workers: Number of GPU worker processes (usually 1).
gpu_id: GPU device ID to use.
task_timeout: Timeout in seconds for individual tasks.
"""
self.cpu_workers = cpu_workers
self.gpu_workers = gpu_workers
self.gpu_id = gpu_id
self.task_timeout = task_timeout
self._cpu_pool: Optional[CPUWorkerPool] = None
self._gpu_pool: Optional[GPUWorkerPool] = None
self._dispatcher = TaskDispatcher()
self._started = False
def start(self) -> None:
"""Start both worker pools."""
if self._started:
raise RuntimeError("Coordinator already started")
logger.info(
f"Starting DualPoolCoordinator: "
f"{self.cpu_workers} CPU workers, {self.gpu_workers} GPU workers"
)
self._cpu_pool = CPUWorkerPool(max_workers=self.cpu_workers)
self._gpu_pool = GPUWorkerPool(
max_workers=self.gpu_workers,
gpu_id=self.gpu_id,
)
self._cpu_pool.start()
self._gpu_pool.start()
self._started = True
def shutdown(self, wait: bool = True) -> None:
"""Shutdown both worker pools."""
logger.info("Shutting down DualPoolCoordinator")
if self._cpu_pool is not None:
self._cpu_pool.shutdown(wait=wait)
self._cpu_pool = None
if self._gpu_pool is not None:
self._gpu_pool.shutdown(wait=wait)
self._gpu_pool = None
self._started = False
def __enter__(self) -> "DualPoolCoordinator":
"""Context manager entry."""
self.start()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit."""
self.shutdown(wait=True)
def process_batch(
self,
documents: List[dict],
cpu_task_fn: Callable[[dict], Any],
gpu_task_fn: Callable[[dict], Any],
on_result: Optional[Callable[[TaskResult], None]] = None,
on_error: Optional[Callable[[str, Exception], None]] = None,
on_progress: Optional[Callable[[int, int], None]] = None,
id_field: str = "id",
) -> List[TaskResult]:
"""
Process a batch of documents using both CPU and GPU pools.
Documents are automatically classified and routed to the appropriate
pool. Results are collected as they complete.
Args:
documents: List of document info dicts to process.
cpu_task_fn: Function to process text PDFs (called in CPU pool).
gpu_task_fn: Function to process scanned PDFs (called in GPU pool).
on_result: Callback for each successful result.
on_error: Callback for each failed task (task_id, exception).
on_progress: Callback for progress updates (completed, total).
id_field: Field name to use as task ID.
Returns:
List of TaskResult objects for all tasks.
Raises:
RuntimeError: If coordinator is not started.
"""
if not self._started:
raise RuntimeError("Coordinator not started. Use context manager or call start().")
if not documents:
return []
stats = BatchStats(total=len(documents))
# Create and partition tasks
tasks = self._dispatcher.create_tasks(documents, id_field=id_field)
cpu_tasks, gpu_tasks = self._dispatcher.partition_tasks(tasks)
# Submit tasks to pools
futures_map: Dict[Future, Task] = {}
# Submit CPU tasks
cpu_start = time.time()
for task in cpu_tasks:
future = self._cpu_pool.submit(cpu_task_fn, task.data)
futures_map[future] = task
stats.cpu_submitted += 1
# Submit GPU tasks
gpu_start = time.time()
for task in gpu_tasks:
future = self._gpu_pool.submit(gpu_task_fn, task.data)
futures_map[future] = task
stats.gpu_submitted += 1
logger.info(
f"Submitted {stats.cpu_submitted} CPU tasks, {stats.gpu_submitted} GPU tasks"
)
# Collect results as they complete
results: List[TaskResult] = []
completed = 0
for future in as_completed(futures_map.keys(), timeout=self.task_timeout * len(documents)):
task = futures_map[future]
pool_type = "CPU" if task.task_type == TaskType.CPU else "GPU"
start_time = time.time()
try:
data = future.result(timeout=self.task_timeout)
processing_time = time.time() - start_time
result = TaskResult(
task_id=task.id,
success=True,
data=data,
pool_type=pool_type,
processing_time=processing_time,
)
stats.successful += 1
if pool_type == "CPU":
stats.cpu_time += processing_time
else:
stats.gpu_time += processing_time
if on_result is not None:
try:
on_result(result)
except Exception as e:
logger.warning(f"on_result callback failed: {e}")
except TimeoutError:
error_msg = f"Task timed out after {self.task_timeout}s"
logger.error(f"[{pool_type}] Task {task.id}: {error_msg}")
result = TaskResult(
task_id=task.id,
success=False,
data=None,
error=error_msg,
pool_type=pool_type,
)
stats.failed += 1
stats.errors.append(f"{task.id}: {error_msg}")
if on_error is not None:
try:
on_error(task.id, TimeoutError(error_msg))
except Exception as e:
logger.warning(f"on_error callback failed: {e}")
except Exception as e:
error_msg = str(e)
logger.error(f"[{pool_type}] Task {task.id} failed: {error_msg}")
result = TaskResult(
task_id=task.id,
success=False,
data=None,
error=error_msg,
pool_type=pool_type,
)
stats.failed += 1
stats.errors.append(f"{task.id}: {error_msg}")
if on_error is not None:
try:
on_error(task.id, e)
except Exception as callback_error:
logger.warning(f"on_error callback failed: {callback_error}")
results.append(result)
completed += 1
if on_progress is not None:
try:
on_progress(completed, stats.total)
except Exception as e:
logger.warning(f"on_progress callback failed: {e}")
# Log final stats
logger.info(
f"Batch complete: {stats.successful}/{stats.total} successful "
f"({stats.success_rate:.1f}%), {stats.failed} failed"
)
if stats.cpu_submitted > 0:
logger.info(f"CPU: {stats.cpu_submitted} tasks, {stats.cpu_time:.2f}s total")
if stats.gpu_submitted > 0:
logger.info(f"GPU: {stats.gpu_submitted} tasks, {stats.gpu_time:.2f}s total")
return results
def process_single(
self,
document: dict,
cpu_task_fn: Callable[[dict], Any],
gpu_task_fn: Callable[[dict], Any],
id_field: str = "id",
) -> TaskResult:
"""
Process a single document.
Convenience method for processing one document at a time.
Args:
document: Document info dict.
cpu_task_fn: Function for text PDF processing.
gpu_task_fn: Function for scanned PDF processing.
id_field: Field name for task ID.
Returns:
TaskResult for the document.
"""
results = self.process_batch(
documents=[document],
cpu_task_fn=cpu_task_fn,
gpu_task_fn=gpu_task_fn,
id_field=id_field,
)
return results[0] if results else TaskResult(
task_id=str(document.get(id_field, "unknown")),
success=False,
data=None,
error="No result returned",
)
@property
def is_running(self) -> bool:
"""Check if both pools are running."""
return (
self._started
and self._cpu_pool is not None
and self._gpu_pool is not None
and self._cpu_pool.is_running
and self._gpu_pool.is_running
)

110
src/processing/gpu_pool.py Normal file
View File

@@ -0,0 +1,110 @@
"""
GPU Worker Pool for OCR processing.
This pool handles GPU-bound tasks like PaddleOCR for scanned PDF processing.
"""
from __future__ import annotations
import logging
import os
from typing import Any, Callable, Optional
from src.processing.worker_pool import WorkerPool
logger = logging.getLogger(__name__)
# Global OCR instance for GPU workers (initialized once per process)
_ocr_instance: Optional[Any] = None
_gpu_initialized: bool = False
def _init_gpu_worker(gpu_id: int = 0) -> None:
"""
Initialize a GPU worker process with PaddleOCR.
Args:
gpu_id: GPU device ID to use.
"""
global _ocr_instance, _gpu_initialized
# Set GPU device before importing paddle
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
# Reduce logging noise
os.environ["GLOG_minloglevel"] = "2"
# Suppress PaddleX warnings
import warnings
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
warnings.filterwarnings("ignore", message=".*reinitialization.*")
try:
# Import PaddleOCR after setting environment
# PaddleOCR 3.x uses paddle.set_device() for GPU control, not use_gpu param
import paddle
paddle.set_device(f"gpu:{gpu_id}")
from paddleocr import PaddleOCR
# PaddleOCR 3.x init - minimal params, GPU controlled via paddle.set_device
_ocr_instance = PaddleOCR(lang="en")
_gpu_initialized = True
logger.info(f"GPU worker initialized on GPU {gpu_id} in process {os.getpid()}")
except Exception as e:
logger.error(f"Failed to initialize GPU worker: {e}")
raise
def get_ocr_instance() -> Any:
"""
Get the initialized OCR instance for the current worker.
Returns:
PaddleOCR instance.
Raises:
RuntimeError: If OCR is not initialized.
"""
global _ocr_instance
if _ocr_instance is None:
raise RuntimeError("OCR not initialized. This function must be called from a GPU worker.")
return _ocr_instance
class GPUWorkerPool(WorkerPool):
"""
Worker pool for GPU-bound OCR tasks.
Handles scanned PDF processing using PaddleOCR with GPU acceleration.
Typically limited to 1 worker to avoid GPU memory conflicts.
Example:
with GPUWorkerPool(max_workers=1, gpu_id=0) as pool:
future = pool.submit(process_scanned_pdf, pdf_path)
result = future.result()
"""
def __init__(
self,
max_workers: int = 1,
gpu_id: int = 0,
) -> None:
"""
Initialize GPU worker pool.
Args:
max_workers: Number of GPU worker processes.
Defaults to 1 to avoid GPU memory conflicts.
gpu_id: GPU device ID to use.
"""
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
def get_initializer(self) -> Optional[Callable[..., None]]:
"""Return the GPU worker initializer."""
return _init_gpu_worker
def get_init_args(self) -> tuple:
"""Return args for GPU initializer."""
return (self.gpu_id,)

View File

@@ -0,0 +1,174 @@
"""
Task Dispatcher for classifying and routing tasks to appropriate worker pools.
Determines whether a document should be processed by CPU (text PDF) or
GPU (scanned PDF requiring OCR) workers.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import Any, List, Tuple
logger = logging.getLogger(__name__)
class TaskType(Enum):
"""Task type classification."""
CPU = auto() # Text PDF - no OCR needed
GPU = auto() # Scanned PDF - requires OCR
@dataclass
class Task:
"""
Represents a processing task.
Attributes:
id: Unique task identifier.
task_type: Whether task needs CPU or GPU processing.
data: Task payload (document info, paths, etc.).
"""
id: str
task_type: TaskType
data: Any
class TaskDispatcher:
"""
Classifies and partitions tasks for CPU and GPU worker pools.
Uses PDF characteristics to determine if OCR is needed:
- Text PDFs with extractable text -> CPU
- Scanned PDFs / image-based PDFs -> GPU (OCR)
Example:
dispatcher = TaskDispatcher()
tasks = [Task(id="1", task_type=dispatcher.classify(doc), data=doc) for doc in docs]
cpu_tasks, gpu_tasks = dispatcher.partition_tasks(tasks)
"""
def __init__(
self,
text_char_threshold: int = 100,
ocr_ratio_threshold: float = 0.3,
) -> None:
"""
Initialize the task dispatcher.
Args:
text_char_threshold: Minimum characters to consider as text PDF.
ocr_ratio_threshold: If text/expected ratio below this, use OCR.
"""
self.text_char_threshold = text_char_threshold
self.ocr_ratio_threshold = ocr_ratio_threshold
def classify_by_pdf_info(
self,
has_text: bool,
text_length: int,
page_count: int = 1,
) -> TaskType:
"""
Classify task based on PDF text extraction info.
Args:
has_text: Whether PDF has extractable text layer.
text_length: Number of characters extracted.
page_count: Number of pages in PDF.
Returns:
TaskType.CPU for text PDFs, TaskType.GPU for scanned PDFs.
"""
if not has_text:
return TaskType.GPU
# Check if text density is reasonable
avg_chars_per_page = text_length / max(page_count, 1)
if avg_chars_per_page < self.text_char_threshold:
return TaskType.GPU
return TaskType.CPU
def classify_document(self, doc_info: dict) -> TaskType:
"""
Classify a document based on its metadata.
Args:
doc_info: Document information dict with keys like:
- 'is_scanned': bool (if known)
- 'text_length': int
- 'page_count': int
- 'pdf_path': str
Returns:
TaskType for the document.
"""
# If explicitly marked as scanned
if doc_info.get("is_scanned", False):
return TaskType.GPU
# If we have text extraction info
text_length = doc_info.get("text_length", 0)
page_count = doc_info.get("page_count", 1)
has_text = doc_info.get("has_text", text_length > 0)
return self.classify_by_pdf_info(
has_text=has_text,
text_length=text_length,
page_count=page_count,
)
def partition_tasks(
self,
tasks: List[Task],
) -> Tuple[List[Task], List[Task]]:
"""
Partition tasks into CPU and GPU groups.
Args:
tasks: List of Task objects with task_type set.
Returns:
Tuple of (cpu_tasks, gpu_tasks).
"""
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
logger.info(
f"Task partition: {len(cpu_tasks)} CPU tasks, {len(gpu_tasks)} GPU tasks"
)
return cpu_tasks, gpu_tasks
def create_tasks(
self,
documents: List[dict],
id_field: str = "id",
) -> List[Task]:
"""
Create Task objects from document dicts.
Args:
documents: List of document info dicts.
id_field: Field name to use as task ID.
Returns:
List of Task objects with types classified.
"""
tasks = []
for doc in documents:
task_id = str(doc.get(id_field, id(doc)))
task_type = self.classify_document(doc)
tasks.append(Task(id=task_id, task_type=task_type, data=doc))
cpu_count = sum(1 for t in tasks if t.task_type == TaskType.CPU)
gpu_count = len(tasks) - cpu_count
logger.debug(f"Created {len(tasks)} tasks: {cpu_count} CPU, {gpu_count} GPU")
return tasks

View File

@@ -0,0 +1,182 @@
"""
Abstract base class for worker pools.
Provides a unified interface for CPU and GPU worker pools with proper
initialization, task submission, and resource cleanup.
"""
from __future__ import annotations
import logging
import multiprocessing as mp
from abc import ABC, abstractmethod
from concurrent.futures import Future, ProcessPoolExecutor
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
logger = logging.getLogger(__name__)
@dataclass
class TaskResult:
"""Container for task execution results."""
task_id: str
success: bool
data: Any
error: Optional[str] = None
processing_time: float = 0.0
pool_type: str = ""
extra: dict = field(default_factory=dict)
class WorkerPool(ABC):
"""
Abstract base class for worker pools.
Provides a common interface for ProcessPoolExecutor-based worker pools
with proper initialization using the 'spawn' start method for CUDA
compatibility.
Attributes:
max_workers: Maximum number of worker processes.
use_gpu: Whether this pool uses GPU resources.
gpu_id: GPU device ID (only relevant if use_gpu=True).
"""
def __init__(
self,
max_workers: int,
use_gpu: bool = False,
gpu_id: int = 0,
) -> None:
"""
Initialize the worker pool configuration.
Args:
max_workers: Maximum number of worker processes.
use_gpu: Whether this pool uses GPU resources.
gpu_id: GPU device ID for GPU pools.
"""
self.max_workers = max_workers
self.use_gpu = use_gpu
self.gpu_id = gpu_id
self._executor: Optional[ProcessPoolExecutor] = None
self._started = False
@property
def name(self) -> str:
"""Return the pool name for logging."""
return self.__class__.__name__
@abstractmethod
def get_initializer(self) -> Optional[Callable[..., None]]:
"""
Return the worker initialization function.
This function is called once per worker process when it starts.
Use it to load models, set environment variables, etc.
Returns:
Callable to initialize each worker, or None if no initialization needed.
"""
pass
@abstractmethod
def get_init_args(self) -> tuple:
"""
Return arguments for the initializer function.
Returns:
Tuple of arguments to pass to the initializer.
"""
pass
def start(self) -> None:
"""
Start the worker pool.
Creates a ProcessPoolExecutor with the 'spawn' start method
for CUDA compatibility.
Raises:
RuntimeError: If the pool is already started.
"""
if self._started:
raise RuntimeError(f"{self.name} is already started")
# Use 'spawn' for CUDA compatibility
ctx = mp.get_context("spawn")
initializer = self.get_initializer()
initargs = self.get_init_args()
logger.info(
f"Starting {self.name} with {self.max_workers} workers "
f"(GPU: {self.use_gpu}, GPU ID: {self.gpu_id})"
)
self._executor = ProcessPoolExecutor(
max_workers=self.max_workers,
mp_context=ctx,
initializer=initializer,
initargs=initargs if initializer else (),
)
self._started = True
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
"""
Submit a task to the worker pool.
Args:
fn: Function to execute.
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
Future representing the pending result.
Raises:
RuntimeError: If the pool is not started.
"""
if not self._started or self._executor is None:
raise RuntimeError(f"{self.name} is not started. Call start() first.")
return self._executor.submit(fn, *args, **kwargs)
def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None:
"""
Shutdown the worker pool.
Args:
wait: If True, wait for all pending futures to complete.
cancel_futures: If True, cancel all pending futures.
"""
if self._executor is not None:
logger.info(f"Shutting down {self.name} (wait={wait})")
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
self._executor = None
self._started = False
@property
def is_running(self) -> bool:
"""Check if the pool is currently running."""
return self._started and self._executor is not None
def __enter__(self) -> "WorkerPool":
"""Context manager entry - start the pool."""
self.start()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
"""Context manager exit - shutdown the pool."""
self.shutdown(wait=True)
def __repr__(self) -> str:
status = "running" if self.is_running else "stopped"
return (
f"{self.__class__.__name__}("
f"workers={self.max_workers}, "
f"gpu={self.use_gpu}, "
f"status={status})"
)

9
src/web/__init__.py Normal file
View File

@@ -0,0 +1,9 @@
"""
Web Application Module
Provides REST API and web interface for invoice field extraction.
"""
from .app import create_app
__all__ = ["create_app"]

616
src/web/app.py Normal file
View File

@@ -0,0 +1,616 @@
"""
FastAPI Application Factory
Creates and configures the FastAPI application.
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
from .config import AppConfig, default_config
from .routes import create_api_router
from .services import InferenceService
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
logger = logging.getLogger(__name__)
def create_app(config: AppConfig | None = None) -> FastAPI:
"""
Create and configure FastAPI application.
Args:
config: Application configuration. Uses default if not provided.
Returns:
Configured FastAPI application
"""
config = config or default_config
# Create inference service
inference_service = InferenceService(
model_config=config.model,
storage_config=config.storage,
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application lifespan manager."""
logger.info("Starting Invoice Inference API...")
# Initialize inference service on startup
try:
inference_service.initialize()
logger.info("Inference service ready")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
# Continue anyway - service will retry on first request
yield
logger.info("Shutting down Invoice Inference API...")
# Create FastAPI app
app = FastAPI(
title="Invoice Field Extraction API",
description="""
REST API for extracting fields from Swedish invoices.
## Features
- YOLO-based field detection
- OCR text extraction
- Field normalization and validation
- Visualization of detections
## Supported Fields
- InvoiceNumber
- InvoiceDate
- InvoiceDueDate
- OCR (reference number)
- Bankgiro
- Plusgiro
- Amount
""",
version="1.0.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Mount static files for results
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
app.mount(
"/static/results",
StaticFiles(directory=str(config.storage.result_dir)),
name="results",
)
# Include API routes
api_router = create_api_router(inference_service, config.storage)
app.include_router(api_router)
# Root endpoint - serve HTML UI
@app.get("/", response_class=HTMLResponse)
async def root() -> str:
"""Serve the web UI."""
return get_html_ui()
return app
def get_html_ui() -> str:
"""Generate HTML UI for the web application."""
return """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Invoice Field Extraction</title>
<style>
* {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
header {
text-align: center;
color: white;
margin-bottom: 30px;
}
header h1 {
font-size: 2.5rem;
margin-bottom: 10px;
}
header p {
opacity: 0.9;
font-size: 1.1rem;
}
.main-content {
display: grid;
grid-template-columns: 1fr 1fr;
gap: 20px;
}
@media (max-width: 900px) {
.main-content {
grid-template-columns: 1fr;
}
}
.card {
background: white;
border-radius: 16px;
padding: 24px;
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
}
.card h2 {
color: #333;
margin-bottom: 20px;
font-size: 1.3rem;
display: flex;
align-items: center;
gap: 10px;
}
.upload-area {
border: 3px dashed #ddd;
border-radius: 12px;
padding: 40px;
text-align: center;
cursor: pointer;
transition: all 0.3s;
background: #fafafa;
}
.upload-area:hover, .upload-area.dragover {
border-color: #667eea;
background: #f0f4ff;
}
.upload-area.has-file {
border-color: #10b981;
background: #ecfdf5;
}
.upload-icon {
font-size: 48px;
margin-bottom: 15px;
}
.upload-area p {
color: #666;
margin-bottom: 10px;
}
.upload-area small {
color: #999;
}
#file-input {
display: none;
}
.file-name {
margin-top: 15px;
padding: 10px 15px;
background: #e0f2fe;
border-radius: 8px;
color: #0369a1;
font-weight: 500;
}
.btn {
display: inline-block;
padding: 14px 28px;
border: none;
border-radius: 10px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.3s;
text-transform: uppercase;
letter-spacing: 0.5px;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
width: 100%;
margin-top: 20px;
}
.btn-primary:hover:not(:disabled) {
transform: translateY(-2px);
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
}
.btn-primary:disabled {
opacity: 0.6;
cursor: not-allowed;
}
.loading {
display: none;
text-align: center;
padding: 20px;
}
.loading.active {
display: block;
}
.spinner {
width: 40px;
height: 40px;
border: 4px solid #f3f3f3;
border-top: 4px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
margin: 0 auto 15px;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.results {
display: none;
}
.results.active {
display: block;
}
.result-header {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 20px;
padding-bottom: 15px;
border-bottom: 2px solid #eee;
}
.result-status {
padding: 6px 12px;
border-radius: 20px;
font-size: 0.85rem;
font-weight: 600;
}
.result-status.success {
background: #dcfce7;
color: #166534;
}
.result-status.partial {
background: #fef3c7;
color: #92400e;
}
.result-status.error {
background: #fee2e2;
color: #991b1b;
}
.fields-grid {
display: grid;
grid-template-columns: repeat(2, 1fr);
gap: 12px;
}
.field-item {
padding: 12px;
background: #f8fafc;
border-radius: 10px;
border-left: 4px solid #667eea;
}
.field-item label {
display: block;
font-size: 0.75rem;
color: #64748b;
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 4px;
}
.field-item .value {
font-size: 1.1rem;
font-weight: 600;
color: #1e293b;
}
.field-item .confidence {
font-size: 0.75rem;
color: #10b981;
margin-top: 2px;
}
.visualization {
margin-top: 20px;
}
.visualization img {
width: 100%;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
}
.processing-time {
text-align: center;
color: #64748b;
font-size: 0.9rem;
margin-top: 15px;
}
.error-message {
background: #fee2e2;
color: #991b1b;
padding: 15px;
border-radius: 10px;
margin-top: 15px;
}
footer {
text-align: center;
color: white;
opacity: 0.8;
margin-top: 30px;
font-size: 0.9rem;
}
</style>
</head>
<body>
<div class="container">
<header>
<h1>📄 Invoice Field Extraction</h1>
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
</header>
<div class="main-content">
<div class="card">
<h2>📤 Upload Document</h2>
<div class="upload-area" id="upload-area">
<div class="upload-icon">📁</div>
<p>Drag & drop your file here</p>
<p>or <strong>click to browse</strong></p>
<small>Supports PDF, PNG, JPG (max 50MB)</small>
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
<div class="file-name" id="file-name" style="display: none;"></div>
</div>
<button class="btn btn-primary" id="submit-btn" disabled>
🚀 Extract Fields
</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p>Processing document...</p>
</div>
</div>
<div class="card">
<h2>📊 Extraction Results</h2>
<div id="placeholder" style="text-align: center; padding: 40px; color: #999;">
<div style="font-size: 64px; margin-bottom: 15px;">🔍</div>
<p>Upload a document to see extraction results</p>
</div>
<div class="results" id="results">
<div class="result-header">
<span>Document: <strong id="doc-id"></strong></span>
<span class="result-status" id="result-status"></span>
</div>
<div class="fields-grid" id="fields-grid"></div>
<div class="processing-time" id="processing-time"></div>
<div class="error-message" id="error-message" style="display: none;"></div>
<div class="visualization" id="visualization" style="display: none;">
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
<img id="viz-image" src="" alt="Detection visualization">
</div>
</div>
</div>
</div>
<footer>
<p>Powered by ColaCoder</p>
</footer>
</div>
<script>
const uploadArea = document.getElementById('upload-area');
const fileInput = document.getElementById('file-input');
const fileName = document.getElementById('file-name');
const submitBtn = document.getElementById('submit-btn');
const loading = document.getElementById('loading');
const placeholder = document.getElementById('placeholder');
const results = document.getElementById('results');
let selectedFile = null;
// Drag and drop handlers
uploadArea.addEventListener('click', () => fileInput.click());
uploadArea.addEventListener('dragover', (e) => {
e.preventDefault();
uploadArea.classList.add('dragover');
});
uploadArea.addEventListener('dragleave', () => {
uploadArea.classList.remove('dragover');
});
uploadArea.addEventListener('drop', (e) => {
e.preventDefault();
uploadArea.classList.remove('dragover');
const files = e.dataTransfer.files;
if (files.length > 0) {
handleFile(files[0]);
}
});
fileInput.addEventListener('change', (e) => {
if (e.target.files.length > 0) {
handleFile(e.target.files[0]);
}
});
function handleFile(file) {
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
const ext = '.' + file.name.split('.').pop().toLowerCase();
if (!validTypes.includes(ext)) {
alert('Please upload a PDF, PNG, or JPG file.');
return;
}
selectedFile = file;
fileName.textContent = `📎 ${file.name}`;
fileName.style.display = 'block';
uploadArea.classList.add('has-file');
submitBtn.disabled = false;
}
submitBtn.addEventListener('click', async () => {
if (!selectedFile) return;
// Show loading
submitBtn.disabled = true;
loading.classList.add('active');
placeholder.style.display = 'none';
results.classList.remove('active');
try {
const formData = new FormData();
formData.append('file', selectedFile);
const response = await fetch('/api/v1/infer', {
method: 'POST',
body: formData,
});
const data = await response.json();
if (!response.ok) {
throw new Error(data.detail || 'Processing failed');
}
displayResults(data);
} catch (error) {
console.error('Error:', error);
document.getElementById('error-message').textContent = error.message;
document.getElementById('error-message').style.display = 'block';
results.classList.add('active');
} finally {
loading.classList.remove('active');
submitBtn.disabled = false;
}
});
function displayResults(data) {
const result = data.result;
// Document ID
document.getElementById('doc-id').textContent = result.document_id;
// Status
const statusEl = document.getElementById('result-status');
statusEl.textContent = result.success ? 'Success' : 'Partial';
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
// Fields
const fieldsGrid = document.getElementById('fields-grid');
fieldsGrid.innerHTML = '';
const fieldOrder = ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Amount', 'Bankgiro', 'Plusgiro'];
fieldOrder.forEach(field => {
const value = result.fields[field];
const confidence = result.confidence[field];
if (value !== null && value !== undefined) {
const fieldDiv = document.createElement('div');
fieldDiv.className = 'field-item';
fieldDiv.innerHTML = `
<label>${formatFieldName(field)}</label>
<div class="value">${value}</div>
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
`;
fieldsGrid.appendChild(fieldDiv);
}
});
// Processing time
document.getElementById('processing-time').textContent =
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
// Visualization
if (result.visualization_url) {
const vizDiv = document.getElementById('visualization');
const vizImg = document.getElementById('viz-image');
vizImg.src = result.visualization_url;
vizDiv.style.display = 'block';
}
// Errors
if (result.errors && result.errors.length > 0) {
document.getElementById('error-message').textContent = result.errors.join(', ');
document.getElementById('error-message').style.display = 'block';
} else {
document.getElementById('error-message').style.display = 'none';
}
results.classList.add('active');
}
function formatFieldName(name) {
return name.replace(/([A-Z])/g, ' $1').trim();
}
</script>
</body>
</html>
"""

69
src/web/config.py Normal file
View File

@@ -0,0 +1,69 @@
"""
Web Application Configuration
Centralized configuration for the web application.
"""
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class ModelConfig:
"""YOLO model configuration."""
model_path: Path = Path("runs/train/invoice_yolo11n_full/weights/best.pt")
confidence_threshold: float = 0.3
use_gpu: bool = True
dpi: int = 150
@dataclass(frozen=True)
class ServerConfig:
"""Server configuration."""
host: str = "0.0.0.0"
port: int = 8000
debug: bool = False
reload: bool = False
workers: int = 1
@dataclass(frozen=True)
class StorageConfig:
"""File storage configuration."""
upload_dir: Path = Path("uploads")
result_dir: Path = Path("results")
max_file_size_mb: int = 50
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
def __post_init__(self) -> None:
"""Create directories if they don't exist."""
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
object.__setattr__(self, "result_dir", Path(self.result_dir))
self.upload_dir.mkdir(parents=True, exist_ok=True)
self.result_dir.mkdir(parents=True, exist_ok=True)
@dataclass
class AppConfig:
"""Main application configuration."""
model: ModelConfig = field(default_factory=ModelConfig)
server: ServerConfig = field(default_factory=ServerConfig)
storage: StorageConfig = field(default_factory=StorageConfig)
@classmethod
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
"""Create config from dictionary."""
return cls(
model=ModelConfig(**config_dict.get("model", {})),
server=ServerConfig(**config_dict.get("server", {})),
storage=StorageConfig(**config_dict.get("storage", {})),
)
# Default configuration instance
default_config = AppConfig()

183
src/web/routes.py Normal file
View File

@@ -0,0 +1,183 @@
"""
API Routes
FastAPI route definitions for the inference API.
"""
from __future__ import annotations
import logging
import shutil
import uuid
from pathlib import Path
from typing import TYPE_CHECKING
from fastapi import APIRouter, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse
from .schemas import (
BatchInferenceResponse,
DetectionResult,
ErrorResponse,
HealthResponse,
InferenceResponse,
InferenceResult,
)
if TYPE_CHECKING:
from .services import InferenceService
from .config import StorageConfig
logger = logging.getLogger(__name__)
def create_api_router(
inference_service: "InferenceService",
storage_config: "StorageConfig",
) -> APIRouter:
"""
Create API router with inference endpoints.
Args:
inference_service: Inference service instance
storage_config: Storage configuration
Returns:
Configured APIRouter
"""
router = APIRouter(prefix="/api/v1", tags=["inference"])
@router.get("/health", response_model=HealthResponse)
async def health_check() -> HealthResponse:
"""Check service health status."""
return HealthResponse(
status="healthy",
model_loaded=inference_service.is_initialized,
gpu_available=inference_service.gpu_available,
version="1.0.0",
)
@router.post(
"/infer",
response_model=InferenceResponse,
responses={
400: {"model": ErrorResponse, "description": "Invalid file"},
500: {"model": ErrorResponse, "description": "Processing error"},
},
)
async def infer_document(
file: UploadFile = File(..., description="PDF or image file to process"),
) -> InferenceResponse:
"""
Process a document and extract invoice fields.
Accepts PDF or image files (PNG, JPG, JPEG).
Returns extracted field values with confidence scores.
"""
# Validate file extension
if not file.filename:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Filename is required",
)
file_ext = Path(file.filename).suffix.lower()
if file_ext not in storage_config.allowed_extensions:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
)
# Generate document ID
doc_id = str(uuid.uuid4())[:8]
# Save uploaded file
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
try:
with open(upload_path, "wb") as f:
shutil.copyfileobj(file.file, f)
except Exception as e:
logger.error(f"Failed to save uploaded file: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to save uploaded file",
)
try:
# Process based on file type
if file_ext == ".pdf":
service_result = inference_service.process_pdf(
upload_path, document_id=doc_id
)
else:
service_result = inference_service.process_image(
upload_path, document_id=doc_id
)
# Build response
viz_url = None
if service_result.visualization_path:
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
inference_result = InferenceResult(
document_id=service_result.document_id,
success=service_result.success,
fields=service_result.fields,
confidence=service_result.confidence,
detections=[
DetectionResult(**d) for d in service_result.detections
],
processing_time_ms=service_result.processing_time_ms,
visualization_url=viz_url,
errors=service_result.errors,
)
return InferenceResponse(
status="success" if service_result.success else "partial",
message=f"Processed document {doc_id}",
result=inference_result,
)
except Exception as e:
logger.error(f"Error processing document: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=str(e),
)
finally:
# Cleanup uploaded file
upload_path.unlink(missing_ok=True)
@router.get("/results/{filename}")
async def get_result_image(filename: str) -> FileResponse:
"""Get visualization result image."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
return FileResponse(
path=file_path,
media_type="image/png",
filename=filename,
)
@router.delete("/results/{filename}")
async def delete_result(filename: str) -> dict:
"""Delete a result file."""
file_path = storage_config.result_dir / filename
if not file_path.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Result file not found: {filename}",
)
file_path.unlink()
return {"status": "deleted", "filename": filename}
return router

83
src/web/schemas.py Normal file
View File

@@ -0,0 +1,83 @@
"""
API Request/Response Schemas
Pydantic models for API validation and serialization.
"""
from pydantic import BaseModel, Field
from typing import Any
class DetectionResult(BaseModel):
"""Single detection result."""
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
class ExtractedField(BaseModel):
"""Extracted and normalized field value."""
field_name: str = Field(..., description="Field name")
value: str | None = Field(None, description="Extracted value")
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
is_valid: bool = Field(True, description="Whether the value passed validation")
class InferenceResult(BaseModel):
"""Complete inference result for a document."""
document_id: str = Field(..., description="Document identifier")
success: bool = Field(..., description="Whether inference succeeded")
fields: dict[str, str | None] = Field(
default_factory=dict, description="Extracted field values"
)
confidence: dict[str, float] = Field(
default_factory=dict, description="Confidence scores per field"
)
detections: list[DetectionResult] = Field(
default_factory=list, description="Raw YOLO detections"
)
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
visualization_url: str | None = Field(
None, description="URL to visualization image"
)
errors: list[str] = Field(default_factory=list, description="Error messages")
class InferenceResponse(BaseModel):
"""API response for inference endpoint."""
status: str = Field(..., description="Response status: success or error")
message: str = Field(..., description="Response message")
result: InferenceResult | None = Field(None, description="Inference result")
class BatchInferenceResponse(BaseModel):
"""API response for batch inference endpoint."""
status: str = Field(..., description="Response status")
message: str = Field(..., description="Response message")
total: int = Field(..., description="Total documents processed")
successful: int = Field(..., description="Number of successful extractions")
results: list[InferenceResult] = Field(
default_factory=list, description="Individual results"
)
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
model_loaded: bool = Field(..., description="Whether model is loaded")
gpu_available: bool = Field(..., description="Whether GPU is available")
version: str = Field(..., description="API version")
class ErrorResponse(BaseModel):
"""Error response."""
status: str = Field(default="error", description="Error status")
message: str = Field(..., description="Error message")
detail: str | None = Field(None, description="Detailed error information")

270
src/web/services.py Normal file
View File

@@ -0,0 +1,270 @@
"""
Inference Service
Business logic for invoice field extraction.
"""
from __future__ import annotations
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
from PIL import Image
if TYPE_CHECKING:
from .config import ModelConfig, StorageConfig
logger = logging.getLogger(__name__)
@dataclass
class ServiceResult:
"""Result from inference service."""
document_id: str
success: bool = False
fields: dict[str, str | None] = field(default_factory=dict)
confidence: dict[str, float] = field(default_factory=dict)
detections: list[dict] = field(default_factory=list)
processing_time_ms: float = 0.0
visualization_path: Path | None = None
errors: list[str] = field(default_factory=list)
class InferenceService:
"""
Service for running invoice field extraction.
Encapsulates YOLO detection and OCR extraction logic.
"""
def __init__(
self,
model_config: ModelConfig,
storage_config: StorageConfig,
) -> None:
"""
Initialize inference service.
Args:
model_config: Model configuration
storage_config: Storage configuration
"""
self.model_config = model_config
self.storage_config = storage_config
self._pipeline = None
self._detector = None
self._is_initialized = False
def initialize(self) -> None:
"""Initialize the inference pipeline (lazy loading)."""
if self._is_initialized:
return
logger.info("Initializing inference service...")
start_time = time.time()
try:
from ..inference.pipeline import InferencePipeline
from ..inference.yolo_detector import YOLODetector
# Initialize YOLO detector for visualization
self._detector = YOLODetector(
str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
device="cuda" if self.model_config.use_gpu else "cpu",
)
# Initialize full pipeline
self._pipeline = InferencePipeline(
model_path=str(self.model_config.model_path),
confidence_threshold=self.model_config.confidence_threshold,
use_gpu=self.model_config.use_gpu,
dpi=self.model_config.dpi,
enable_fallback=True,
)
self._is_initialized = True
elapsed = time.time() - start_time
logger.info(f"Inference service initialized in {elapsed:.2f}s")
except Exception as e:
logger.error(f"Failed to initialize inference service: {e}")
raise
@property
def is_initialized(self) -> bool:
"""Check if service is initialized."""
return self._is_initialized
@property
def gpu_available(self) -> bool:
"""Check if GPU is available."""
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def process_image(
self,
image_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process an image file and extract invoice fields.
Args:
image_path: Path to image file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_visualization(image_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing image {image_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def process_pdf(
self,
pdf_path: Path,
document_id: str | None = None,
save_visualization: bool = True,
) -> ServiceResult:
"""
Process a PDF file and extract invoice fields.
Args:
pdf_path: Path to PDF file
document_id: Optional document ID
save_visualization: Whether to save visualization
Returns:
ServiceResult with extracted fields
"""
if not self._is_initialized:
self.initialize()
doc_id = document_id or str(uuid.uuid4())[:8]
start_time = time.time()
result = ServiceResult(document_id=doc_id)
try:
# Run inference pipeline
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
result.fields = pipeline_result.fields
result.confidence = pipeline_result.confidence
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
# Save visualization (render first page)
if save_visualization and pipeline_result.raw_detections:
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
result.visualization_path = viz_path
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
result.errors.append(str(e))
result.success = False
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
from ..pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Render first page
for page_no, image_bytes in render_pdf_to_images(
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
# If no pages rendered
return None

View File

@@ -1,4 +1,5 @@
from .annotation_generator import AnnotationGenerator, generate_annotations from .annotation_generator import AnnotationGenerator, generate_annotations
from .dataset_builder import DatasetBuilder from .dataset_builder import DatasetBuilder
from .db_dataset import DBYOLODataset, create_datasets
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder'] __all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder', 'DBYOLODataset', 'create_datasets']

View File

@@ -174,20 +174,47 @@ class AnnotationGenerator:
output_path: str | Path, output_path: str | Path,
train_path: str = 'train/images', train_path: str = 'train/images',
val_path: str = 'val/images', val_path: str = 'val/images',
test_path: str = 'test/images' test_path: str = 'test/images',
use_wsl_paths: bool | None = None
) -> None: ) -> None:
"""Generate YOLO dataset YAML configuration.""" """
Generate YOLO dataset YAML configuration.
Args:
output_path: Path to output YAML file
train_path: Relative path to training images
val_path: Relative path to validation images
test_path: Relative path to test images
use_wsl_paths: If True, convert Windows paths to WSL format.
If None, auto-detect based on environment.
"""
import os
import platform
output_path = Path(output_path) output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
# Use absolute path for WSL compatibility
dataset_dir = output_path.parent.absolute() dataset_dir = output_path.parent.absolute()
# Convert Windows path to WSL path if needed dataset_path_str = str(dataset_dir)
dataset_path_str = str(dataset_dir).replace('\\', '/')
if dataset_path_str[1] == ':': # Auto-detect WSL environment
# Windows path like C:/... -> /mnt/c/... if use_wsl_paths is None:
# Check if running inside WSL
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
use_wsl_paths = is_wsl
# Convert path format based on environment
if use_wsl_paths:
# Running in WSL: convert Windows paths to /mnt/c/... format
dataset_path_str = dataset_path_str.replace('\\', '/')
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
drive = dataset_path_str[0].lower() drive = dataset_path_str[0].lower()
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}" dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
elif platform.system() == 'Windows':
# Running on native Windows: use forward slashes for YOLO compatibility
dataset_path_str = dataset_path_str.replace('\\', '/')
config = f"""# Invoice Field Detection Dataset config = f"""# Invoice Field Detection Dataset
path: {dataset_path_str} path: {dataset_path_str}

625
src/yolo/db_dataset.py Normal file
View File

@@ -0,0 +1,625 @@
"""
Database-backed YOLO Dataset
Loads images from filesystem and labels from PostgreSQL database.
Generates YOLO format labels dynamically at training time.
"""
from __future__ import annotations
import logging
import random
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Any, Optional
import numpy as np
from PIL import Image
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
logger = logging.getLogger(__name__)
# Module-level LRU cache for image loading (shared across dataset instances)
@lru_cache(maxsize=256)
def _load_image_cached(image_path: str) -> tuple[np.ndarray, int, int]:
"""
Load and cache image from disk.
Args:
image_path: Path to image file (must be string for hashability)
Returns:
Tuple of (image_array, width, height)
"""
image = Image.open(image_path).convert('RGB')
width, height = image.size
image_array = np.array(image)
return image_array, width, height
def clear_image_cache():
"""Clear the image cache to free memory."""
_load_image_cached.cache_clear()
@dataclass
class DatasetItem:
"""Single item in the dataset."""
document_id: str
image_path: Path
page_no: int
labels: list[YOLOAnnotation]
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
class DBYOLODataset:
"""
YOLO Dataset that reads labels from database.
This dataset:
1. Scans temp directory for rendered images
2. Queries database for bbox data
3. Generates YOLO labels dynamically
4. Performs train/val/test split at runtime
"""
def __init__(
self,
images_dir: str | Path,
db: Any, # DocumentDB instance
split: str = 'train',
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
dpi: int = 300,
min_confidence: float = 0.7,
bbox_padding_px: int = 20,
min_bbox_height_px: int = 30,
limit: int | None = None,
):
"""
Initialize database-backed YOLO dataset.
Args:
images_dir: Directory containing temp/{doc_id}/images/*.png
db: DocumentDB instance for label queries
split: Which split to use ('train', 'val', 'test')
train_ratio: Ratio for training set
val_ratio: Ratio for validation set
seed: Random seed for reproducible splits
dpi: DPI used for rendering (for coordinate conversion)
min_confidence: Minimum match score to include
bbox_padding_px: Padding around bboxes
min_bbox_height_px: Minimum bbox height
limit: Maximum number of documents to use (None for all)
"""
self.images_dir = Path(images_dir)
self.db = db
self.split = split
self.train_ratio = train_ratio
self.val_ratio = val_ratio
self.seed = seed
self.dpi = dpi
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
self.limit = limit
# Load and split dataset
self.items: list[DatasetItem] = []
self._all_items: list[DatasetItem] = [] # Cache all items for sharing
self._doc_ids_ordered: list[str] = [] # Cache ordered doc IDs for consistent splits
self._load_dataset()
@classmethod
def from_shared_data(
cls,
source_dataset: 'DBYOLODataset',
split: str,
) -> 'DBYOLODataset':
"""
Create a new dataset instance sharing data from an existing one.
This avoids re-loading data from filesystem and database.
Args:
source_dataset: Dataset to share data from
split: Which split to use ('train', 'val', 'test')
Returns:
New dataset instance with shared data
"""
# Create instance without loading (we'll share data)
instance = object.__new__(cls)
# Copy configuration
instance.images_dir = source_dataset.images_dir
instance.db = source_dataset.db
instance.split = split
instance.train_ratio = source_dataset.train_ratio
instance.val_ratio = source_dataset.val_ratio
instance.seed = source_dataset.seed
instance.dpi = source_dataset.dpi
instance.min_confidence = source_dataset.min_confidence
instance.bbox_padding_px = source_dataset.bbox_padding_px
instance.min_bbox_height_px = source_dataset.min_bbox_height_px
instance.limit = source_dataset.limit
# Share loaded data
instance._all_items = source_dataset._all_items
instance._doc_ids_ordered = source_dataset._doc_ids_ordered
# Split items for this split
instance.items = instance._split_dataset_from_cache()
print(f"Split '{split}': {len(instance.items)} items")
return instance
def _load_dataset(self):
"""Load dataset items from filesystem and database."""
# Find all document directories
temp_dir = self.images_dir / 'temp'
if not temp_dir.exists():
print(f"Temp directory not found: {temp_dir}")
return
# Collect all document IDs with images
doc_image_map: dict[str, list[Path]] = {}
for doc_dir in temp_dir.iterdir():
if not doc_dir.is_dir():
continue
images_path = doc_dir / 'images'
if not images_path.exists():
continue
images = list(images_path.glob('*.png'))
if images:
doc_image_map[doc_dir.name] = sorted(images)
print(f"Found {len(doc_image_map)} documents with images")
# Query database for all document labels
doc_ids = list(doc_image_map.keys())
doc_labels = self._load_labels_from_db(doc_ids)
print(f"Loaded labels for {len(doc_labels)} documents from database")
# Build dataset items
all_items: list[DatasetItem] = []
skipped_no_labels = 0
skipped_no_db_record = 0
total_images = 0
for doc_id, images in doc_image_map.items():
doc_data = doc_labels.get(doc_id)
# Skip documents that don't exist in database
if doc_data is None:
skipped_no_db_record += len(images)
total_images += len(images)
continue
labels_by_page, is_scanned = doc_data
for image_path in images:
total_images += 1
# Extract page number from filename (e.g., "doc_page_000.png")
page_no = self._extract_page_no(image_path.stem)
# Get labels for this page
page_labels = labels_by_page.get(page_no, [])
if page_labels: # Only include pages with labels
all_items.append(DatasetItem(
document_id=doc_id,
image_path=image_path,
page_no=page_no,
labels=page_labels,
is_scanned=is_scanned
))
else:
skipped_no_labels += 1
print(f"Total images found: {total_images}")
print(f"Images with labels: {len(all_items)}")
if skipped_no_db_record > 0:
print(f"Skipped {skipped_no_db_record} images (document not in database)")
if skipped_no_labels > 0:
print(f"Skipped {skipped_no_labels} images (no labels for page)")
# Cache all items for sharing with other splits
self._all_items = all_items
# Split dataset
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
print(f"Split '{self.split}': {len(self.items)} items")
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]:
"""
Load labels from database for given document IDs using batch queries.
Returns:
Dict of doc_id -> (page_labels, is_scanned)
where page_labels is {page_no -> list[YOLOAnnotation]}
and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
"""
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {}
# Query in batches using efficient batch method
batch_size = 500
for i in range(0, len(doc_ids), batch_size):
batch_ids = doc_ids[i:i + batch_size]
# Use batch query instead of individual queries (N+1 fix)
docs_batch = self.db.get_documents_batch(batch_ids)
for doc_id, doc in docs_batch.items():
if not doc.get('success'):
continue
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
is_scanned = doc.get('pdf_type') == 'scanned'
page_labels: dict[int, list[YOLOAnnotation]] = {}
for field_result in doc.get('field_results', []):
if not field_result.get('matched'):
continue
field_name = field_result.get('field_name')
if field_name not in FIELD_CLASSES:
continue
score = field_result.get('score', 0)
if score < self.min_confidence:
continue
bbox = field_result.get('bbox')
page_no = field_result.get('page_no', 0)
if bbox and len(bbox) == 4:
annotation = self._create_annotation(
field_name=field_name,
bbox=bbox,
score=score
)
if page_no not in page_labels:
page_labels[page_no] = []
page_labels[page_no].append(annotation)
if page_labels:
result[doc_id] = (page_labels, is_scanned)
return result
def _create_annotation(
self,
field_name: str,
bbox: list[float],
score: float
) -> YOLOAnnotation:
"""
Create a YOLO annotation from bbox.
Note: bbox is in PDF points (72 DPI), will be normalized later.
"""
class_id = FIELD_CLASSES[field_name]
x0, y0, x1, y1 = bbox
# Store raw PDF coordinates - will be normalized when getting item
return YOLOAnnotation(
class_id=class_id,
x_center=(x0 + x1) / 2, # center in PDF points
y_center=(y0 + y1) / 2,
width=x1 - x0,
height=y1 - y0,
confidence=score
)
def _extract_page_no(self, stem: str) -> int:
"""Extract page number from image filename."""
# Format: "{doc_id}_page_{page_no:03d}"
parts = stem.rsplit('_', 1)
if len(parts) == 2:
try:
return int(parts[1])
except ValueError:
pass
return 0
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
"""
Split items into train/val/test based on document ID.
Returns:
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
reused for consistent splits across shared datasets.
"""
# Group by document ID for proper splitting
doc_items: dict[str, list[DatasetItem]] = {}
for item in items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_items[item.document_id].append(item)
# Shuffle document IDs
doc_ids = list(doc_items.keys())
random.seed(self.seed)
random.shuffle(doc_ids)
# Apply limit if specified (before splitting)
if self.limit is not None and self.limit < len(doc_ids):
doc_ids = doc_ids[:self.limit]
print(f"Limited to {self.limit} documents")
# Calculate split indices
n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio)
# Split document IDs
if self.split == 'train':
split_doc_ids = doc_ids[:n_train]
elif self.split == 'val':
split_doc_ids = doc_ids[n_train:n_train + n_val]
else: # test
split_doc_ids = doc_ids[n_train + n_val:]
# Collect items for this split
split_items = []
for doc_id in split_doc_ids:
split_items.extend(doc_items[doc_id])
return split_items, doc_ids
def _split_dataset_from_cache(self) -> list[DatasetItem]:
"""
Split items using cached data from a shared dataset.
Uses pre-computed doc_ids order for consistent splits.
"""
# Group by document ID
doc_items: dict[str, list[DatasetItem]] = {}
for item in self._all_items:
if item.document_id not in doc_items:
doc_items[item.document_id] = []
doc_items[item.document_id].append(item)
# Use cached doc_ids order
doc_ids = self._doc_ids_ordered
# Calculate split indices
n_total = len(doc_ids)
n_train = int(n_total * self.train_ratio)
n_val = int(n_total * self.val_ratio)
# Split document IDs based on split type
if self.split == 'train':
split_doc_ids = doc_ids[:n_train]
elif self.split == 'val':
split_doc_ids = doc_ids[n_train:n_train + n_val]
else: # test
split_doc_ids = doc_ids[n_train + n_val:]
# Collect items for this split
split_items = []
for doc_id in split_doc_ids:
if doc_id in doc_items:
split_items.extend(doc_items[doc_id])
return split_items
def __len__(self) -> int:
return len(self.items)
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
"""
Get image and labels for index.
Returns:
(image, labels) where:
- image: numpy array (H, W, C)
- labels: numpy array (N, 5) with [class_id, x_center, y_center, width, height]
"""
item = self.items[idx]
# Load image using LRU cache (significant speedup during training)
image_array, img_width, img_height = _load_image_cached(str(item.image_path))
# Convert annotations to YOLO format (normalized)
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
return image_array, labels
def _convert_labels(
self,
annotations: list[YOLOAnnotation],
img_width: int,
img_height: int,
is_scanned: bool = False
) -> np.ndarray:
"""
Convert annotations to normalized YOLO format.
Args:
annotations: List of annotations
img_width: Actual image width in pixels
img_height: Actual image height in pixels
is_scanned: If True, bbox is already in pixels; if False, bbox is in PDF points
Returns:
numpy array (N, 5) with [class_id, x_center, y_center, width, height]
"""
if not annotations:
return np.zeros((0, 5), dtype=np.float32)
# Scale factor: PDF points (72 DPI) -> rendered pixels
# Note: After the OCR coordinate fix, ALL bbox (both text and scanned PDF)
# are stored in PDF points, so we always apply the same scaling.
scale = self.dpi / 72.0
labels = []
for ann in annotations:
# Convert to pixels (if needed)
x_center_px = ann.x_center * scale
y_center_px = ann.y_center * scale
width_px = ann.width * scale
height_px = ann.height * scale
# Add padding
pad = self.bbox_padding_px
width_px += 2 * pad
height_px += 2 * pad
# Ensure minimum height
if height_px < self.min_bbox_height_px:
height_px = self.min_bbox_height_px
# Normalize to 0-1
x_center = x_center_px / img_width
y_center = y_center_px / img_height
width = width_px / img_width
height = height_px / img_height
# Clamp to valid range
x_center = max(0, min(1, x_center))
y_center = max(0, min(1, y_center))
width = max(0, min(1, width))
height = max(0, min(1, height))
labels.append([ann.class_id, x_center, y_center, width, height])
return np.array(labels, dtype=np.float32)
def get_image_path(self, idx: int) -> Path:
"""Get image path for index."""
return self.items[idx].image_path
def get_labels_for_yolo(self, idx: int) -> str:
"""
Get YOLO format labels as string for index.
Returns:
String with YOLO format labels (one per line)
"""
item = self.items[idx]
# Use cached image loading to avoid duplicate disk reads
_, img_width, img_height = _load_image_cached(str(item.image_path))
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
lines = []
for label in labels:
class_id = int(label[0])
x_center, y_center, width, height = label[1:5]
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
return '\n'.join(lines)
def export_to_yolo_format(
self,
output_dir: str | Path,
split_name: Optional[str] = None
) -> int:
"""
Export dataset to standard YOLO format (images + label files).
This is useful for training with standard YOLO training scripts.
Args:
output_dir: Output directory
split_name: Name for the split subdirectory (default: self.split)
Returns:
Number of items exported
"""
import shutil
output_dir = Path(output_dir)
split_name = split_name or self.split
images_out = output_dir / split_name / 'images'
labels_out = output_dir / split_name / 'labels'
# Clear existing directories before export
if images_out.exists():
shutil.rmtree(images_out)
if labels_out.exists():
shutil.rmtree(labels_out)
images_out.mkdir(parents=True, exist_ok=True)
labels_out.mkdir(parents=True, exist_ok=True)
count = 0
for idx in range(len(self)):
item = self.items[idx]
# Copy image
dest_image = images_out / item.image_path.name
shutil.copy2(item.image_path, dest_image)
# Write label file
label_content = self.get_labels_for_yolo(idx)
label_path = labels_out / f"{item.image_path.stem}.txt"
with open(label_path, 'w') as f:
f.write(label_content)
count += 1
print(f"Exported {count} items to {output_dir / split_name}")
return count
def create_datasets(
images_dir: str | Path,
db: Any,
train_ratio: float = 0.8,
val_ratio: float = 0.1,
seed: int = 42,
limit: int | None = None,
**kwargs
) -> dict[str, DBYOLODataset]:
"""
Create train/val/test datasets.
This function loads data once and shares it across all splits for efficiency.
Args:
images_dir: Directory containing temp/{doc_id}/images/
db: DocumentDB instance
train_ratio: Training set ratio
val_ratio: Validation set ratio
seed: Random seed
limit: Maximum number of documents to use (None for all)
**kwargs: Additional arguments for DBYOLODataset
Returns:
Dict with 'train', 'val', 'test' datasets
"""
# Create first dataset which loads all data
print("Loading dataset (this may take a few minutes for large datasets)...")
first_dataset = DBYOLODataset(
images_dir=images_dir,
db=db,
split='train',
train_ratio=train_ratio,
val_ratio=val_ratio,
seed=seed,
limit=limit,
**kwargs
)
# Create other splits by sharing loaded data
datasets = {'train': first_dataset}
for split in ['val', 'test']:
datasets[split] = DBYOLODataset.from_shared_data(
first_dataset,
split=split,
)
return datasets