WOP
This commit is contained in:
40
.claude/README.md
Normal file
40
.claude/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
# Claude Code Configuration
|
||||
|
||||
This directory contains Claude Code specific configurations.
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### Main Controller
|
||||
- **Location**: `../CLAUDE.md` (project root)
|
||||
- **Purpose**: Main controller configuration for the Swedish Invoice Extraction System
|
||||
- **Version**: v1.3.0
|
||||
|
||||
### Sub-Agents
|
||||
Located in `agents/` directory:
|
||||
- `developer.md` - Development agent
|
||||
- `code-reviewer.md` - Code review agent
|
||||
- `tester.md` - Testing agent
|
||||
- `researcher.md` - Research agent
|
||||
- `project-manager.md` - Project management agent
|
||||
|
||||
### Skills
|
||||
Located in `skills/` directory:
|
||||
- `code-generation.md` - High-quality code generation skill
|
||||
|
||||
## Important Notes
|
||||
|
||||
⚠️ **The main CLAUDE.md file is in the project root**, not in this directory.
|
||||
|
||||
This is intentional because:
|
||||
1. CLAUDE.md is a project-level configuration
|
||||
2. It should be visible alongside README.md and other important docs
|
||||
3. It serves as the "constitution" for the entire project
|
||||
|
||||
When Claude Code starts, it will read:
|
||||
1. `../CLAUDE.md` (main controller instructions)
|
||||
2. Files in `agents/` (when agents are called)
|
||||
3. Files in `skills/` (when skills are used)
|
||||
|
||||
---
|
||||
|
||||
For the full main controller configuration, see: [../CLAUDE.md](../CLAUDE.md)
|
||||
7
.claude/config.toml
Normal file
7
.claude/config.toml
Normal file
@@ -0,0 +1,7 @@
|
||||
[permissions]
|
||||
read = true
|
||||
write = true
|
||||
execute = true
|
||||
|
||||
[permissions.scope]
|
||||
paths = ["."]
|
||||
0
.claude/product_manager.md
Normal file
0
.claude/product_manager.md
Normal file
13
.claude/settings.json
Normal file
13
.claude/settings.json
Normal file
@@ -0,0 +1,13 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(*)",
|
||||
"Read(*)",
|
||||
"Write(*)",
|
||||
"Edit(*)",
|
||||
"Glob(*)",
|
||||
"Grep(*)",
|
||||
"Task(*)"
|
||||
]
|
||||
}
|
||||
}
|
||||
81
.claude/settings.local.json
Normal file
81
.claude/settings.local.json
Normal file
@@ -0,0 +1,81 @@
|
||||
{
|
||||
"permissions": {
|
||||
"allow": [
|
||||
"Bash(*)",
|
||||
"Bash(wsl*)",
|
||||
"Bash(wsl -e bash*)",
|
||||
"Read(*)",
|
||||
"Write(*)",
|
||||
"Edit(*)",
|
||||
"Glob(*)",
|
||||
"Grep(*)",
|
||||
"WebFetch(*)",
|
||||
"WebSearch(*)",
|
||||
"Task(*)",
|
||||
"Bash(wsl -e bash -c:*)",
|
||||
"Bash(powershell -c:*)",
|
||||
"Bash(dir \"C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2\\\\runs\\\\detect\\\\runs\\\\train\\\\invoice_fields_v3\"\")",
|
||||
"Bash(timeout:*)",
|
||||
"Bash(powershell:*)",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && nvidia-smi 2>/dev/null | head -10\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && nohup python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4 > training.log 2>&1 &\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && sleep 10 && tail -20 training.log 2>/dev/null\":*)",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && cat training.log 2>/dev/null | head -30\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la training.log 2>/dev/null && ps aux | grep python\")",
|
||||
"Bash(wsl -e bash -c \"ps aux | grep -E ''python|train''\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4 2>&1 | tee training.log &\")",
|
||||
"Bash(wsl -e bash -c \"sleep 15 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && tail -15 training.log\":*)",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -m src.cli.train --data data/dataset/dataset.yaml --model yolo11s.pt --epochs 100 --batch 8 --device 0 --name invoice_fields_v4\")",
|
||||
"Bash(wsl -e bash -c \"which screen || sudo apt-get install -y screen 2>/dev/null\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python3 -c \"\"\nfrom ultralytics import YOLO\n\n# Load model\nmodel = YOLO\\(''runs/detect/runs/train/invoice_fields_v4/weights/best.pt''\\)\n\n# Run inference on a test image\nresults = model.predict\\(\n ''data/dataset/test/images/36a4fd23-0a66-4428-9149-4f95c93db9cb_page_000.png'',\n conf=0.5,\n save=True,\n project=''results'',\n name=''test_inference''\n\\)\n\n# Print results\nfor r in results:\n print\\(''Image:'', r.path\\)\n print\\(''Boxes:'', len\\(r.boxes\\)\\)\n for box in r.boxes:\n cls = int\\(box.cls[0]\\)\n conf = float\\(box.conf[0]\\)\n name = model.names[cls]\n print\\(f'' - {name}: {conf:.2%}''\\)\n\"\"\")",
|
||||
"Bash(python:*)",
|
||||
"Bash(dir:*)",
|
||||
"Bash(timeout 180 tail:*)",
|
||||
"Bash(python3:*)",
|
||||
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source .venv/bin/activate && python -m src.cli.autolabel --csv ''data/structured_data/document_export_20260110_141554_page1.csv,data/structured_data/document_export_20260110_141612_page2.csv'' --report reports/autolabel_test_2csv.jsonl --workers 2\")",
|
||||
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice && python -m src.cli.autolabel --csv ''data/structured_data/document_export_20260110_141554_page1.csv,data/structured_data/document_export_20260110_141612_page2.csv'' --report reports/autolabel_test_2csv.jsonl --workers 2\")",
|
||||
"Bash(wsl -d Ubuntu-22.04 -- bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda info --envs\":*)",
|
||||
"Bash(wsl:*)",
|
||||
"Bash(for f in reports/autolabel_shard_test_part*.jsonl)",
|
||||
"Bash(done)",
|
||||
"Bash(timeout 10 tail:*)",
|
||||
"Bash(more:*)",
|
||||
"Bash(cmd /c type \"C:\\\\Users\\\\yaoji\\\\AppData\\\\Local\\\\Temp\\\\claude\\\\c--Users-yaoji-git-ColaCoder-invoice-master-poc-v2\\\\tasks\\\\b4d8070.output\")",
|
||||
"Bash(cmd /c \"dir C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2\\\\reports\\\\autolabel_report_v4*.jsonl\")",
|
||||
"Bash(wsl wc:*)",
|
||||
"Bash(wsl bash:*)",
|
||||
"Bash(wsl bash -c \"ps aux | grep python | grep -v grep\")",
|
||||
"Bash(wsl bash -c \"kill -9 130864 130870 414045 414046\")",
|
||||
"Bash(wsl bash -c \"ps aux | grep python | grep -v grep | grep -v networkd | grep -v unattended\")",
|
||||
"Bash(wsl bash -c \"kill -9 414046 2>/dev/null; pkill -9 -f autolabel 2>/dev/null; sleep 1; ps aux | grep autolabel | grep -v grep\")",
|
||||
"Bash(python -m src.cli.import_report_to_db:*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && pip install psycopg2-binary\")",
|
||||
"Bash(conda activate:*)",
|
||||
"Bash(python -m src.cli.analyze_report:*)",
|
||||
"Bash(/c/Users/yaoji/miniconda3/envs/yolo11/python.exe -m src.cli.analyze_report:*)",
|
||||
"Bash(c:/Users/yaoji/miniconda3/envs/yolo11/python.exe -m src.cli.analyze_report:*)",
|
||||
"Bash(cmd /c \"cd /d C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2 && C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report\")",
|
||||
"Bash(cmd /c \"cd /d C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2 && C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report 2>&1\")",
|
||||
"Bash(powershell -Command \"cd C:\\\\Users\\\\yaoji\\\\git\\\\ColaCoder\\\\invoice-master-poc-v2; C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\\\\yolo11\\\\python.exe -m src.cli.analyze_report\":*)",
|
||||
"Bash(powershell -Command \"ls C:\\\\Users\\\\yaoji\\\\miniconda3\\\\envs\"\")",
|
||||
"Bash(where:*)",
|
||||
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
|
||||
"Bash(\"C:/Users/yaoji/anaconda3/envs/torch-gpu/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
|
||||
"Bash(\"C:/Users/yaoji/anaconda3/python.exe\" -c \"import psycopg2; print\\(''psycopg2 OK''\\)\")",
|
||||
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -m pip install psycopg2-binary)",
|
||||
"Bash(\"C:/Users/yaoji/anaconda3/envs/invoice-master/python.exe\" -m src.cli.analyze_report)",
|
||||
"Bash(wsl -d Ubuntu bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && conda activate invoice-master && python -m src.cli.autolabel --help\":*)",
|
||||
"Bash(wsl -d Ubuntu-22.04 bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-extract && python -m src.cli.train --export-only 2>&1\")",
|
||||
"Bash(wsl -d Ubuntu-22.04 bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la data/dataset/\")",
|
||||
"Bash(wsl -d Ubuntu-22.04 bash -c \"ps aux | grep python | grep -v grep\")",
|
||||
"Bash(wsl -d Ubuntu-22.04 bash -c:*)",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && ls -la\")",
|
||||
"Bash(wsl -e bash -c \"cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-master && python -c \"\"\nimport sys\nsys.path.insert\\(0, ''.''\\)\nfrom src.data.db import DocumentDB\nfrom src.yolo.db_dataset import DBYOLODataset\n\n# Connect to database\ndb = DocumentDB\\(\\)\ndb.connect\\(\\)\n\n# Create dataset\ndataset = DBYOLODataset\\(\n images_dir=''data/dataset'',\n db=db,\n split=''train'',\n train_ratio=0.8,\n val_ratio=0.1,\n seed=42,\n dpi=300\n\\)\n\nprint\\(f''Dataset size: {len\\(dataset\\)}''\\)\n\nif len\\(dataset\\) > 0:\n # Check first few items\n for i in range\\(min\\(3, len\\(dataset\\)\\)\\):\n item = dataset.items[i]\n print\\(f''\\\\n--- Item {i} ---''\\)\n print\\(f''Document: {item.document_id}''\\)\n print\\(f''Is scanned: {item.is_scanned}''\\)\n print\\(f''Image: {item.image_path.name}''\\)\n \n # Get YOLO labels\n yolo_labels = dataset.get_labels_for_yolo\\(i\\)\n print\\(f''YOLO labels:''\\)\n for line in yolo_labels.split\\(''\\\\n''\\)[:3]:\n print\\(f'' {line}''\\)\n # Check if values are normalized\n parts = line.split\\(\\)\n if len\\(parts\\) == 5:\n x, y, w, h = float\\(parts[1]\\), float\\(parts[2]\\), float\\(parts[3]\\), float\\(parts[4]\\)\n if x > 1 or y > 1 or w > 1 or h > 1:\n print\\(f'' WARNING: Values not normalized!''\\)\n elif x == 1.0 or y == 1.0:\n print\\(f'' WARNING: Values clamped to 1.0!''\\)\n else:\n print\\(f'' OK: Values properly normalized''\\)\n\ndb.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/\")",
|
||||
"Bash(wsl -e bash -c \"ls -la /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2/data/dataset/train/\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
"defaultMode": "default"
|
||||
}
|
||||
}
|
||||
9
.gitignore
vendored
9
.gitignore
vendored
@@ -34,13 +34,8 @@ env/
|
||||
*~
|
||||
|
||||
# Data files (large files)
|
||||
data/raw_pdfs/
|
||||
data/dataset/train/images/
|
||||
data/dataset/val/images/
|
||||
data/dataset/test/images/
|
||||
data/dataset/train/labels/
|
||||
data/dataset/val/labels/
|
||||
data/dataset/test/labels/
|
||||
/data/
|
||||
/results/
|
||||
*.pdf
|
||||
*.png
|
||||
*.jpg
|
||||
|
||||
452
README.md
452
README.md
@@ -1,90 +1,62 @@
|
||||
# Invoice Master POC v2
|
||||
|
||||
自动账单信息提取系统 - 使用 YOLO + OCR 从 PDF 发票中提取结构化数据。
|
||||
自动账单信息提取系统 - 使用 YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
|
||||
|
||||
## 运行环境
|
||||
|
||||
> **重要**: 本项目需要在 **WSL (Windows Subsystem for Linux)** 环境下运行。
|
||||
本项目**必须**在 **WSL + Conda** 环境中运行。
|
||||
|
||||
### 系统要求
|
||||
|
||||
- WSL 2 (Ubuntu 22.04 推荐)
|
||||
- Python 3.10+
|
||||
- **NVIDIA GPU + CUDA 12.x (强烈推荐)** - GPU 训练比 CPU 快 10-50 倍
|
||||
| 环境 | 要求 |
|
||||
|------|------|
|
||||
| **WSL** | WSL 2 + Ubuntu 22.04 |
|
||||
| **Conda** | Miniconda 或 Anaconda |
|
||||
| **Python** | 3.10+ (通过 Conda 管理) |
|
||||
| **GPU** | NVIDIA GPU + CUDA 12.x (强烈推荐) |
|
||||
| **数据库** | PostgreSQL (存储标注结果) |
|
||||
|
||||
## 功能特点
|
||||
|
||||
- **双模式 PDF 处理**: 支持文本层 PDF 和扫描图 PDF
|
||||
- **自动标注**: 利用已有 CSV 结构化数据自动生成 YOLO 训练数据
|
||||
- **字段检测**: 使用 YOLOv8 检测发票字段区域
|
||||
- **OCR 识别**: 使用 PaddleOCR 提取检测区域的文本
|
||||
- **智能匹配**: 支持多种格式规范化和上下文关键词增强
|
||||
- **多池处理架构**: CPU 池处理文本 PDF,GPU 池处理扫描 PDF
|
||||
- **数据库存储**: 标注结果存储在 PostgreSQL,支持增量处理
|
||||
- **YOLO 检测**: 使用 YOLOv11 检测发票字段区域
|
||||
- **OCR 识别**: 使用 PaddleOCR 3.x 提取检测区域的文本
|
||||
- **Web 应用**: 提供 REST API 和可视化界面
|
||||
- **增量训练**: 支持在已训练模型基础上继续训练
|
||||
|
||||
## 支持的字段
|
||||
|
||||
| 字段 | 说明 |
|
||||
|------|------|
|
||||
| InvoiceNumber | 发票号码 |
|
||||
| InvoiceDate | 发票日期 |
|
||||
| InvoiceDueDate | 到期日期 |
|
||||
| OCR | OCR 参考号 (瑞典) |
|
||||
| Bankgiro | Bankgiro 号码 |
|
||||
| Plusgiro | Plusgiro 号码 |
|
||||
| Amount | 金额 |
|
||||
| 类别 ID | 字段名 | 说明 |
|
||||
|---------|--------|------|
|
||||
| 0 | invoice_number | 发票号码 |
|
||||
| 1 | invoice_date | 发票日期 |
|
||||
| 2 | invoice_due_date | 到期日期 |
|
||||
| 3 | ocr_number | OCR 参考号 (瑞典支付系统) |
|
||||
| 4 | bankgiro | Bankgiro 号码 |
|
||||
| 5 | plusgiro | Plusgiro 号码 |
|
||||
| 6 | amount | 金额 |
|
||||
|
||||
## 安装 (WSL)
|
||||
|
||||
### 1. 进入 WSL 环境
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
# 从 Windows 终端进入 WSL
|
||||
wsl
|
||||
# 1. 进入 WSL
|
||||
wsl -d Ubuntu-22.04
|
||||
|
||||
# 进入项目目录 (Windows 路径映射到 /mnt/)
|
||||
# 2. 创建 Conda 环境
|
||||
conda create -n invoice-py311 python=3.11 -y
|
||||
conda activate invoice-py311
|
||||
|
||||
# 3. 进入项目目录
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
```
|
||||
|
||||
### 2. 安装系统依赖
|
||||
|
||||
```bash
|
||||
# 更新系统
|
||||
sudo apt update && sudo apt upgrade -y
|
||||
|
||||
# 安装 Python 和必要工具
|
||||
sudo apt install -y python3.10 python3.10-venv python3-pip
|
||||
|
||||
# 安装 OpenCV 依赖
|
||||
sudo apt install -y libgl1-mesa-glx libglib2.0-0 libsm6 libxrender1 libxext6
|
||||
```
|
||||
|
||||
### 3. 创建虚拟环境并安装依赖
|
||||
|
||||
```bash
|
||||
# 创建虚拟环境
|
||||
python3 -m venv venv
|
||||
source venv/bin/activate
|
||||
|
||||
# 升级 pip
|
||||
pip install --upgrade pip
|
||||
|
||||
# 安装依赖
|
||||
# 4. 安装依赖
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 或使用 pip install (开发模式)
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### GPU 支持 (可选)
|
||||
|
||||
```bash
|
||||
# 确保 WSL 已配置 CUDA
|
||||
nvidia-smi # 检查 GPU 是否可用
|
||||
|
||||
# 安装 GPU 版本 PaddlePaddle
|
||||
pip install paddlepaddle-gpu
|
||||
|
||||
# 或指定 CUDA 版本
|
||||
pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/whl/linux/mkl/avx/stable.html
|
||||
# 5. 安装 Web 依赖
|
||||
pip install uvicorn fastapi python-multipart pydantic
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
@@ -92,12 +64,14 @@ pip install paddlepaddle-gpu==2.5.2.post118 -f https://www.paddlepaddle.org.cn/w
|
||||
### 1. 准备数据
|
||||
|
||||
```
|
||||
data/
|
||||
~/invoice-data/
|
||||
├── raw_pdfs/
|
||||
│ ├── {DocumentId}.pdf
|
||||
│ └── ...
|
||||
└── structured_data/
|
||||
└── invoices.csv
|
||||
├── structured_data/
|
||||
│ └── document_export_YYYYMMDD.csv
|
||||
└── dataset/
|
||||
└── temp/ (渲染的图片)
|
||||
```
|
||||
|
||||
CSV 格式:
|
||||
@@ -109,118 +83,336 @@ DocumentId,InvoiceDate,InvoiceNumber,InvoiceDueDate,OCR,Bankgiro,Plusgiro,Amount
|
||||
### 2. 自动标注
|
||||
|
||||
```bash
|
||||
# 使用双池模式 (CPU + GPU)
|
||||
python -m src.cli.autolabel \
|
||||
--csv data/structured_data/invoices.csv \
|
||||
--pdf-dir data/raw_pdfs \
|
||||
--output data/dataset \
|
||||
--report reports/autolabel_report.jsonl
|
||||
--dual-pool \
|
||||
--cpu-workers 3 \
|
||||
--gpu-workers 1
|
||||
|
||||
# 单线程模式
|
||||
python -m src.cli.autolabel --workers 4
|
||||
```
|
||||
|
||||
### 3. 训练模型
|
||||
|
||||
> **重要**: 务必使用 GPU 进行训练!CPU 训练速度非常慢。
|
||||
|
||||
```bash
|
||||
# GPU 训练 (强烈推荐)
|
||||
# 从预训练模型开始训练
|
||||
python -m src.cli.train \
|
||||
--data data/dataset/dataset.yaml \
|
||||
--model yolo11n.pt \
|
||||
--epochs 100 \
|
||||
--batch 16 \
|
||||
--device 0 # 使用 GPU
|
||||
|
||||
# 验证 GPU 可用
|
||||
python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else None}')"
|
||||
--name invoice_yolo11n_full \
|
||||
--dpi 150
|
||||
```
|
||||
|
||||
GPU vs CPU 训练时间对比 (100 epochs, 77 训练图片):
|
||||
- **GPU (RTX 5080)**: ~2 分钟
|
||||
- **CPU**: 30+ 分钟
|
||||
### 4. 增量训练
|
||||
|
||||
### 4. 推理
|
||||
当添加新数据后,可以在已训练模型基础上继续训练:
|
||||
|
||||
```bash
|
||||
# 从已训练的 best.pt 继续训练
|
||||
python -m src.cli.train \
|
||||
--model runs/train/invoice_yolo11n_full/weights/best.pt \
|
||||
--epochs 30 \
|
||||
--batch 16 \
|
||||
--name invoice_yolo11n_v2 \
|
||||
--dpi 150
|
||||
```
|
||||
|
||||
**增量训练建议**:
|
||||
|
||||
| 场景 | 建议 |
|
||||
|------|------|
|
||||
| 添加少量新数据 (<20%) | 继续训练 10-30 epochs |
|
||||
| 添加大量新数据 (>50%) | 继续训练 50-100 epochs |
|
||||
| 修正大量标注错误 | 从头训练 |
|
||||
| 添加新的字段类型 | 从头训练 |
|
||||
|
||||
### 5. 推理
|
||||
|
||||
```bash
|
||||
# 命令行推理
|
||||
python -m src.cli.infer \
|
||||
--model runs/train/invoice_fields/weights/best.pt \
|
||||
--model runs/train/invoice_yolo11n_full/weights/best.pt \
|
||||
--input path/to/invoice.pdf \
|
||||
--output result.json
|
||||
--output result.json \
|
||||
--gpu
|
||||
```
|
||||
|
||||
## 输出示例
|
||||
### 6. Web 应用
|
||||
|
||||
```json
|
||||
{
|
||||
"DocumentId": "3be53fd7-d5ea-458c-a229-8d360b8ba6a9",
|
||||
"InvoiceNumber": "100017500321",
|
||||
"InvoiceDate": "2025-12-13",
|
||||
"InvoiceDueDate": "2026-01-03",
|
||||
"OCR": "100017500321",
|
||||
"Bankgiro": "5393-9484",
|
||||
"Plusgiro": null,
|
||||
"Amount": "114.00",
|
||||
"confidence": {
|
||||
"InvoiceNumber": 0.96,
|
||||
"InvoiceDate": 0.92,
|
||||
"Amount": 0.93
|
||||
}
|
||||
}
|
||||
```bash
|
||||
# 启动 Web 服务器
|
||||
python run_server.py --port 8000
|
||||
|
||||
# 开发模式 (自动重载)
|
||||
python run_server.py --debug --reload
|
||||
|
||||
# 禁用 GPU
|
||||
python run_server.py --no-gpu
|
||||
```
|
||||
|
||||
访问 **http://localhost:8000** 使用 Web 界面。
|
||||
|
||||
#### Web API 端点
|
||||
|
||||
| 方法 | 端点 | 描述 |
|
||||
|------|------|------|
|
||||
| GET | `/` | Web UI 界面 |
|
||||
| GET | `/api/v1/health` | 健康检查 |
|
||||
| POST | `/api/v1/infer` | 上传文件并推理 |
|
||||
| GET | `/api/v1/results/{filename}` | 获取可视化图片 |
|
||||
|
||||
## 训练配置
|
||||
|
||||
### YOLO 训练参数
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型 (默认: yolo11n.pt)
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数 (用于测试)
|
||||
--device 设备 (0=GPU, cpu)
|
||||
```
|
||||
|
||||
### 训练最佳实践
|
||||
|
||||
1. **禁用翻转增强** (文本检测):
|
||||
```python
|
||||
fliplr=0.0, flipud=0.0
|
||||
```
|
||||
|
||||
2. **使用 Early Stopping**:
|
||||
```python
|
||||
patience=20
|
||||
```
|
||||
|
||||
3. **启用 AMP** (混合精度训练):
|
||||
```python
|
||||
amp=True
|
||||
```
|
||||
|
||||
4. **保存检查点**:
|
||||
```python
|
||||
save_period=10
|
||||
```
|
||||
|
||||
### 训练结果示例
|
||||
|
||||
使用 15,571 张训练图片,100 epochs 后的结果:
|
||||
|
||||
| 指标 | 值 |
|
||||
|------|-----|
|
||||
| **mAP@0.5** | 98.7% |
|
||||
| **mAP@0.5-0.95** | 87.4% |
|
||||
| **Precision** | 97.5% |
|
||||
| **Recall** | 95.5% |
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
invoice-master-poc-v2/
|
||||
├── src/
|
||||
│ ├── pdf/ # PDF 处理模块
|
||||
│ ├── ocr/ # OCR 提取模块
|
||||
│ ├── normalize/ # 字段规范化模块
|
||||
│ ├── matcher/ # 字段匹配模块
|
||||
│ ├── yolo/ # YOLO 标注生成
|
||||
│ ├── cli/ # 命令行工具
|
||||
│ │ ├── autolabel.py # 自动标注
|
||||
│ │ ├── train.py # 模型训练
|
||||
│ │ ├── infer.py # 推理
|
||||
│ │ └── serve.py # Web 服务器
|
||||
│ ├── pdf/ # PDF 处理
|
||||
│ │ ├── extractor.py # 文本提取
|
||||
│ │ ├── renderer.py # 图像渲染
|
||||
│ │ └── detector.py # 类型检测
|
||||
│ ├── ocr/ # PaddleOCR 封装
|
||||
│ ├── normalize/ # 字段规范化
|
||||
│ ├── matcher/ # 字段匹配
|
||||
│ ├── yolo/ # YOLO 相关
|
||||
│ │ ├── annotation_generator.py
|
||||
│ │ └── db_dataset.py
|
||||
│ ├── inference/ # 推理管道
|
||||
│ ├── data/ # 数据加载模块
|
||||
│ └── cli/ # 命令行工具
|
||||
├── configs/ # 配置文件
|
||||
├── data/ # 数据目录
|
||||
│ │ ├── pipeline.py
|
||||
│ │ ├── yolo_detector.py
|
||||
│ │ └── field_extractor.py
|
||||
│ ├── processing/ # 多池处理架构
|
||||
│ │ ├── worker_pool.py
|
||||
│ │ ├── cpu_pool.py
|
||||
│ │ ├── gpu_pool.py
|
||||
│ │ ├── task_dispatcher.py
|
||||
│ │ └── dual_pool_coordinator.py
|
||||
│ ├── web/ # Web 应用
|
||||
│ │ ├── app.py # FastAPI 应用
|
||||
│ │ ├── routes.py # API 路由
|
||||
│ │ ├── services.py # 业务逻辑
|
||||
│ │ ├── schemas.py # 数据模型
|
||||
│ │ └── config.py # 配置
|
||||
│ └── data/ # 数据处理
|
||||
├── config.py # 配置文件
|
||||
├── run_server.py # Web 服务器启动脚本
|
||||
├── runs/ # 训练输出
|
||||
│ └── train/
|
||||
│ └── invoice_yolo11n_full/
|
||||
│ └── weights/
|
||||
│ ├── best.pt
|
||||
│ └── last.pt
|
||||
└── requirements.txt
|
||||
```
|
||||
|
||||
## 开发优先级
|
||||
## 多池处理架构
|
||||
|
||||
1. ✅ 文本层 PDF 自动标注
|
||||
2. ✅ 扫描图 OCR 自动标注
|
||||
3. 🔄 金额 / OCR / Bankgiro 三字段稳定
|
||||
4. ⏳ 日期、Plusgiro 扩展
|
||||
5. ⏳ 表格 items 处理
|
||||
项目使用 CPU + GPU 双池架构处理不同类型的 PDF:
|
||||
|
||||
## 配置
|
||||
```
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ DualPoolCoordinator │
|
||||
│ ┌─────────────────┐ ┌─────────────────┐ │
|
||||
│ │ CPU Pool │ │ GPU Pool │ │
|
||||
│ │ (3 workers) │ │ (1 worker) │ │
|
||||
│ │ │ │ │ │
|
||||
│ │ Text PDFs │ │ Scanned PDFs │ │
|
||||
│ │ ~50-87 it/s │ │ ~1-2 it/s │ │
|
||||
│ └─────────────────┘ └─────────────────┘ │
|
||||
│ │
|
||||
│ TaskDispatcher: 根据 PDF 类型分配任务 │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
编辑 `configs/default.yaml` 自定义:
|
||||
- PDF 渲染 DPI
|
||||
- OCR 语言
|
||||
- 匹配置信度阈值
|
||||
- 上下文关键词
|
||||
- 数据增强参数
|
||||
### 关键设计
|
||||
|
||||
## API 使用
|
||||
- **spawn 启动方式**: 兼容 CUDA 多进程
|
||||
- **as_completed()**: 无死锁结果收集
|
||||
- **进程初始化器**: 每个 worker 加载一次模型
|
||||
- **协调器持久化**: 跨 CSV 文件复用 worker 池
|
||||
|
||||
## 配置文件
|
||||
|
||||
### config.py
|
||||
|
||||
```python
|
||||
# 数据库配置
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31',
|
||||
'port': 5432,
|
||||
'database': 'docmaster',
|
||||
'user': 'docmaster',
|
||||
'password': '******',
|
||||
}
|
||||
|
||||
# 路径配置
|
||||
PATHS = {
|
||||
'csv_dir': '~/invoice-data/structured_data',
|
||||
'pdf_dir': '~/invoice-data/raw_pdfs',
|
||||
'output_dir': '~/invoice-data/dataset',
|
||||
}
|
||||
```
|
||||
|
||||
## CLI 命令参考
|
||||
|
||||
### autolabel
|
||||
|
||||
```bash
|
||||
python -m src.cli.autolabel [OPTIONS]
|
||||
|
||||
Options:
|
||||
--csv, -c CSV 文件路径 (支持 glob)
|
||||
--pdf-dir, -p PDF 文件目录
|
||||
--output, -o 输出目录
|
||||
--workers, -w 单线程模式 worker 数 (默认: 4)
|
||||
--dual-pool 启用双池模式
|
||||
--cpu-workers CPU 池 worker 数 (默认: 3)
|
||||
--gpu-workers GPU 池 worker 数 (默认: 1)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--limit, -l 限制处理文档数
|
||||
```
|
||||
|
||||
### train
|
||||
|
||||
```bash
|
||||
python -m src.cli.train [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 基础模型路径
|
||||
--epochs, -e 训练轮数 (默认: 100)
|
||||
--batch, -b 批大小 (默认: 16)
|
||||
--imgsz 图像尺寸 (默认: 1280)
|
||||
--dpi PDF 渲染 DPI (默认: 150)
|
||||
--name 训练名称
|
||||
--limit 限制文档数
|
||||
```
|
||||
|
||||
### infer
|
||||
|
||||
```bash
|
||||
python -m src.cli.infer [OPTIONS]
|
||||
|
||||
Options:
|
||||
--model, -m 模型路径
|
||||
--input, -i 输入 PDF/图像
|
||||
--output, -o 输出 JSON 路径
|
||||
--confidence 置信度阈值 (默认: 0.5)
|
||||
--dpi 渲染 DPI (默认: 300)
|
||||
--gpu 使用 GPU
|
||||
```
|
||||
|
||||
### serve
|
||||
|
||||
```bash
|
||||
python run_server.py [OPTIONS]
|
||||
|
||||
Options:
|
||||
--host 绑定地址 (默认: 0.0.0.0)
|
||||
--port 端口 (默认: 8000)
|
||||
--model, -m 模型路径
|
||||
--confidence 置信度阈值 (默认: 0.3)
|
||||
--dpi 渲染 DPI (默认: 150)
|
||||
--no-gpu 禁用 GPU
|
||||
--reload 开发模式自动重载
|
||||
--debug 调试模式
|
||||
```
|
||||
|
||||
## Python API
|
||||
|
||||
```python
|
||||
from src.inference import InferencePipeline
|
||||
|
||||
# 初始化
|
||||
pipeline = InferencePipeline(
|
||||
model_path='models/best.pt',
|
||||
confidence_threshold=0.5,
|
||||
ocr_lang='en'
|
||||
model_path='runs/train/invoice_yolo11n_full/weights/best.pt',
|
||||
confidence_threshold=0.3,
|
||||
use_gpu=True,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
# 处理 PDF
|
||||
result = pipeline.process_pdf('invoice.pdf')
|
||||
|
||||
# 获取字段
|
||||
print(result.fields)
|
||||
print(result.confidence)
|
||||
# 处理图片
|
||||
result = pipeline.process_image('invoice.png')
|
||||
|
||||
# 获取结果
|
||||
print(result.fields) # {'InvoiceNumber': '12345', 'Amount': '1234.56', ...}
|
||||
print(result.confidence) # {'InvoiceNumber': 0.95, 'Amount': 0.92, ...}
|
||||
print(result.to_json()) # JSON 格式输出
|
||||
```
|
||||
|
||||
## 开发状态
|
||||
|
||||
- [x] 文本层 PDF 自动标注
|
||||
- [x] 扫描图 OCR 自动标注
|
||||
- [x] 多池处理架构 (CPU + GPU)
|
||||
- [x] PostgreSQL 数据库存储
|
||||
- [x] YOLO 训练 (98.7% mAP@0.5)
|
||||
- [x] 推理管道
|
||||
- [x] 字段规范化和验证
|
||||
- [x] Web 应用 (FastAPI + 前端 UI)
|
||||
- [x] 增量训练支持
|
||||
- [ ] 表格 items 处理
|
||||
- [ ] 模型量化部署
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT License
|
||||
|
||||
216
claude.md
Normal file
216
claude.md
Normal file
@@ -0,0 +1,216 @@
|
||||
# Claude Code Instructions - Invoice Master POC v2
|
||||
|
||||
## Environment Requirements
|
||||
|
||||
> **IMPORTANT**: This project MUST run in **WSL + Conda** environment.
|
||||
|
||||
| Requirement | Details |
|
||||
|-------------|---------|
|
||||
| **WSL** | WSL 2 with Ubuntu 22.04+ |
|
||||
| **Conda** | Miniconda or Anaconda |
|
||||
| **Python** | 3.10+ (managed by Conda) |
|
||||
| **GPU** | NVIDIA drivers on Windows + CUDA in WSL |
|
||||
|
||||
```bash
|
||||
# Verify environment before running any commands
|
||||
uname -a # Should show "Linux"
|
||||
conda --version # Should show conda version
|
||||
conda activate <env> # Activate project environment
|
||||
which python # Should point to conda environment
|
||||
```
|
||||
|
||||
**All commands must be executed in WSL terminal with Conda environment activated.**
|
||||
|
||||
---
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Automated invoice field extraction system** for Swedish PDF invoices:
|
||||
- **YOLO Object Detection** (YOLOv8/v11) for field region detection
|
||||
- **PaddleOCR** for text extraction
|
||||
- **Multi-strategy matching** for field validation
|
||||
|
||||
**Stack**: Python 3.10+ | PyTorch | Ultralytics | PaddleOCR | PyMuPDF
|
||||
|
||||
**Target Fields**: InvoiceNumber, InvoiceDate, InvoiceDueDate, OCR, Bankgiro, Plusgiro, Amount
|
||||
|
||||
---
|
||||
|
||||
## Architecture Principles
|
||||
|
||||
### SOLID
|
||||
- **Single Responsibility**: Each module handles one concern
|
||||
- **Open/Closed**: Extend via new strategies, not modifying existing code
|
||||
- **Liskov Substitution**: Use Protocol/ABC for interchangeable components
|
||||
- **Interface Segregation**: Small, focused interfaces
|
||||
- **Dependency Inversion**: Depend on abstractions, inject dependencies
|
||||
|
||||
### Project Structure
|
||||
```
|
||||
src/
|
||||
├── cli/ # Entry points only, no business logic
|
||||
├── pdf/ # PDF processing (extraction, rendering, detection)
|
||||
├── ocr/ # OCR engines (PaddleOCR wrapper)
|
||||
├── normalize/ # Field normalization and validation
|
||||
├── matcher/ # Multi-strategy field matching
|
||||
├── yolo/ # YOLO annotation and dataset building
|
||||
├── inference/ # Inference pipeline
|
||||
└── data/ # Data loading and reporting
|
||||
```
|
||||
|
||||
### Configuration
|
||||
- `configs/default.yaml` — All tunable parameters
|
||||
- `config.py` — Sensitive data (credentials, use environment variables)
|
||||
- Never hardcode magic numbers
|
||||
|
||||
---
|
||||
|
||||
## Python Standards
|
||||
|
||||
### Required
|
||||
- **Type hints** on all public functions (PEP 484/585)
|
||||
- **Docstrings** in Google style (PEP 257)
|
||||
- **Dataclasses** for data structures (`frozen=True, slots=True` when immutable)
|
||||
- **Protocol** for interfaces (PEP 544)
|
||||
- **Enum** for constants
|
||||
- **pathlib.Path** instead of string paths
|
||||
|
||||
### Naming Conventions
|
||||
| Type | Convention | Example |
|
||||
|------|------------|---------|
|
||||
| Functions/Variables | snake_case | `extract_tokens`, `page_count` |
|
||||
| Classes | PascalCase | `FieldMatcher`, `AutoLabelReport` |
|
||||
| Constants | UPPER_SNAKE | `DEFAULT_DPI`, `FIELD_TYPES` |
|
||||
| Private | _prefix | `_parse_date`, `_cache` |
|
||||
|
||||
### Import Order (isort)
|
||||
1. `from __future__ import annotations`
|
||||
2. Standard library
|
||||
3. Third-party
|
||||
4. Local modules
|
||||
5. `if TYPE_CHECKING:` block
|
||||
|
||||
### Code Quality Tools
|
||||
| Tool | Purpose | Config |
|
||||
|------|---------|--------|
|
||||
| Black | Formatting | line-length=100 |
|
||||
| Ruff | Linting | E, F, W, I, N, D, UP, B, C4, SIM, ARG, PTH |
|
||||
| MyPy | Type checking | strict=true |
|
||||
| Pytest | Testing | tests/ directory |
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Use **custom exception hierarchy** (base: `InvoiceMasterError`)
|
||||
- Use **logging** instead of print (`logger = logging.getLogger(__name__)`)
|
||||
- Implement **graceful degradation** with fallback strategies
|
||||
- Use **context managers** for resource cleanup
|
||||
|
||||
---
|
||||
|
||||
## Machine Learning Standards
|
||||
|
||||
### Data Management
|
||||
- **Immutable raw data**: Never modify `data/raw/`
|
||||
- **Version datasets**: Track with checksum and metadata
|
||||
- **Reproducible splits**: Use fixed random seed (42)
|
||||
- **Split ratios**: 80% train / 10% val / 10% test
|
||||
|
||||
### YOLO Training
|
||||
- **Disable flips** for text detection (`fliplr=0.0, flipud=0.0`)
|
||||
- **Use early stopping** (`patience=20`)
|
||||
- **Enable AMP** for faster training (`amp=true`)
|
||||
- **Save checkpoints** periodically (`save_period=10`)
|
||||
|
||||
### Reproducibility
|
||||
- Set random seeds: `random`, `numpy`, `torch`
|
||||
- Enable deterministic mode: `torch.backends.cudnn.deterministic = True`
|
||||
- Track experiment config: model, epochs, batch_size, learning_rate, dataset_version, git_commit
|
||||
|
||||
### Evaluation Metrics
|
||||
- Precision, Recall, F1 Score
|
||||
- mAP@0.5, mAP@0.5:0.95
|
||||
- Per-class AP
|
||||
|
||||
---
|
||||
|
||||
## Testing Standards
|
||||
|
||||
### Structure
|
||||
```
|
||||
tests/
|
||||
├── unit/ # Isolated, fast tests
|
||||
├── integration/ # Multi-module tests
|
||||
├── e2e/ # End-to-end workflow tests
|
||||
├── fixtures/ # Test data
|
||||
└── conftest.py # Shared fixtures
|
||||
```
|
||||
|
||||
### Practices
|
||||
- Follow **AAA pattern**: Arrange, Act, Assert
|
||||
- Use **parametrized tests** for multiple inputs
|
||||
- Use **fixtures** for shared setup
|
||||
- Use **mocking** for external dependencies
|
||||
- Mark slow tests with `@pytest.mark.slow`
|
||||
|
||||
---
|
||||
|
||||
## Performance
|
||||
|
||||
- **Parallel processing**: Use `ProcessPoolExecutor` with progress tracking
|
||||
- **Lazy loading**: Use `@cached_property` for expensive resources
|
||||
- **Generators**: Use for large datasets to save memory
|
||||
- **Batch processing**: Process items in batches when possible
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
- **Never commit**: credentials, API keys, `.env` files
|
||||
- **Use environment variables** for sensitive config
|
||||
- **Validate paths**: Prevent path traversal attacks
|
||||
- **Validate inputs**: At system boundaries
|
||||
|
||||
---
|
||||
|
||||
## Commands
|
||||
|
||||
| Task | Command |
|
||||
|------|---------|
|
||||
| Run autolabel | `python run_autolabel.py` |
|
||||
| Train YOLO | `python -m src.cli.train --config configs/training.yaml` |
|
||||
| Run inference | `python -m src.cli.infer --model models/best.pt` |
|
||||
| Run tests | `pytest tests/ -v` |
|
||||
| Coverage | `pytest tests/ --cov=src --cov-report=html` |
|
||||
| Format | `black src/ tests/` |
|
||||
| Lint | `ruff check src/ tests/ --fix` |
|
||||
| Type check | `mypy src/` |
|
||||
|
||||
---
|
||||
|
||||
## DO NOT
|
||||
|
||||
- Hardcode file paths or magic numbers
|
||||
- Use `print()` for logging
|
||||
- Skip type hints on public APIs
|
||||
- Write functions longer than 50 lines
|
||||
- Mix business logic with I/O
|
||||
- Commit credentials or `.env` files
|
||||
- Use `# type: ignore` without explanation
|
||||
- Use mutable default arguments
|
||||
- Catch bare `except:`
|
||||
- Use flip augmentation for text detection
|
||||
|
||||
## DO
|
||||
|
||||
- Use type hints everywhere
|
||||
- Write descriptive docstrings
|
||||
- Log with appropriate levels
|
||||
- Use dataclasses for data structures
|
||||
- Use enums for constants
|
||||
- Use Protocol for interfaces
|
||||
- Set random seeds for reproducibility
|
||||
- Track experiment configurations
|
||||
- Use context managers for resources
|
||||
- Validate inputs at boundaries
|
||||
64
config.py
Normal file
64
config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Configuration settings for the invoice extraction system.
|
||||
"""
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
|
||||
def _is_wsl() -> bool:
|
||||
"""Check if running inside WSL (Windows Subsystem for Linux)."""
|
||||
if platform.system() != 'Linux':
|
||||
return False
|
||||
# Check for WSL-specific indicators
|
||||
if os.environ.get('WSL_DISTRO_NAME'):
|
||||
return True
|
||||
try:
|
||||
with open('/proc/version', 'r') as f:
|
||||
return 'microsoft' in f.read().lower()
|
||||
except (FileNotFoundError, PermissionError):
|
||||
return False
|
||||
|
||||
|
||||
# PostgreSQL Database Configuration
|
||||
DATABASE = {
|
||||
'host': '192.168.68.31',
|
||||
'port': 5432,
|
||||
'database': 'docmaster',
|
||||
'user': 'docmaster',
|
||||
'password': '0412220',
|
||||
}
|
||||
|
||||
# Connection string for psycopg2
|
||||
def get_db_connection_string():
|
||||
return f"postgresql://{DATABASE['user']}:{DATABASE['password']}@{DATABASE['host']}:{DATABASE['port']}/{DATABASE['database']}"
|
||||
|
||||
|
||||
# Paths Configuration - auto-detect WSL vs Windows
|
||||
if _is_wsl():
|
||||
# WSL: use native Linux filesystem for better I/O performance
|
||||
PATHS = {
|
||||
'csv_dir': os.path.expanduser('~/invoice-data/structured_data'),
|
||||
'pdf_dir': os.path.expanduser('~/invoice-data/raw_pdfs'),
|
||||
'output_dir': os.path.expanduser('~/invoice-data/dataset'),
|
||||
'reports_dir': 'reports', # Keep reports in project directory
|
||||
}
|
||||
else:
|
||||
# Windows or native Linux: use relative paths
|
||||
PATHS = {
|
||||
'csv_dir': 'data/structured_data',
|
||||
'pdf_dir': 'data/raw_pdfs',
|
||||
'output_dir': 'data/dataset',
|
||||
'reports_dir': 'reports',
|
||||
}
|
||||
|
||||
# Auto-labeling Configuration
|
||||
AUTOLABEL = {
|
||||
'workers': 2,
|
||||
'dpi': 150,
|
||||
'min_confidence': 0.5,
|
||||
'train_ratio': 0.8,
|
||||
'val_ratio': 0.1,
|
||||
'test_ratio': 0.1,
|
||||
'max_records_per_report': 10000,
|
||||
}
|
||||
619
docs/multi_pool_design.md
Normal file
619
docs/multi_pool_design.md
Normal file
@@ -0,0 +1,619 @@
|
||||
# 多池处理架构设计文档
|
||||
|
||||
## 1. 研究总结
|
||||
|
||||
### 1.1 当前问题分析
|
||||
|
||||
我们之前实现的双池模式存在稳定性问题,主要原因:
|
||||
|
||||
| 问题 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| 处理卡住 | 线程 + ProcessPoolExecutor 混用导致死锁 | 使用 asyncio 或纯 Queue 模式 |
|
||||
| Queue.get() 无限阻塞 | 没有超时机制 | 添加 timeout 和哨兵值 |
|
||||
| GPU 内存冲突 | 多进程同时访问 GPU | 限制 GPU worker = 1 |
|
||||
| CUDA fork 问题 | Linux 默认 fork 不兼容 CUDA | 使用 spawn 启动方式 |
|
||||
|
||||
### 1.2 推荐架构方案
|
||||
|
||||
经过研究,最适合我们场景的方案是 **生产者-消费者队列模式**:
|
||||
|
||||
```
|
||||
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
|
||||
│ Main Process │ │ CPU Workers │ │ GPU Worker │
|
||||
│ │ │ (4 processes) │ │ (1 process) │
|
||||
│ ┌───────────┐ │ │ │ │ │
|
||||
│ │ Task │──┼────▶│ Text PDF处理 │ │ Scanned PDF处理 │
|
||||
│ │ Dispatcher│ │ │ (无需OCR) │ │ (PaddleOCR) │
|
||||
│ └───────────┘ │ │ │ │ │
|
||||
│ ▲ │ │ │ │ │ │ │
|
||||
│ │ │ │ ▼ │ │ ▼ │
|
||||
│ ┌───────────┐ │ │ Result Queue │ │ Result Queue │
|
||||
│ │ Result │◀─┼─────│◀────────────────│─────│◀────────────────│
|
||||
│ │ Collector │ │ │ │ │ │
|
||||
│ └───────────┘ │ └─────────────────┘ └─────────────────┘
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────┐ │
|
||||
│ │ Database │ │
|
||||
│ │ Batch │ │
|
||||
│ │ Writer │ │
|
||||
│ └───────────┘ │
|
||||
└─────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 2. 核心设计原则
|
||||
|
||||
### 2.1 CUDA 兼容性
|
||||
|
||||
```python
|
||||
# 关键:使用 spawn 启动方式
|
||||
import multiprocessing as mp
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
# GPU worker 初始化时设置设备
|
||||
def init_gpu_worker(gpu_id: int = 0):
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
global _ocr
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=True, ...)
|
||||
```
|
||||
|
||||
### 2.2 Worker 初始化模式
|
||||
|
||||
使用 `initializer` 参数一次性加载模型,避免每个任务重新加载:
|
||||
|
||||
```python
|
||||
# 全局变量保存模型
|
||||
_ocr = None
|
||||
|
||||
def init_worker(use_gpu: bool, gpu_id: int = 0):
|
||||
global _ocr
|
||||
if use_gpu:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
else:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(use_gpu=use_gpu, ...)
|
||||
|
||||
# 创建 Pool 时使用 initializer
|
||||
pool = ProcessPoolExecutor(
|
||||
max_workers=1,
|
||||
initializer=init_worker,
|
||||
initargs=(True, 0), # use_gpu=True, gpu_id=0
|
||||
mp_context=mp.get_context("spawn")
|
||||
)
|
||||
```
|
||||
|
||||
### 2.3 队列模式 vs as_completed
|
||||
|
||||
| 方式 | 优点 | 缺点 | 适用场景 |
|
||||
|------|------|------|----------|
|
||||
| `as_completed()` | 简单、无需管理队列 | 无法跨多个 Pool 使用 | 单池场景 |
|
||||
| `multiprocessing.Queue` | 高性能、灵活 | 需要手动管理、死锁风险 | 多池流水线 |
|
||||
| `Manager().Queue()` | 可 pickle、跨 Pool | 性能较低 | 需要 Pool.map 场景 |
|
||||
|
||||
**推荐**:对于双池场景,使用 `as_completed()` 分别处理每个池,然后合并结果。
|
||||
|
||||
---
|
||||
|
||||
## 3. 详细开发计划
|
||||
|
||||
### 阶段 1:重构基础架构 (2-3天)
|
||||
|
||||
#### 1.1 创建 WorkerPool 抽象类
|
||||
|
||||
```python
|
||||
# src/processing/worker_pool.py
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, Future
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Any, Optional, Callable
|
||||
import multiprocessing as mp
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""任务结果容器"""
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""Worker Pool 抽象基类"""
|
||||
|
||||
def __init__(self, max_workers: int, use_gpu: bool = False, gpu_id: int = 0):
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Callable:
|
||||
"""返回 worker 初始化函数"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""返回初始化参数"""
|
||||
pass
|
||||
|
||||
def start(self):
|
||||
"""启动 worker pool"""
|
||||
ctx = mp.get_context("spawn")
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=self.get_initializer(),
|
||||
initargs=self.get_init_args()
|
||||
)
|
||||
|
||||
def submit(self, fn: Callable, *args, **kwargs) -> Future:
|
||||
"""提交任务"""
|
||||
if not self._executor:
|
||||
raise RuntimeError("Pool not started")
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True):
|
||||
"""关闭 pool"""
|
||||
if self._executor:
|
||||
self._executor.shutdown(wait=wait)
|
||||
self._executor = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.shutdown()
|
||||
```
|
||||
|
||||
#### 1.2 实现 CPU 和 GPU Worker Pool
|
||||
|
||||
```python
|
||||
# src/processing/cpu_pool.py
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""CPU-only worker pool for text PDF processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
super().__init__(max_workers=max_workers, use_gpu=False)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return ()
|
||||
|
||||
# src/processing/gpu_pool.py
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""GPU worker pool for OCR processing"""
|
||||
|
||||
def __init__(self, max_workers: int = 1, gpu_id: int = 0):
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Callable:
|
||||
return init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
return (self.gpu_id,)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 2:实现双池协调器 (2-3天)
|
||||
|
||||
#### 2.1 任务分发器
|
||||
|
||||
```python
|
||||
# src/processing/task_dispatcher.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
class TaskType(Enum):
|
||||
CPU = auto() # Text PDF
|
||||
GPU = auto() # Scanned PDF
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
class TaskDispatcher:
|
||||
"""根据 PDF 类型分发任务到不同的 pool"""
|
||||
|
||||
def classify_task(self, doc_info: dict) -> TaskType:
|
||||
"""判断文档是否需要 OCR"""
|
||||
# 基于 PDF 特征判断
|
||||
if self._is_scanned_pdf(doc_info):
|
||||
return TaskType.GPU
|
||||
return TaskType.CPU
|
||||
|
||||
def _is_scanned_pdf(self, doc_info: dict) -> bool:
|
||||
"""检测是否为扫描件"""
|
||||
# 1. 检查是否有可提取文本
|
||||
# 2. 检查图片比例
|
||||
# 3. 检查文本密度
|
||||
pass
|
||||
|
||||
def partition_tasks(self, tasks: List[Task]) -> Tuple[List[Task], List[Task]]:
|
||||
"""将任务分为 CPU 和 GPU 两组"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
return cpu_tasks, gpu_tasks
|
||||
```
|
||||
|
||||
#### 2.2 双池协调器
|
||||
|
||||
```python
|
||||
# src/processing/dual_pool_coordinator.py
|
||||
|
||||
from concurrent.futures import as_completed
|
||||
from typing import List, Iterator
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""协调 CPU 和 GPU 两个 worker pool"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0
|
||||
):
|
||||
self.cpu_pool = CPUWorkerPool(max_workers=cpu_workers)
|
||||
self.gpu_pool = GPUWorkerPool(max_workers=gpu_workers, gpu_id=gpu_id)
|
||||
self.dispatcher = TaskDispatcher()
|
||||
|
||||
def __enter__(self):
|
||||
self.cpu_pool.start()
|
||||
self.gpu_pool.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args):
|
||||
self.cpu_pool.shutdown()
|
||||
self.gpu_pool.shutdown()
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable,
|
||||
gpu_task_fn: Callable,
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
处理一批文档,自动分发到 CPU 或 GPU pool
|
||||
|
||||
Args:
|
||||
documents: 待处理文档列表
|
||||
cpu_task_fn: CPU 任务处理函数
|
||||
gpu_task_fn: GPU 任务处理函数
|
||||
on_result: 结果回调(可选)
|
||||
on_error: 错误回调(可选)
|
||||
|
||||
Returns:
|
||||
所有任务结果列表
|
||||
"""
|
||||
# 分类任务
|
||||
tasks = [
|
||||
Task(id=doc['id'], task_type=self.dispatcher.classify_task(doc), data=doc)
|
||||
for doc in documents
|
||||
]
|
||||
cpu_tasks, gpu_tasks = self.dispatcher.partition_tasks(tasks)
|
||||
|
||||
logger.info(f"Task partition: {len(cpu_tasks)} CPU, {len(gpu_tasks)} GPU")
|
||||
|
||||
# 提交任务到各自的 pool
|
||||
cpu_futures = {
|
||||
self.cpu_pool.submit(cpu_task_fn, t.data): t.id
|
||||
for t in cpu_tasks
|
||||
}
|
||||
gpu_futures = {
|
||||
self.gpu_pool.submit(gpu_task_fn, t.data): t.id
|
||||
for t in gpu_tasks
|
||||
}
|
||||
|
||||
# 收集结果
|
||||
results = []
|
||||
all_futures = list(cpu_futures.keys()) + list(gpu_futures.keys())
|
||||
|
||||
for future in as_completed(all_futures):
|
||||
task_id = cpu_futures.get(future) or gpu_futures.get(future)
|
||||
pool_type = "CPU" if future in cpu_futures else "GPU"
|
||||
|
||||
try:
|
||||
data = future.result(timeout=300) # 5分钟超时
|
||||
result = TaskResult(task_id=task_id, success=True, data=data)
|
||||
if on_result:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.error(f"[{pool_type}] Task {task_id} failed: {e}")
|
||||
result = TaskResult(task_id=task_id, success=False, data=None, error=str(e))
|
||||
if on_error:
|
||||
on_error(task_id, e)
|
||||
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 3:集成到 autolabel (1-2天)
|
||||
|
||||
#### 3.1 修改 autolabel.py
|
||||
|
||||
```python
|
||||
# src/cli/autolabel.py
|
||||
|
||||
def run_autolabel_dual_pool(args):
|
||||
"""使用双池模式运行自动标注"""
|
||||
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
# 初始化数据库批处理
|
||||
db_batch = []
|
||||
db_batch_size = 100
|
||||
|
||||
def on_result(result: TaskResult):
|
||||
"""处理成功结果"""
|
||||
nonlocal db_batch
|
||||
db_batch.append(result.data)
|
||||
|
||||
if len(db_batch) >= db_batch_size:
|
||||
save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
def on_error(task_id: str, error: Exception):
|
||||
"""处理错误"""
|
||||
logger.error(f"Task {task_id} failed: {error}")
|
||||
|
||||
# 创建双池协调器
|
||||
with DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers or 4,
|
||||
gpu_workers=args.gpu_workers or 1,
|
||||
gpu_id=0
|
||||
) as coordinator:
|
||||
|
||||
# 处理所有 CSV
|
||||
for csv_file in csv_files:
|
||||
documents = load_documents_from_csv(csv_file)
|
||||
|
||||
results = coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error
|
||||
)
|
||||
|
||||
logger.info(f"CSV {csv_file}: {len(results)} processed")
|
||||
|
||||
# 保存剩余批次
|
||||
if db_batch:
|
||||
save_documents_batch(db_batch)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### 阶段 4:测试与验证 (1-2天)
|
||||
|
||||
#### 4.1 单元测试
|
||||
|
||||
```python
|
||||
# tests/unit/test_dual_pool.py
|
||||
|
||||
import pytest
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator, TaskResult
|
||||
|
||||
class TestDualPoolCoordinator:
|
||||
|
||||
def test_cpu_only_batch(self):
|
||||
"""测试纯 CPU 任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [{"id": f"doc_{i}", "type": "text"} for i in range(10)]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 10
|
||||
assert all(r.success for r in results)
|
||||
|
||||
def test_mixed_batch(self):
|
||||
"""测试混合任务批处理"""
|
||||
with DualPoolCoordinator(cpu_workers=2, gpu_workers=1) as coord:
|
||||
docs = [
|
||||
{"id": "text_1", "type": "text"},
|
||||
{"id": "scan_1", "type": "scanned"},
|
||||
{"id": "text_2", "type": "text"},
|
||||
]
|
||||
results = coord.process_batch(docs, cpu_fn, gpu_fn)
|
||||
assert len(results) == 3
|
||||
|
||||
def test_timeout_handling(self):
|
||||
"""测试超时处理"""
|
||||
pass
|
||||
|
||||
def test_error_recovery(self):
|
||||
"""测试错误恢复"""
|
||||
pass
|
||||
```
|
||||
|
||||
#### 4.2 集成测试
|
||||
|
||||
```python
|
||||
# tests/integration/test_autolabel_dual_pool.py
|
||||
|
||||
def test_autolabel_with_dual_pool():
|
||||
"""端到端测试双池模式"""
|
||||
# 使用少量测试数据
|
||||
result = subprocess.run([
|
||||
"python", "-m", "src.cli.autolabel",
|
||||
"--cpu-workers", "2",
|
||||
"--gpu-workers", "1",
|
||||
"--limit", "50"
|
||||
], capture_output=True)
|
||||
|
||||
assert result.returncode == 0
|
||||
# 验证数据库记录
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. 关键技术点
|
||||
|
||||
### 4.1 避免死锁的策略
|
||||
|
||||
```python
|
||||
# 1. 使用 timeout
|
||||
try:
|
||||
result = future.result(timeout=300)
|
||||
except TimeoutError:
|
||||
logger.warning(f"Task timed out")
|
||||
|
||||
# 2. 使用哨兵值
|
||||
SENTINEL = object()
|
||||
queue.put(SENTINEL) # 发送结束信号
|
||||
|
||||
# 3. 检查进程状态
|
||||
if not worker.is_alive():
|
||||
logger.error("Worker died unexpectedly")
|
||||
break
|
||||
|
||||
# 4. 先清空队列再 join
|
||||
while not queue.empty():
|
||||
results.append(queue.get_nowait())
|
||||
worker.join(timeout=5.0)
|
||||
```
|
||||
|
||||
### 4.2 PaddleOCR 特殊处理
|
||||
|
||||
```python
|
||||
# PaddleOCR 必须在 worker 进程中初始化
|
||||
def init_paddle_worker(gpu_id: int):
|
||||
global _ocr
|
||||
import os
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# 延迟导入,确保 CUDA 环境变量生效
|
||||
from paddleocr import PaddleOCR
|
||||
_ocr = PaddleOCR(
|
||||
use_angle_cls=True,
|
||||
lang='en',
|
||||
use_gpu=True,
|
||||
show_log=False,
|
||||
# 重要:设置 GPU 内存比例
|
||||
gpu_mem=2000 # 限制 GPU 内存使用 (MB)
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3 资源监控
|
||||
|
||||
```python
|
||||
import psutil
|
||||
import GPUtil
|
||||
|
||||
def get_resource_usage():
|
||||
"""获取系统资源使用情况"""
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
memory = psutil.virtual_memory()
|
||||
|
||||
gpu_info = []
|
||||
for gpu in GPUtil.getGPUs():
|
||||
gpu_info.append({
|
||||
"id": gpu.id,
|
||||
"memory_used": gpu.memoryUsed,
|
||||
"memory_total": gpu.memoryTotal,
|
||||
"utilization": gpu.load * 100
|
||||
})
|
||||
|
||||
return {
|
||||
"cpu_percent": cpu_percent,
|
||||
"memory_percent": memory.percent,
|
||||
"gpu": gpu_info
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. 风险评估与应对
|
||||
|
||||
| 风险 | 可能性 | 影响 | 应对策略 |
|
||||
|------|--------|------|----------|
|
||||
| GPU 内存不足 | 中 | 高 | 限制 GPU worker = 1,设置 gpu_mem 参数 |
|
||||
| 进程僵死 | 低 | 高 | 添加心跳检测,超时自动重启 |
|
||||
| 任务分类错误 | 中 | 中 | 添加回退机制,CPU 失败后尝试 GPU |
|
||||
| 数据库写入瓶颈 | 低 | 中 | 增大批处理大小,异步写入 |
|
||||
|
||||
---
|
||||
|
||||
## 6. 备选方案
|
||||
|
||||
如果上述方案仍存在问题,可以考虑:
|
||||
|
||||
### 6.1 使用 Ray
|
||||
|
||||
```python
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
def cpu_task(data):
|
||||
return process_text_pdf(data)
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
def gpu_task(data):
|
||||
return process_scanned_pdf(data)
|
||||
|
||||
# 自动资源调度
|
||||
futures = [cpu_task.remote(d) for d in cpu_docs]
|
||||
futures += [gpu_task.remote(d) for d in gpu_docs]
|
||||
results = ray.get(futures)
|
||||
```
|
||||
|
||||
### 6.2 单池 + 动态 GPU 调度
|
||||
|
||||
保持单池模式,但在每个任务内部动态决定是否使用 GPU:
|
||||
|
||||
```python
|
||||
def process_document(doc_data):
|
||||
if is_scanned_pdf(doc_data):
|
||||
# 使用 GPU (需要全局锁或信号量控制并发)
|
||||
with gpu_semaphore:
|
||||
return process_with_ocr(doc_data)
|
||||
else:
|
||||
return process_text_only(doc_data)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. 时间线总结
|
||||
|
||||
| 阶段 | 任务 | 预计工作量 |
|
||||
|------|------|------------|
|
||||
| 阶段 1 | 基础架构重构 | 2-3 天 |
|
||||
| 阶段 2 | 双池协调器实现 | 2-3 天 |
|
||||
| 阶段 3 | 集成到 autolabel | 1-2 天 |
|
||||
| 阶段 4 | 测试与验证 | 1-2 天 |
|
||||
| **总计** | | **6-10 天** |
|
||||
|
||||
---
|
||||
|
||||
## 8. 参考资料
|
||||
|
||||
1. [Python concurrent.futures 官方文档](https://docs.python.org/3/library/concurrent.futures.html)
|
||||
2. [PyTorch Multiprocessing Best Practices](https://docs.pytorch.org/docs/stable/notes/multiprocessing.html)
|
||||
3. [Super Fast Python - ProcessPoolExecutor 完整指南](https://superfastpython.com/processpoolexecutor-in-python/)
|
||||
4. [PaddleOCR 并行推理文档](http://www.paddleocr.ai/main/en/version3.x/pipeline_usage/instructions/parallel_inference.html)
|
||||
5. [AWS - 跨 CPU/GPU 并行化 ML 推理](https://aws.amazon.com/blogs/machine-learning/parallelizing-across-multiple-cpu-gpus-to-speed-up-deep-learning-inference-at-the-edge/)
|
||||
6. [Ray 分布式多进程处理](https://docs.ray.io/en/latest/ray-more-libs/multiprocessing.html)
|
||||
14
run_server.py
Normal file
14
run_server.py
Normal file
@@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Quick start script for the web server.
|
||||
|
||||
Usage:
|
||||
python run_server.py
|
||||
python run_server.py --port 8080
|
||||
python run_server.py --debug --reload
|
||||
"""
|
||||
|
||||
from src.cli.serve import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
600
src/cli/analyze_labels.py
Normal file
600
src/cli/analyze_labels.py
Normal file
@@ -0,0 +1,600 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Label Analysis CLI
|
||||
|
||||
Analyzes auto-generated labels to identify failures and diagnose root causes.
|
||||
Now reads from PostgreSQL database instead of JSONL files.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
|
||||
from ..normalize import normalize_field
|
||||
from ..matcher import FieldMatcher
|
||||
from ..pdf import is_text_pdf, extract_text_tokens
|
||||
from ..yolo.annotation_generator import FIELD_CLASSES
|
||||
from ..data.db import DocumentDB
|
||||
|
||||
|
||||
@dataclass
|
||||
class FieldAnalysis:
|
||||
"""Analysis result for a single field."""
|
||||
field_name: str
|
||||
csv_value: str
|
||||
expected: bool # True if CSV has value
|
||||
labeled: bool # True if label file has this field
|
||||
matched: bool # True if matcher finds it
|
||||
|
||||
# Diagnosis
|
||||
failure_reason: Optional[str] = None
|
||||
details: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentAnalysis:
|
||||
"""Analysis result for a document."""
|
||||
doc_id: str
|
||||
pdf_exists: bool
|
||||
pdf_type: str # "text" or "scanned"
|
||||
total_pages: int
|
||||
|
||||
# Per-field analysis
|
||||
fields: list[FieldAnalysis] = field(default_factory=list)
|
||||
|
||||
# Summary
|
||||
csv_fields_count: int = 0 # Fields with values in CSV
|
||||
labeled_fields_count: int = 0 # Fields in label files
|
||||
matched_fields_count: int = 0 # Fields matcher can find
|
||||
|
||||
@property
|
||||
def has_issues(self) -> bool:
|
||||
"""Check if document has any labeling issues."""
|
||||
return any(
|
||||
f.expected and not f.labeled
|
||||
for f in self.fields
|
||||
)
|
||||
|
||||
@property
|
||||
def missing_labels(self) -> list[FieldAnalysis]:
|
||||
"""Get fields that should be labeled but aren't."""
|
||||
return [f for f in self.fields if f.expected and not f.labeled]
|
||||
|
||||
|
||||
class LabelAnalyzer:
|
||||
"""Analyzes labels and diagnoses failures."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str,
|
||||
pdf_dir: str,
|
||||
dataset_dir: str,
|
||||
use_db: bool = True
|
||||
):
|
||||
self.csv_path = Path(csv_path)
|
||||
self.pdf_dir = Path(pdf_dir)
|
||||
self.dataset_dir = Path(dataset_dir)
|
||||
self.use_db = use_db
|
||||
|
||||
self.matcher = FieldMatcher()
|
||||
self.csv_data = {}
|
||||
self.label_data = {}
|
||||
self.report_data = {}
|
||||
|
||||
# Database connection
|
||||
self.db = None
|
||||
if use_db:
|
||||
self.db = DocumentDB()
|
||||
self.db.connect()
|
||||
|
||||
# Class ID to name mapping
|
||||
self.class_names = list(FIELD_CLASSES.keys())
|
||||
|
||||
def load_csv(self):
|
||||
"""Load CSV data."""
|
||||
with open(self.csv_path, 'r', encoding='utf-8-sig') as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
doc_id = row['DocumentId']
|
||||
self.csv_data[doc_id] = row
|
||||
print(f"Loaded {len(self.csv_data)} records from CSV")
|
||||
|
||||
def load_labels(self):
|
||||
"""Load all label files from dataset."""
|
||||
for split in ['train', 'val', 'test']:
|
||||
label_dir = self.dataset_dir / split / 'labels'
|
||||
if not label_dir.exists():
|
||||
continue
|
||||
|
||||
for label_file in label_dir.glob('*.txt'):
|
||||
# Parse document ID from filename (uuid_page_XXX.txt)
|
||||
name = label_file.stem
|
||||
parts = name.rsplit('_page_', 1)
|
||||
if len(parts) == 2:
|
||||
doc_id = parts[0]
|
||||
page_no = int(parts[1])
|
||||
else:
|
||||
continue
|
||||
|
||||
if doc_id not in self.label_data:
|
||||
self.label_data[doc_id] = {'pages': {}, 'split': split}
|
||||
|
||||
# Parse label file
|
||||
labels = []
|
||||
with open(label_file, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 5:
|
||||
class_id = int(parts[0])
|
||||
labels.append({
|
||||
'class_id': class_id,
|
||||
'class_name': self.class_names[class_id],
|
||||
'x_center': float(parts[1]),
|
||||
'y_center': float(parts[2]),
|
||||
'width': float(parts[3]),
|
||||
'height': float(parts[4])
|
||||
})
|
||||
|
||||
self.label_data[doc_id]['pages'][page_no] = labels
|
||||
|
||||
total_docs = len(self.label_data)
|
||||
total_labels = sum(
|
||||
len(labels)
|
||||
for doc in self.label_data.values()
|
||||
for labels in doc['pages'].values()
|
||||
)
|
||||
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
|
||||
|
||||
def load_report(self):
|
||||
"""Load autolabel report from database."""
|
||||
if not self.db:
|
||||
print("Database not configured, skipping report loading")
|
||||
return
|
||||
|
||||
# Get document IDs from CSV to query
|
||||
doc_ids = list(self.csv_data.keys())
|
||||
if not doc_ids:
|
||||
return
|
||||
|
||||
# Query in batches to avoid memory issues
|
||||
batch_size = 1000
|
||||
loaded = 0
|
||||
|
||||
for i in range(0, len(doc_ids), batch_size):
|
||||
batch_ids = doc_ids[i:i + batch_size]
|
||||
for doc_id in batch_ids:
|
||||
doc = self.db.get_document(doc_id)
|
||||
if doc:
|
||||
self.report_data[doc_id] = doc
|
||||
loaded += 1
|
||||
|
||||
print(f"Loaded {loaded} autolabel reports from database")
|
||||
|
||||
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
|
||||
"""Analyze a single document."""
|
||||
csv_row = self.csv_data.get(doc_id, {})
|
||||
label_info = self.label_data.get(doc_id, {'pages': {}})
|
||||
report = self.report_data.get(doc_id, {})
|
||||
|
||||
# Check PDF
|
||||
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
||||
pdf_exists = pdf_path.exists()
|
||||
|
||||
# Skip documents without PDF if requested
|
||||
if skip_missing_pdf and not pdf_exists:
|
||||
return None
|
||||
|
||||
pdf_type = "unknown"
|
||||
total_pages = 0
|
||||
|
||||
if pdf_exists:
|
||||
pdf_type = "scanned" if not is_text_pdf(pdf_path) else "text"
|
||||
total_pages = len(label_info['pages']) or report.get('total_pages', 0)
|
||||
|
||||
analysis = DocumentAnalysis(
|
||||
doc_id=doc_id,
|
||||
pdf_exists=pdf_exists,
|
||||
pdf_type=pdf_type,
|
||||
total_pages=total_pages
|
||||
)
|
||||
|
||||
# Get labeled classes
|
||||
labeled_classes = set()
|
||||
for page_labels in label_info['pages'].values():
|
||||
for label in page_labels:
|
||||
labeled_classes.add(label['class_name'])
|
||||
|
||||
# Analyze each field
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
csv_value = csv_row.get(field_name, '')
|
||||
if csv_value is None:
|
||||
csv_value = ''
|
||||
csv_value = str(csv_value).strip()
|
||||
|
||||
# Handle datetime values (remove time part)
|
||||
if ' 00:00:00' in csv_value:
|
||||
csv_value = csv_value.replace(' 00:00:00', '')
|
||||
|
||||
expected = bool(csv_value)
|
||||
labeled = field_name in labeled_classes
|
||||
|
||||
field_analysis = FieldAnalysis(
|
||||
field_name=field_name,
|
||||
csv_value=csv_value,
|
||||
expected=expected,
|
||||
labeled=labeled,
|
||||
matched=False
|
||||
)
|
||||
|
||||
if expected:
|
||||
analysis.csv_fields_count += 1
|
||||
if labeled:
|
||||
analysis.labeled_fields_count += 1
|
||||
|
||||
# Diagnose failures
|
||||
if expected and not labeled:
|
||||
field_analysis.failure_reason = self._diagnose_failure(
|
||||
doc_id, field_name, csv_value, pdf_path, pdf_type, report
|
||||
)
|
||||
field_analysis.details = self._get_failure_details(
|
||||
doc_id, field_name, csv_value, pdf_path, pdf_type
|
||||
)
|
||||
elif not expected and labeled:
|
||||
field_analysis.failure_reason = "EXTRA_LABEL"
|
||||
field_analysis.details = {'note': 'Labeled but no CSV value'}
|
||||
|
||||
analysis.fields.append(field_analysis)
|
||||
|
||||
return analysis
|
||||
|
||||
def _diagnose_failure(
|
||||
self,
|
||||
doc_id: str,
|
||||
field_name: str,
|
||||
csv_value: str,
|
||||
pdf_path: Path,
|
||||
pdf_type: str,
|
||||
report: dict
|
||||
) -> str:
|
||||
"""Diagnose why a field wasn't labeled."""
|
||||
|
||||
if not pdf_path.exists():
|
||||
return "PDF_NOT_FOUND"
|
||||
|
||||
if pdf_type == "scanned":
|
||||
return "SCANNED_PDF"
|
||||
|
||||
# Try to match now with current normalizer (not historical report)
|
||||
if pdf_path.exists() and pdf_type == "text":
|
||||
try:
|
||||
# Check all pages
|
||||
for page_no in range(10): # Max 10 pages
|
||||
try:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
if not tokens:
|
||||
break
|
||||
|
||||
normalized = normalize_field(field_name, csv_value)
|
||||
matches = self.matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
if matches:
|
||||
return "MATCHER_OK_NOW" # Would match with current normalizer
|
||||
except Exception:
|
||||
break
|
||||
|
||||
return "VALUE_NOT_IN_PDF"
|
||||
|
||||
except Exception as e:
|
||||
return f"PDF_ERROR: {str(e)[:50]}"
|
||||
|
||||
return "UNKNOWN"
|
||||
|
||||
def _get_failure_details(
|
||||
self,
|
||||
doc_id: str,
|
||||
field_name: str,
|
||||
csv_value: str,
|
||||
pdf_path: Path,
|
||||
pdf_type: str
|
||||
) -> dict:
|
||||
"""Get detailed information about a failure."""
|
||||
details = {
|
||||
'csv_value': csv_value,
|
||||
'normalized_candidates': [],
|
||||
'pdf_tokens_sample': [],
|
||||
'potential_matches': []
|
||||
}
|
||||
|
||||
# Get normalized candidates
|
||||
try:
|
||||
details['normalized_candidates'] = normalize_field(field_name, csv_value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Get PDF tokens if available
|
||||
if pdf_path.exists() and pdf_type == "text":
|
||||
try:
|
||||
tokens = list(extract_text_tokens(pdf_path, 0))[:100]
|
||||
|
||||
# Find tokens that might be related
|
||||
candidates = details['normalized_candidates']
|
||||
for token in tokens:
|
||||
text = token.text.strip()
|
||||
# Check if any candidate is substring or similar
|
||||
for cand in candidates:
|
||||
if cand in text or text in cand:
|
||||
details['potential_matches'].append({
|
||||
'token': text,
|
||||
'candidate': cand,
|
||||
'bbox': token.bbox
|
||||
})
|
||||
break
|
||||
# Also collect date-like or number-like tokens for reference
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate'):
|
||||
if any(c.isdigit() for c in text) and len(text) >= 6:
|
||||
details['pdf_tokens_sample'].append(text)
|
||||
elif field_name == 'Amount':
|
||||
if any(c.isdigit() for c in text) and (',' in text or '.' in text or len(text) >= 4):
|
||||
details['pdf_tokens_sample'].append(text)
|
||||
|
||||
# Limit samples
|
||||
details['pdf_tokens_sample'] = details['pdf_tokens_sample'][:10]
|
||||
details['potential_matches'] = details['potential_matches'][:5]
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return details
|
||||
|
||||
def run_analysis(self, limit: Optional[int] = None, skip_missing_pdf: bool = True) -> list[DocumentAnalysis]:
|
||||
"""Run analysis on all documents."""
|
||||
self.load_csv()
|
||||
self.load_labels()
|
||||
self.load_report()
|
||||
|
||||
results = []
|
||||
doc_ids = list(self.csv_data.keys())
|
||||
skipped = 0
|
||||
|
||||
for doc_id in doc_ids:
|
||||
analysis = self.analyze_document(doc_id, skip_missing_pdf=skip_missing_pdf)
|
||||
if analysis is None:
|
||||
skipped += 1
|
||||
continue
|
||||
results.append(analysis)
|
||||
if limit and len(results) >= limit:
|
||||
break
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} documents without PDF files")
|
||||
|
||||
return results
|
||||
|
||||
def generate_report(
|
||||
self,
|
||||
results: list[DocumentAnalysis],
|
||||
output_path: str,
|
||||
verbose: bool = False
|
||||
):
|
||||
"""Generate analysis report."""
|
||||
output = Path(output_path)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect statistics
|
||||
stats = {
|
||||
'total_documents': len(results),
|
||||
'documents_with_issues': 0,
|
||||
'total_expected_fields': 0,
|
||||
'total_labeled_fields': 0,
|
||||
'missing_labels': 0,
|
||||
'extra_labels': 0,
|
||||
'failure_reasons': defaultdict(int),
|
||||
'failures_by_field': defaultdict(lambda: defaultdict(int))
|
||||
}
|
||||
|
||||
issues = []
|
||||
|
||||
for analysis in results:
|
||||
stats['total_expected_fields'] += analysis.csv_fields_count
|
||||
stats['total_labeled_fields'] += analysis.labeled_fields_count
|
||||
|
||||
if analysis.has_issues:
|
||||
stats['documents_with_issues'] += 1
|
||||
|
||||
for f in analysis.fields:
|
||||
if f.expected and not f.labeled:
|
||||
stats['missing_labels'] += 1
|
||||
stats['failure_reasons'][f.failure_reason] += 1
|
||||
stats['failures_by_field'][f.field_name][f.failure_reason] += 1
|
||||
|
||||
issues.append({
|
||||
'doc_id': analysis.doc_id,
|
||||
'field': f.field_name,
|
||||
'csv_value': f.csv_value,
|
||||
'reason': f.failure_reason,
|
||||
'details': f.details if verbose else {}
|
||||
})
|
||||
elif not f.expected and f.labeled:
|
||||
stats['extra_labels'] += 1
|
||||
|
||||
# Write JSON report
|
||||
report = {
|
||||
'summary': {
|
||||
'total_documents': stats['total_documents'],
|
||||
'documents_with_issues': stats['documents_with_issues'],
|
||||
'issue_rate': f"{stats['documents_with_issues'] / stats['total_documents'] * 100:.1f}%",
|
||||
'total_expected_fields': stats['total_expected_fields'],
|
||||
'total_labeled_fields': stats['total_labeled_fields'],
|
||||
'label_coverage': f"{stats['total_labeled_fields'] / max(1, stats['total_expected_fields']) * 100:.1f}%",
|
||||
'missing_labels': stats['missing_labels'],
|
||||
'extra_labels': stats['extra_labels']
|
||||
},
|
||||
'failure_reasons': dict(stats['failure_reasons']),
|
||||
'failures_by_field': {
|
||||
field: dict(reasons)
|
||||
for field, reasons in stats['failures_by_field'].items()
|
||||
},
|
||||
'issues': issues
|
||||
}
|
||||
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport saved to: {output}")
|
||||
|
||||
return report
|
||||
|
||||
|
||||
def print_summary(report: dict):
|
||||
"""Print summary to console."""
|
||||
summary = report['summary']
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LABEL ANALYSIS SUMMARY")
|
||||
print("=" * 60)
|
||||
|
||||
print(f"\nDocuments:")
|
||||
print(f" Total: {summary['total_documents']}")
|
||||
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
|
||||
|
||||
print(f"\nFields:")
|
||||
print(f" Expected: {summary['total_expected_fields']}")
|
||||
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
|
||||
print(f" Missing: {summary['missing_labels']}")
|
||||
print(f" Extra: {summary['extra_labels']}")
|
||||
|
||||
print(f"\nFailure Reasons:")
|
||||
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
|
||||
print(f"\nFailures by Field:")
|
||||
for field, reasons in report['failures_by_field'].items():
|
||||
total = sum(reasons.values())
|
||||
print(f" {field}: {total}")
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
|
||||
# Show sample issues
|
||||
if report['issues']:
|
||||
print(f"\n" + "-" * 60)
|
||||
print("SAMPLE ISSUES (first 10)")
|
||||
print("-" * 60)
|
||||
|
||||
for issue in report['issues'][:10]:
|
||||
print(f"\n[{issue['doc_id']}] {issue['field']}")
|
||||
print(f" CSV value: {issue['csv_value']}")
|
||||
print(f" Reason: {issue['reason']}")
|
||||
|
||||
if issue.get('details'):
|
||||
details = issue['details']
|
||||
if details.get('normalized_candidates'):
|
||||
print(f" Candidates: {details['normalized_candidates'][:5]}")
|
||||
if details.get('pdf_tokens_sample'):
|
||||
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
|
||||
if details.get('potential_matches'):
|
||||
print(f" Potential matches:")
|
||||
for pm in details['potential_matches'][:3]:
|
||||
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze auto-generated labels and diagnose failures'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default='data/structured_data/document_export_20260109_220326.csv',
|
||||
help='Path to structured data CSV file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default='data/raw_pdfs',
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dataset', '-d',
|
||||
default='data/dataset',
|
||||
help='Dataset directory with labels'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='reports/label_analysis.json',
|
||||
help='Output path for analysis report'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents to analyze'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Include detailed failure information'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single', '-s',
|
||||
help='Analyze single document ID'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-db',
|
||||
action='store_true',
|
||||
help='Skip database, only analyze label files'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
analyzer = LabelAnalyzer(
|
||||
csv_path=args.csv,
|
||||
pdf_dir=args.pdf_dir,
|
||||
dataset_dir=args.dataset,
|
||||
use_db=not args.no_db
|
||||
)
|
||||
|
||||
if args.single:
|
||||
# Analyze single document
|
||||
analyzer.load_csv()
|
||||
analyzer.load_labels()
|
||||
analyzer.load_report()
|
||||
|
||||
analysis = analyzer.analyze_document(args.single)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Document: {analysis.doc_id}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"PDF exists: {analysis.pdf_exists}")
|
||||
print(f"PDF type: {analysis.pdf_type}")
|
||||
print(f"Pages: {analysis.total_pages}")
|
||||
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
|
||||
|
||||
for f in analysis.fields:
|
||||
status = "✓" if f.labeled else ("✗" if f.expected else "-")
|
||||
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
|
||||
print(f" [{status}] {f.field_name}: {value_str}")
|
||||
|
||||
if f.failure_reason:
|
||||
print(f" Reason: {f.failure_reason}")
|
||||
if f.details.get('normalized_candidates'):
|
||||
print(f" Candidates: {f.details['normalized_candidates']}")
|
||||
if f.details.get('potential_matches'):
|
||||
print(f" Potential matches in PDF:")
|
||||
for pm in f.details['potential_matches'][:3]:
|
||||
print(f" - '{pm['token']}'")
|
||||
else:
|
||||
# Full analysis
|
||||
print("Running label analysis...")
|
||||
results = analyzer.run_analysis(limit=args.limit)
|
||||
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
|
||||
print_summary(report)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
435
src/cli/analyze_report.py
Normal file
435
src/cli/analyze_report.py
Normal file
@@ -0,0 +1,435 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Analyze Auto-Label Report
|
||||
|
||||
Generates statistics and insights from database or autolabel_report.jsonl
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
"""Load statistics directly from database using SQL aggregation."""
|
||||
from ..data.db import DocumentDB
|
||||
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
|
||||
'by_field': defaultdict(lambda: {
|
||||
'total': 0,
|
||||
'matched': 0,
|
||||
'exact_match': 0,
|
||||
'flexible_match': 0,
|
||||
'scores': [],
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}),
|
||||
'errors': defaultdict(int),
|
||||
'processing_times': [],
|
||||
}
|
||||
|
||||
conn = db.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Overall stats
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful,
|
||||
SUM(CASE WHEN NOT success THEN 1 ELSE 0 END) as failed
|
||||
FROM documents
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
stats['total'] = row[0] or 0
|
||||
stats['successful'] = row[1] or 0
|
||||
stats['failed'] = row[2] or 0
|
||||
|
||||
# By PDF type
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
pdf_type,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN success THEN 1 ELSE 0 END) as successful
|
||||
FROM documents
|
||||
GROUP BY pdf_type
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
pdf_type = row[0] or 'unknown'
|
||||
stats['by_pdf_type'][pdf_type] = {
|
||||
'total': row[1] or 0,
|
||||
'successful': row[2] or 0
|
||||
}
|
||||
|
||||
# Processing times
|
||||
cursor.execute("""
|
||||
SELECT AVG(processing_time_ms), MIN(processing_time_ms), MAX(processing_time_ms)
|
||||
FROM documents
|
||||
WHERE processing_time_ms > 0
|
||||
""")
|
||||
row = cursor.fetchone()
|
||||
if row[0]:
|
||||
stats['processing_time_stats'] = {
|
||||
'avg_ms': float(row[0]),
|
||||
'min_ms': float(row[1]),
|
||||
'max_ms': float(row[2])
|
||||
}
|
||||
|
||||
# Field stats
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
field_name,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched,
|
||||
SUM(CASE WHEN matched AND score >= 0.99 THEN 1 ELSE 0 END) as exact_match,
|
||||
SUM(CASE WHEN matched AND score < 0.99 THEN 1 ELSE 0 END) as flexible_match,
|
||||
AVG(CASE WHEN matched THEN score END) as avg_score
|
||||
FROM field_results
|
||||
GROUP BY field_name
|
||||
ORDER BY field_name
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
field_name = row[0]
|
||||
stats['by_field'][field_name] = {
|
||||
'total': row[1] or 0,
|
||||
'matched': row[2] or 0,
|
||||
'exact_match': row[3] or 0,
|
||||
'flexible_match': row[4] or 0,
|
||||
'avg_score': float(row[5]) if row[5] else 0,
|
||||
'scores': [], # Not loading individual scores for efficiency
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}
|
||||
|
||||
# Field stats by PDF type
|
||||
cursor.execute("""
|
||||
SELECT
|
||||
fr.field_name,
|
||||
d.pdf_type,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN fr.matched THEN 1 ELSE 0 END) as matched
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
GROUP BY fr.field_name, d.pdf_type
|
||||
""")
|
||||
for row in cursor.fetchall():
|
||||
field_name = row[0]
|
||||
pdf_type = row[1] or 'unknown'
|
||||
if field_name in stats['by_field']:
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type] = {
|
||||
'total': row[2] or 0,
|
||||
'matched': row[3] or 0
|
||||
}
|
||||
|
||||
db.close()
|
||||
return stats
|
||||
|
||||
|
||||
def load_reports_from_file(report_path: str) -> list[dict]:
|
||||
"""Load all reports from JSONL file(s). Supports glob patterns."""
|
||||
path = Path(report_path)
|
||||
|
||||
# Handle glob pattern
|
||||
if '*' in str(path) or '?' in str(path):
|
||||
parent = path.parent
|
||||
pattern = path.name
|
||||
report_files = sorted(parent.glob(pattern))
|
||||
else:
|
||||
report_files = [path]
|
||||
|
||||
if not report_files:
|
||||
return []
|
||||
|
||||
print(f"Reading {len(report_files)} report file(s):")
|
||||
for f in report_files:
|
||||
print(f" - {f.name}")
|
||||
|
||||
reports = []
|
||||
for report_file in report_files:
|
||||
if not report_file.exists():
|
||||
continue
|
||||
with open(report_file, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
reports.append(json.loads(line))
|
||||
|
||||
return reports
|
||||
|
||||
|
||||
def analyze_reports(reports: list[dict]) -> dict:
|
||||
"""Analyze reports and generate statistics."""
|
||||
stats = {
|
||||
'total': len(reports),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'successful': 0}),
|
||||
'by_field': defaultdict(lambda: {
|
||||
'total': 0,
|
||||
'matched': 0,
|
||||
'exact_match': 0, # score == 1.0
|
||||
'flexible_match': 0, # score < 1.0
|
||||
'scores': [],
|
||||
'by_pdf_type': defaultdict(lambda: {'total': 0, 'matched': 0})
|
||||
}),
|
||||
'errors': defaultdict(int),
|
||||
'processing_times': [],
|
||||
}
|
||||
|
||||
for report in reports:
|
||||
pdf_type = report.get('pdf_type') or 'unknown'
|
||||
success = report.get('success', False)
|
||||
|
||||
# Overall stats
|
||||
if success:
|
||||
stats['successful'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
# By PDF type
|
||||
stats['by_pdf_type'][pdf_type]['total'] += 1
|
||||
if success:
|
||||
stats['by_pdf_type'][pdf_type]['successful'] += 1
|
||||
|
||||
# Processing time
|
||||
proc_time = report.get('processing_time_ms', 0)
|
||||
if proc_time > 0:
|
||||
stats['processing_times'].append(proc_time)
|
||||
|
||||
# Errors
|
||||
for error in report.get('errors', []):
|
||||
stats['errors'][error] += 1
|
||||
|
||||
# Field results
|
||||
for field_result in report.get('field_results', []):
|
||||
field_name = field_result['field_name']
|
||||
matched = field_result.get('matched', False)
|
||||
score = field_result.get('score', 0.0)
|
||||
|
||||
stats['by_field'][field_name]['total'] += 1
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['total'] += 1
|
||||
|
||||
if matched:
|
||||
stats['by_field'][field_name]['matched'] += 1
|
||||
stats['by_field'][field_name]['scores'].append(score)
|
||||
stats['by_field'][field_name]['by_pdf_type'][pdf_type]['matched'] += 1
|
||||
|
||||
if score >= 0.99:
|
||||
stats['by_field'][field_name]['exact_match'] += 1
|
||||
else:
|
||||
stats['by_field'][field_name]['flexible_match'] += 1
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def print_report(stats: dict, verbose: bool = False):
|
||||
"""Print analysis report."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AUTO-LABEL REPORT ANALYSIS")
|
||||
print("=" * 60)
|
||||
|
||||
# Overall stats
|
||||
print(f"\n{'OVERALL STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
total = stats['total']
|
||||
successful = stats['successful']
|
||||
failed = stats['failed']
|
||||
success_rate = successful / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"Total documents: {total:>8}")
|
||||
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
|
||||
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
|
||||
|
||||
# Processing time
|
||||
if 'processing_time_stats' in stats:
|
||||
pts = stats['processing_time_stats']
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {pts['avg_ms']:>8.1f}")
|
||||
print(f" Min: {pts['min_ms']:>8.1f}")
|
||||
print(f" Max: {pts['max_ms']:>8.1f}")
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {avg_time:>8.1f}")
|
||||
print(f" Min: {min_time:>8.1f}")
|
||||
print(f" Max: {max_time:>8.1f}")
|
||||
|
||||
# By PDF type
|
||||
print(f"\n{'BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
|
||||
print("-" * 60)
|
||||
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
|
||||
type_total = type_stats['total']
|
||||
type_success = type_stats['successful']
|
||||
type_rate = type_success / type_total * 100 if type_total > 0 else 0
|
||||
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
|
||||
|
||||
# By field
|
||||
print(f"\n{'FIELD MATCH STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
|
||||
print("-" * 60)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
continue
|
||||
field_stats = stats['by_field'][field_name]
|
||||
total = field_stats['total']
|
||||
matched = field_stats['matched']
|
||||
exact = field_stats['exact_match']
|
||||
flex = field_stats['flexible_match']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
|
||||
# Handle avg_score from either DB or file analysis
|
||||
if 'avg_score' in field_stats:
|
||||
avg_score = field_stats['avg_score']
|
||||
elif field_stats['scores']:
|
||||
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
|
||||
else:
|
||||
avg_score = 0
|
||||
|
||||
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
|
||||
|
||||
# Field match by PDF type
|
||||
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
|
||||
for pdf_type in sorted(stats['by_pdf_type'].keys()):
|
||||
print(f"\n[{pdf_type.upper()}]")
|
||||
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
|
||||
print("-" * 50)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
continue
|
||||
type_stats = stats['by_field'][field_name]['by_pdf_type'].get(pdf_type, {'total': 0, 'matched': 0})
|
||||
total = type_stats['total']
|
||||
matched = type_stats['matched']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
|
||||
|
||||
# Errors
|
||||
if stats.get('errors') and verbose:
|
||||
print(f"\n{'ERRORS':^60}")
|
||||
print("-" * 60)
|
||||
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
|
||||
print(f"{count:>5}x {error[:50]}")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
|
||||
|
||||
def export_json(stats: dict, output_path: str):
|
||||
"""Export statistics to JSON file."""
|
||||
# Convert defaultdicts to regular dicts for JSON serialization
|
||||
export_data = {
|
||||
'total': stats['total'],
|
||||
'successful': stats['successful'],
|
||||
'failed': stats['failed'],
|
||||
'by_pdf_type': dict(stats['by_pdf_type']),
|
||||
'by_field': {},
|
||||
'errors': dict(stats.get('errors', {})),
|
||||
}
|
||||
|
||||
# Processing time stats
|
||||
if 'processing_time_stats' in stats:
|
||||
export_data['processing_time_stats'] = stats['processing_time_stats']
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
export_data['processing_time_stats'] = {
|
||||
'avg_ms': sum(times) / len(times),
|
||||
'min_ms': min(times),
|
||||
'max_ms': max(times),
|
||||
'count': len(times)
|
||||
}
|
||||
|
||||
# Field stats
|
||||
for field_name, field_stats in stats['by_field'].items():
|
||||
avg_score = field_stats.get('avg_score', 0)
|
||||
if not avg_score and field_stats.get('scores'):
|
||||
avg_score = sum(field_stats['scores']) / len(field_stats['scores'])
|
||||
|
||||
export_data['by_field'][field_name] = {
|
||||
'total': field_stats['total'],
|
||||
'matched': field_stats['matched'],
|
||||
'exact_match': field_stats['exact_match'],
|
||||
'flexible_match': field_stats['flexible_match'],
|
||||
'match_rate': field_stats['matched'] / field_stats['total'] if field_stats['total'] > 0 else 0,
|
||||
'avg_score': avg_score,
|
||||
'by_pdf_type': dict(field_stats['by_pdf_type'])
|
||||
}
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nStatistics exported to: {output_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Analyze auto-label report'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report', '-r',
|
||||
default=None,
|
||||
help='Path to autolabel report JSONL file (uses database if not specified)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
help='Export statistics to JSON file'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Show detailed error messages'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--from-file',
|
||||
action='store_true',
|
||||
help='Force reading from JSONL file instead of database'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Decide source
|
||||
use_db = not args.from_file and args.report is None
|
||||
|
||||
if use_db:
|
||||
print("Loading statistics from database...")
|
||||
stats = load_reports_from_db()
|
||||
print(f"Loaded stats for {stats['total']} documents")
|
||||
else:
|
||||
report_path = args.report or 'reports/autolabel_report.jsonl'
|
||||
path = Path(report_path)
|
||||
|
||||
# Check if file exists (handle glob patterns)
|
||||
if '*' not in str(path) and '?' not in str(path) and not path.exists():
|
||||
print(f"Error: Report file not found: {path}")
|
||||
return 1
|
||||
|
||||
print(f"Loading reports from: {report_path}")
|
||||
reports = load_reports_from_file(report_path)
|
||||
print(f"Loaded {len(reports)} reports")
|
||||
stats = analyze_reports(reports)
|
||||
|
||||
print_report(stats, verbose=args.verbose)
|
||||
|
||||
if args.output:
|
||||
export_json(stats, args.output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
exit(main())
|
||||
@@ -8,31 +8,83 @@ Generates YOLO training data from PDFs and structured CSV data.
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import multiprocessing
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
if sys.platform == 'win32':
|
||||
multiprocessing.set_start_method('spawn', force=True)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS, AUTOLABEL
|
||||
|
||||
# Global OCR engine for worker processes (initialized once per worker)
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = False
|
||||
_worker_type = None # 'cpu' or 'gpu'
|
||||
|
||||
|
||||
def _init_cpu_worker():
|
||||
"""Initialize CPU worker (no OCR engine needed)."""
|
||||
global _worker_initialized, _worker_type
|
||||
_worker_initialized = True
|
||||
_worker_type = 'cpu'
|
||||
|
||||
|
||||
def _init_gpu_worker():
|
||||
"""Initialize GPU worker with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine, _worker_initialized, _worker_type
|
||||
|
||||
# Suppress PaddlePaddle/PaddleX reinitialization warnings
|
||||
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
||||
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
||||
|
||||
# Set environment variable to suppress paddle warnings
|
||||
os.environ['GLOG_minloglevel'] = '2' # Suppress INFO and WARNING logs
|
||||
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
_worker_initialized = True
|
||||
_worker_type = 'gpu'
|
||||
|
||||
|
||||
def _init_worker():
|
||||
"""Initialize worker process with OCR engine (called once per worker)."""
|
||||
global _worker_ocr_engine
|
||||
# OCR engine will be lazily initialized on first use
|
||||
_worker_ocr_engine = None
|
||||
"""Initialize worker process with OCR engine (called once per worker).
|
||||
Legacy function for backwards compatibility.
|
||||
"""
|
||||
_init_gpu_worker()
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current worker."""
|
||||
global _worker_ocr_engine
|
||||
|
||||
if _worker_ocr_engine is None:
|
||||
# Suppress warnings during OCR initialization
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
from ..ocr import OCREngine
|
||||
_worker_ocr_engine = OCREngine()
|
||||
|
||||
return _worker_ocr_engine
|
||||
|
||||
|
||||
def _save_output_img(output_img, image_path: Path) -> None:
|
||||
"""Save OCR output_img to replace the original rendered image."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Convert numpy array to PIL Image and save
|
||||
if output_img is not None:
|
||||
img = PILImage.fromarray(output_img)
|
||||
img.save(str(image_path))
|
||||
# If output_img is None, the original image is already saved
|
||||
|
||||
|
||||
def process_single_document(args_tuple):
|
||||
"""
|
||||
Process a single document (worker function for parallel processing).
|
||||
@@ -47,8 +99,7 @@ def process_single_document(args_tuple):
|
||||
|
||||
# Import inside worker to avoid pickling issues
|
||||
from ..data import AutoLabelReport, FieldMatchResult
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..pdf import PDFDocument
|
||||
from ..matcher import FieldMatcher
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
@@ -70,8 +121,11 @@ def process_single_document(args_tuple):
|
||||
}
|
||||
|
||||
try:
|
||||
# Check PDF type
|
||||
use_ocr = not is_text_pdf(pdf_path)
|
||||
# Use PDFDocument context manager for efficient PDF handling
|
||||
# Opens PDF only once, caches dimensions, handles cleanup automatically
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
# Check PDF type (uses cached document)
|
||||
use_ocr = not pdf_doc.is_text_pdf()
|
||||
report.pdf_type = "scanned" if use_ocr else "text"
|
||||
|
||||
# Skip OCR if requested
|
||||
@@ -91,20 +145,37 @@ def process_single_document(args_tuple):
|
||||
|
||||
# Process each page
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
for page_no, image_path in render_pdf_to_images(
|
||||
pdf_path,
|
||||
output_dir / 'temp' / doc_id / 'images',
|
||||
dpi=dpi
|
||||
):
|
||||
# Render all pages and process (uses cached document handle)
|
||||
images_dir = output_dir / 'temp' / doc_id / 'images'
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = get_render_dimensions(pdf_path, page_no, dpi)
|
||||
|
||||
# Get dimensions from cache (no additional PDF open)
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Extract tokens
|
||||
if use_ocr:
|
||||
tokens = ocr_engine.extract_from_image(str(image_path), page_no)
|
||||
# Use extract_with_image to get both tokens and preprocessed image
|
||||
# PaddleOCR coordinates are relative to output_img, not original image
|
||||
ocr_result = ocr_engine.extract_with_image(
|
||||
str(image_path),
|
||||
page_no,
|
||||
scale_to_pdf_points=72 / dpi
|
||||
)
|
||||
tokens = ocr_result.tokens
|
||||
|
||||
# Save output_img to replace the original rendered image
|
||||
# This ensures coordinates match the saved image
|
||||
_save_output_img(ocr_result.output_img, image_path)
|
||||
|
||||
# Update image dimensions to match output_img
|
||||
if ocr_result.output_img is not None:
|
||||
img_height, img_width = ocr_result.output_img.shape[:2]
|
||||
else:
|
||||
tokens = list(extract_text_tokens(pdf_path, page_no))
|
||||
# Use cached document for text extraction
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
@@ -120,6 +191,7 @@ def process_single_document(args_tuple):
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
matched_fields.add(field_name)
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
@@ -131,23 +203,14 @@ def process_single_document(args_tuple):
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords
|
||||
))
|
||||
else:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
# Generate annotations
|
||||
# Count annotations
|
||||
annotations = generator.generate_from_matches(matches, img_width, img_height, dpi=dpi)
|
||||
|
||||
if annotations:
|
||||
label_path = output_dir / 'temp' / doc_id / 'labels' / f"{image_path.stem}.txt"
|
||||
generator.save_annotations(annotations, label_path)
|
||||
page_annotations.append({
|
||||
'image_path': str(image_path),
|
||||
'label_path': str(label_path),
|
||||
'page_no': page_no,
|
||||
'count': len(annotations)
|
||||
})
|
||||
|
||||
@@ -156,6 +219,17 @@ def process_single_document(args_tuple):
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result['stats'][class_name] += 1
|
||||
|
||||
# Record unmatched fields
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if value and field_name not in matched_fields:
|
||||
report.add_field_result(FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=-1
|
||||
))
|
||||
|
||||
if page_annotations:
|
||||
result['pages'] = page_annotations
|
||||
result['success'] = True
|
||||
@@ -178,47 +252,41 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
'--csv', '-c',
|
||||
default='data/structured_data/document_export_20260109_212743.csv',
|
||||
help='Path to structured data CSV file'
|
||||
default=f"{PATHS['csv_dir']}/*.csv",
|
||||
help='Path to CSV file(s). Supports: single file, glob pattern (*.csv), or comma-separated list'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--pdf-dir', '-p',
|
||||
default='data/raw_pdfs',
|
||||
default=PATHS['pdf_dir'],
|
||||
help='Directory containing PDF files'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
default='data/dataset',
|
||||
default=PATHS['output_dir'],
|
||||
help='Output directory for dataset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=300,
|
||||
help='DPI for PDF rendering (default: 300)'
|
||||
default=AUTOLABEL['dpi'],
|
||||
help=f"DPI for PDF rendering (default: {AUTOLABEL['dpi']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--min-confidence',
|
||||
type=float,
|
||||
default=0.7,
|
||||
help='Minimum match confidence (default: 0.7)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--train-ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Training set ratio (default: 0.8)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--val-ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Validation set ratio (default: 0.1)'
|
||||
default=AUTOLABEL['min_confidence'],
|
||||
help=f"Minimum match confidence (default: {AUTOLABEL['min_confidence']})"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--report',
|
||||
default='reports/autolabel_report.jsonl',
|
||||
help='Path for auto-label report (JSONL)'
|
||||
default=f"{PATHS['reports_dir']}/autolabel_report.jsonl",
|
||||
help='Path for auto-label report (JSONL). With --max-records, creates report_part000.jsonl, etc.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-records',
|
||||
type=int,
|
||||
default=10000,
|
||||
help='Max records per report file for sharding (default: 10000, 0 = single file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--single',
|
||||
@@ -233,20 +301,37 @@ def main():
|
||||
'--workers', '-w',
|
||||
type=int,
|
||||
default=4,
|
||||
help='Number of parallel workers (default: 4)'
|
||||
help='Number of parallel workers (default: 4). Use --cpu-workers and --gpu-workers for dual-pool mode.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--cpu-workers',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Number of CPU workers for text PDFs (enables dual-pool mode)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu-workers',
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of GPU workers for scanned PDFs (default: 1, used with --cpu-workers)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--skip-ocr',
|
||||
action='store_true',
|
||||
help='Skip scanned PDFs (text-layer only)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit', '-l',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents to process (for testing)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Import here to avoid slow startup
|
||||
from ..data import CSVLoader, AutoLabelReport, FieldMatchResult
|
||||
from ..data.autolabel_report import ReportWriter
|
||||
from ..yolo import DatasetBuilder
|
||||
from ..pdf import is_text_pdf, render_pdf_to_images, extract_text_tokens
|
||||
from ..pdf.renderer import get_render_dimensions
|
||||
from ..ocr import OCREngine
|
||||
@@ -254,66 +339,206 @@ def main():
|
||||
from ..normalize import normalize_field
|
||||
from ..yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
print(f"Loading CSV data from: {args.csv}")
|
||||
loader = CSVLoader(args.csv, args.pdf_dir)
|
||||
# Handle comma-separated CSV paths
|
||||
csv_input = args.csv
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Validate data
|
||||
issues = loader.validate()
|
||||
if issues:
|
||||
print(f"Warning: Found {len(issues)} validation issues")
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
# Only create temp directory for images (no train/val/test split during labeling)
|
||||
(output_dir / 'temp').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Report writer with optional sharding
|
||||
report_path = Path(args.report)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_writer = ReportWriter(args.report, max_records_per_file=args.max_records)
|
||||
|
||||
# Database connection for checking existing documents
|
||||
from ..data.db import DocumentDB
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
print("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
'total': 0,
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'skipped': 0,
|
||||
'skipped_db': 0, # Skipped because already in DB
|
||||
'retried': 0, # Re-processed failed ones
|
||||
'annotations': 0,
|
||||
'tasks_submitted': 0, # Tracks tasks submitted across all CSVs for limit
|
||||
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
# Track all processed items for final split (write to temp file to save memory)
|
||||
processed_items_file = output_dir / 'temp' / 'processed_items.jsonl'
|
||||
processed_items_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
processed_items_writer = open(processed_items_file, 'w', encoding='utf-8')
|
||||
processed_count = 0
|
||||
seen_doc_ids = set()
|
||||
|
||||
# Batch for database updates
|
||||
db_batch = []
|
||||
DB_BATCH_SIZE = 100
|
||||
|
||||
# Helper function to handle result and update database
|
||||
# Defined outside the loop so nonlocal can properly reference db_batch
|
||||
def handle_result(result):
|
||||
nonlocal processed_count, db_batch
|
||||
|
||||
# Write report to file
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
# Add to database batch
|
||||
db_batch.append(result['report'])
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
if result['success']:
|
||||
# Write to temp file instead of memory
|
||||
import json
|
||||
processed_items_writer.write(json.dumps({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
}) + '\n')
|
||||
processed_items_writer.flush()
|
||||
processed_count += 1
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
def handle_error(doc_id, error):
|
||||
nonlocal db_batch
|
||||
stats['failed'] += 1
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(error)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
db_batch.append(error_report)
|
||||
if len(db_batch) >= DB_BATCH_SIZE:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
for issue in issues[:10]:
|
||||
print(f" - {issue}")
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
|
||||
rows = loader.load_all()
|
||||
print(f"Loaded {len(rows)} invoice records")
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
use_dual_pool = args.cpu_workers is not None
|
||||
|
||||
if use_dual_pool:
|
||||
from src.processing import DualPoolCoordinator
|
||||
from src.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
gpu_id=0,
|
||||
task_timeout=300.0,
|
||||
)
|
||||
dual_pool_coordinator.start()
|
||||
|
||||
try:
|
||||
# Process CSV files one by one (streaming)
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
rows = single_loader.load_all()
|
||||
|
||||
# Filter to single document if specified
|
||||
if args.single:
|
||||
rows = [r for r in rows if r.DocumentId == args.single]
|
||||
if not rows:
|
||||
print(f"Error: Document {args.single} not found")
|
||||
sys.exit(1)
|
||||
print(f"Processing single document: {args.single}")
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
for split in ['train', 'val', 'test']:
|
||||
(output_dir / split / 'images').mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / split / 'labels').mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate YOLO config files
|
||||
AnnotationGenerator.generate_classes_file(output_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(output_dir / 'dataset.yaml')
|
||||
|
||||
# Report writer
|
||||
report_path = Path(args.report)
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
report_writer = ReportWriter(args.report)
|
||||
|
||||
# Stats
|
||||
stats = {
|
||||
'total': len(rows),
|
||||
'successful': 0,
|
||||
'failed': 0,
|
||||
'skipped': 0,
|
||||
'annotations': 0,
|
||||
'by_field': {name: 0 for name in FIELD_CLASSES.keys()}
|
||||
}
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
for row in rows:
|
||||
pdf_path = loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
# Write report for missing PDF
|
||||
report = AutoLabelReport(document_id=row.DocumentId)
|
||||
report.errors.append("PDF not found")
|
||||
report_writer.write(report)
|
||||
stats['failed'] += 1
|
||||
continue
|
||||
|
||||
# Convert row to dict for pickling
|
||||
# Deduplicate across CSV files
|
||||
rows = [r for r in rows if r.DocumentId not in seen_doc_ids]
|
||||
for r in rows:
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
csv_doc_ids = [r.DocumentId for r in rows]
|
||||
db_status_map = db.check_documents_status_batch(csv_doc_ids)
|
||||
|
||||
# Count how many are already processed successfully
|
||||
already_processed = sum(1 for doc_id in csv_doc_ids if db_status_map.get(doc_id) is True)
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
# Prepare tasks for this CSV
|
||||
tasks = []
|
||||
skipped_in_csv = 0
|
||||
retry_in_csv = 0
|
||||
|
||||
# Calculate how many more we can process if limit is set
|
||||
# Use tasks_submitted counter which tracks across all CSVs
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
break
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
|
||||
for row in rows:
|
||||
# Stop adding tasks if we've reached the limit
|
||||
if len(tasks) >= remaining_limit:
|
||||
break
|
||||
|
||||
doc_id = row.DocumentId
|
||||
|
||||
# Check document status from batch query result
|
||||
db_status = db_status_map.get(doc_id) # None if not in DB
|
||||
|
||||
# Skip if already successful in database
|
||||
if db_status is True:
|
||||
stats['skipped_db'] += 1
|
||||
skipped_in_csv += 1
|
||||
continue
|
||||
|
||||
# Check if this is a retry (was failed before)
|
||||
if db_status is False:
|
||||
stats['retried'] += 1
|
||||
retry_in_csv += 1
|
||||
|
||||
pdf_path = single_loader.get_pdf_path(row)
|
||||
if not pdf_path:
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
row_dict = {
|
||||
'DocumentId': row.DocumentId,
|
||||
'InvoiceNumber': row.InvoiceNumber,
|
||||
@@ -334,33 +559,87 @@ def main():
|
||||
args.skip_ocr
|
||||
))
|
||||
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
|
||||
# Update tasks_submitted counter for limit tracking
|
||||
stats['tasks_submitted'] += len(tasks)
|
||||
|
||||
if use_dual_pool:
|
||||
# Dual-pool mode using pre-initialized DualPoolCoordinator
|
||||
# (process_text_pdf, process_scanned_pdf already imported above)
|
||||
|
||||
# Convert tasks to new format
|
||||
documents = []
|
||||
for task in tasks:
|
||||
row_dict, pdf_path_str, output_dir_str, dpi, min_confidence, skip_ocr = task
|
||||
# Pre-classify PDF type
|
||||
try:
|
||||
is_text = is_text_pdf(pdf_path_str)
|
||||
except Exception:
|
||||
is_text = False
|
||||
|
||||
documents.append({
|
||||
"id": row_dict["DocumentId"],
|
||||
"row_dict": row_dict,
|
||||
"pdf_path": pdf_path_str,
|
||||
"output_dir": output_dir_str,
|
||||
"dpi": dpi,
|
||||
"min_confidence": min_confidence,
|
||||
"is_scanned": not is_text,
|
||||
"has_text": is_text,
|
||||
"text_length": 1000 if is_text else 0, # Approximate
|
||||
})
|
||||
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
|
||||
def on_result(task_result):
|
||||
"""Handle successful result."""
|
||||
result = task_result.data
|
||||
handle_result(result)
|
||||
pbar.update(1)
|
||||
|
||||
def on_error(task_id, error):
|
||||
"""Handle failed task."""
|
||||
handle_error(task_id, error)
|
||||
pbar.update(1)
|
||||
|
||||
# Process with pre-initialized coordinator (workers stay alive)
|
||||
results = dual_pool_coordinator.process_batch(
|
||||
documents=documents,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=on_result,
|
||||
on_error=on_error,
|
||||
id_field="id",
|
||||
)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
|
||||
# Process documents in parallel
|
||||
processed_items = []
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
if args.workers == 1:
|
||||
for task in tqdm(tasks, desc="Processing"):
|
||||
result = process_single_document(task)
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
handle_result(result)
|
||||
else:
|
||||
# Parallel processing with worker initialization
|
||||
# Each worker initializes OCR engine once and reuses it
|
||||
@@ -372,67 +651,31 @@ def main():
|
||||
doc_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
|
||||
# Write report
|
||||
if result['report']:
|
||||
report_writer.write_dict(result['report'])
|
||||
|
||||
if result['success']:
|
||||
processed_items.append({
|
||||
'doc_id': result['doc_id'],
|
||||
'pages': result['pages']
|
||||
})
|
||||
stats['successful'] += 1
|
||||
for field, count in result['stats'].items():
|
||||
stats['by_field'][field] += count
|
||||
stats['annotations'] += count
|
||||
elif 'Skipped' in str(result.get('report', {}).get('errors', [])):
|
||||
stats['skipped'] += 1
|
||||
else:
|
||||
stats['failed'] += 1
|
||||
|
||||
handle_result(result)
|
||||
except Exception as e:
|
||||
stats['failed'] += 1
|
||||
# Write error report for failed documents
|
||||
error_report = {
|
||||
'document_id': doc_id,
|
||||
'success': False,
|
||||
'errors': [f"Worker error: {str(e)}"]
|
||||
}
|
||||
report_writer.write_dict(error_report)
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {e}")
|
||||
handle_error(doc_id, e)
|
||||
|
||||
# Split and move files
|
||||
import random
|
||||
random.seed(42)
|
||||
random.shuffle(processed_items)
|
||||
# Flush remaining database batch after each CSV
|
||||
if db_batch:
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
|
||||
n_train = int(len(processed_items) * args.train_ratio)
|
||||
n_val = int(len(processed_items) * args.val_ratio)
|
||||
finally:
|
||||
# Shutdown dual-pool coordinator if it was started
|
||||
if dual_pool_coordinator is not None:
|
||||
dual_pool_coordinator.shutdown()
|
||||
|
||||
splits = {
|
||||
'train': processed_items[:n_train],
|
||||
'val': processed_items[n_train:n_train + n_val],
|
||||
'test': processed_items[n_train + n_val:]
|
||||
}
|
||||
# Close temp file
|
||||
processed_items_writer.close()
|
||||
|
||||
import shutil
|
||||
for split_name, items in splits.items():
|
||||
for item in items:
|
||||
for page in item['pages']:
|
||||
# Move image
|
||||
image_path = Path(page['image_path'])
|
||||
label_path = Path(page['label_path'])
|
||||
dest_img = output_dir / split_name / 'images' / image_path.name
|
||||
shutil.move(str(image_path), str(dest_img))
|
||||
# Use the in-memory counter instead of re-reading the file (performance fix)
|
||||
# processed_count already tracks the number of successfully processed items
|
||||
|
||||
# Move label
|
||||
dest_label = output_dir / split_name / 'labels' / label_path.name
|
||||
shutil.move(str(label_path), str(dest_label))
|
||||
# Cleanup processed_items temp file (not needed anymore)
|
||||
processed_items_file.unlink(missing_ok=True)
|
||||
|
||||
# Cleanup temp
|
||||
shutil.rmtree(output_dir / 'temp', ignore_errors=True)
|
||||
# Close database connection
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
@@ -441,17 +684,22 @@ def main():
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (OCR): {stats['skipped']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nDataset split:")
|
||||
print(f" Train: {len(splits['train'])} documents")
|
||||
print(f" Val: {len(splits['val'])} documents")
|
||||
print(f" Test: {len(splits['test'])} documents")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
print(f"\nOutput: {output_dir}")
|
||||
print(f"Report: {args.report}")
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
262
src/cli/import_report_to_db.py
Normal file
262
src/cli/import_report_to_db.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Import existing JSONL report files into PostgreSQL database.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.import_report_to_db --report "reports/autolabel_report_v4*.jsonl"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string, PATHS
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
"""Create database tables."""
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS documents (
|
||||
document_id TEXT PRIMARY KEY,
|
||||
pdf_path TEXT,
|
||||
pdf_type TEXT,
|
||||
success BOOLEAN,
|
||||
total_pages INTEGER,
|
||||
fields_matched INTEGER,
|
||||
fields_total INTEGER,
|
||||
annotations_generated INTEGER,
|
||||
processing_time_ms REAL,
|
||||
timestamp TIMESTAMPTZ,
|
||||
errors JSONB DEFAULT '[]'
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS field_results (
|
||||
id SERIAL PRIMARY KEY,
|
||||
document_id TEXT NOT NULL REFERENCES documents(document_id) ON DELETE CASCADE,
|
||||
field_name TEXT,
|
||||
csv_value TEXT,
|
||||
matched BOOLEAN,
|
||||
score REAL,
|
||||
matched_text TEXT,
|
||||
candidate_used TEXT,
|
||||
bbox JSONB,
|
||||
page_no INTEGER,
|
||||
context_keywords JSONB DEFAULT '[]',
|
||||
error TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_success ON documents(success);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_document_id ON field_results(document_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_field_name ON field_results(field_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_field_results_matched ON field_results(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_size: int = 1000) -> dict:
|
||||
"""Import a single JSONL file into database."""
|
||||
stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
# Get existing document IDs if skipping
|
||||
existing_ids = set()
|
||||
if skip_existing:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents")
|
||||
existing_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
doc_batch = []
|
||||
field_batch = []
|
||||
|
||||
def flush_batches():
|
||||
nonlocal doc_batch, field_batch
|
||||
if doc_batch:
|
||||
with conn.cursor() as cursor:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors)
|
||||
VALUES %s
|
||||
ON CONFLICT (document_id) DO UPDATE SET
|
||||
pdf_path = EXCLUDED.pdf_path,
|
||||
pdf_type = EXCLUDED.pdf_type,
|
||||
success = EXCLUDED.success,
|
||||
total_pages = EXCLUDED.total_pages,
|
||||
fields_matched = EXCLUDED.fields_matched,
|
||||
fields_total = EXCLUDED.fields_total,
|
||||
annotations_generated = EXCLUDED.annotations_generated,
|
||||
processing_time_ms = EXCLUDED.processing_time_ms,
|
||||
timestamp = EXCLUDED.timestamp,
|
||||
errors = EXCLUDED.errors
|
||||
""", doc_batch)
|
||||
doc_batch = []
|
||||
|
||||
if field_batch:
|
||||
with conn.cursor() as cursor:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_batch)
|
||||
field_batch = []
|
||||
|
||||
conn.commit()
|
||||
|
||||
with open(jsonl_path, 'r', encoding='utf-8') as f:
|
||||
for line_no, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Warning: Line {line_no} - JSON parse error: {e}")
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
doc_id = record.get('document_id')
|
||||
if not doc_id:
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
# Only import successful documents
|
||||
if not record.get('success'):
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Check if already exists
|
||||
if skip_existing and doc_id in existing_ids:
|
||||
stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Add to batch
|
||||
doc_batch.append((
|
||||
doc_id,
|
||||
record.get('pdf_path'),
|
||||
record.get('pdf_type'),
|
||||
record.get('success'),
|
||||
record.get('total_pages'),
|
||||
record.get('fields_matched'),
|
||||
record.get('fields_total'),
|
||||
record.get('annotations_generated'),
|
||||
record.get('processing_time_ms'),
|
||||
record.get('timestamp'),
|
||||
json.dumps(record.get('errors', []))
|
||||
))
|
||||
|
||||
for field in record.get('field_results', []):
|
||||
field_batch.append((
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
))
|
||||
|
||||
stats['imported'] += 1
|
||||
existing_ids.add(doc_id)
|
||||
|
||||
# Flush batch if needed
|
||||
if len(doc_batch) >= batch_size:
|
||||
flush_batches()
|
||||
print(f" Processed {stats['imported'] + stats['skipped']} records...")
|
||||
|
||||
# Final flush
|
||||
flush_batches()
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Import JSONL reports to PostgreSQL database')
|
||||
parser.add_argument('--report', type=str, default=f"{PATHS['reports_dir']}/autolabel_report*.jsonl",
|
||||
help='Report file path or glob pattern')
|
||||
parser.add_argument('--db', type=str, default=None,
|
||||
help='PostgreSQL connection string (uses config.py if not specified)')
|
||||
parser.add_argument('--no-skip', action='store_true',
|
||||
help='Do not skip existing documents (replace them)')
|
||||
parser.add_argument('--batch-size', type=int, default=1000,
|
||||
help='Batch size for bulk inserts')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Use config if db not specified
|
||||
db_connection = args.db or get_db_connection_string()
|
||||
|
||||
# Find report files
|
||||
report_path = Path(args.report)
|
||||
if '*' in str(report_path) or '?' in str(report_path):
|
||||
parent = report_path.parent
|
||||
pattern = report_path.name
|
||||
report_files = sorted(parent.glob(pattern))
|
||||
else:
|
||||
report_files = [report_path] if report_path.exists() else []
|
||||
|
||||
if not report_files:
|
||||
print(f"No report files found: {args.report}")
|
||||
return
|
||||
|
||||
print(f"Found {len(report_files)} report file(s)")
|
||||
|
||||
# Connect to database
|
||||
conn = psycopg2.connect(db_connection)
|
||||
create_tables(conn)
|
||||
|
||||
# Import each file
|
||||
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
for report_file in report_files:
|
||||
print(f"\nImporting: {report_file.name}")
|
||||
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
|
||||
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
|
||||
|
||||
for key in total_stats:
|
||||
total_stats[key] += stats[key]
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Import Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total imported: {total_stats['imported']}")
|
||||
print(f"Total skipped: {total_stats['skipped']}")
|
||||
print(f"Total errors: {total_stats['errors']}")
|
||||
|
||||
# Quick stats from database
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT COUNT(*) FROM documents")
|
||||
total_docs = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM documents WHERE success = true")
|
||||
success_docs = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM field_results")
|
||||
total_fields = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM field_results WHERE matched = true")
|
||||
matched_fields = cursor.fetchone()[0]
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"\nDatabase Stats:")
|
||||
print(f" Documents: {total_docs} ({success_docs} successful)")
|
||||
print(f" Field results: {total_fields} ({matched_fields} matched)")
|
||||
if total_fields > 0:
|
||||
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
158
src/cli/serve.py
Normal file
158
src/cli/serve.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Web Server CLI
|
||||
|
||||
Command-line interface for starting the web server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Start the Invoice Field Extraction web server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to listen on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=Path,
|
||||
default=Path("runs/train/invoice_yolo11n_full/weights/best.pt"),
|
||||
help="Path to YOLO model weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Detection confidence threshold",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=150,
|
||||
help="DPI for PDF rendering",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-gpu",
|
||||
action="store_true",
|
||||
help="Disable GPU acceleration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
setup_logging(debug=args.debug)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validate model path
|
||||
if not args.model.exists():
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Invoice Field Extraction Web Server")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"Confidence threshold: {args.confidence}")
|
||||
logger.info(f"GPU enabled: {not args.no_gpu}")
|
||||
logger.info(f"Server: http://{args.host}:{args.port}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from src.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
use_gpu=not args.no_gpu,
|
||||
dpi=args.dpi,
|
||||
),
|
||||
server=ServerConfig(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
storage=StorageConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
import uvicorn
|
||||
from src.web.app import create_app
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
reload=config.server.reload,
|
||||
workers=config.server.workers if not config.server.reload else 1,
|
||||
log_level="debug" if config.server.debug else "info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
src/cli/train.py
135
src/cli/train.py
@@ -2,22 +2,26 @@
|
||||
"""
|
||||
Training CLI
|
||||
|
||||
Trains YOLO model on generated dataset.
|
||||
Trains YOLO model on dataset with labels from PostgreSQL database.
|
||||
Images are read from filesystem, labels are dynamically generated from DB.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import PATHS
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Train YOLO model for invoice field detection'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data', '-d',
|
||||
required=True,
|
||||
help='Path to dataset.yaml file'
|
||||
'--dataset-dir', '-d',
|
||||
default=PATHS['output_dir'],
|
||||
help='Dataset directory containing temp/{doc_id}/images/ (default: data/dataset)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
@@ -62,24 +66,117 @@ def main():
|
||||
help='Resume from checkpoint'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--config',
|
||||
help='Path to training config YAML'
|
||||
'--train-ratio',
|
||||
type=float,
|
||||
default=0.8,
|
||||
help='Training set ratio (default: 0.8)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--val-ratio',
|
||||
type=float,
|
||||
default=0.1,
|
||||
help='Validation set ratio (default: 0.1)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seed',
|
||||
type=int,
|
||||
default=42,
|
||||
help='Random seed for split (default: 42)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=300,
|
||||
help='DPI used for rendering (default: 300)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--export-only',
|
||||
action='store_true',
|
||||
help='Only export dataset to YOLO format, do not train'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--limit',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Limit number of documents for training (default: all)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate data file
|
||||
data_path = Path(args.data)
|
||||
if not data_path.exists():
|
||||
print(f"Error: Dataset file not found: {data_path}")
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Error: Temp directory not found: {temp_dir}")
|
||||
print("Run autolabel first to generate images.")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Training YOLO model for invoice field detection")
|
||||
print(f"Dataset: {args.data}")
|
||||
print("=" * 60)
|
||||
print("YOLO Training with Database Labels")
|
||||
print("=" * 60)
|
||||
print(f"Dataset dir: {dataset_dir}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch}")
|
||||
print(f"Image size: {args.imgsz}")
|
||||
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
|
||||
if args.limit:
|
||||
print(f"Document limit: {args.limit}")
|
||||
|
||||
# Connect to database
|
||||
from ..data.db import DocumentDB
|
||||
|
||||
print("\nConnecting to database...")
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create datasets from database
|
||||
from ..yolo.db_dataset import create_datasets
|
||||
|
||||
print("Loading dataset from database...")
|
||||
datasets = create_datasets(
|
||||
images_dir=dataset_dir,
|
||||
db=db,
|
||||
train_ratio=args.train_ratio,
|
||||
val_ratio=args.val_ratio,
|
||||
seed=args.seed,
|
||||
dpi=args.dpi,
|
||||
limit=args.limit
|
||||
)
|
||||
|
||||
print(f"\nDataset splits:")
|
||||
print(f" Train: {len(datasets['train'])} items")
|
||||
print(f" Val: {len(datasets['val'])} items")
|
||||
print(f" Test: {len(datasets['test'])} items")
|
||||
|
||||
if len(datasets['train']) == 0:
|
||||
print("\nError: No training data found!")
|
||||
print("Make sure autolabel has been run and images exist in temp directory.")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Export to YOLO format (required for Ultralytics training)
|
||||
print("\nExporting dataset to YOLO format...")
|
||||
for split_name, dataset in datasets.items():
|
||||
count = dataset.export_to_yolo_format(dataset_dir, split_name)
|
||||
print(f" {split_name}: {count} items exported")
|
||||
|
||||
# Generate YOLO config files
|
||||
from ..yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
|
||||
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
|
||||
|
||||
if args.export_only:
|
||||
print("\nExport complete (--export-only specified, skipping training)")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
@@ -91,8 +188,9 @@ def main():
|
||||
model = YOLO(args.model)
|
||||
|
||||
# Training arguments
|
||||
data_yaml = dataset_dir / 'dataset.yaml'
|
||||
train_args = {
|
||||
'data': str(data_path.absolute()),
|
||||
'data': str(data_yaml.absolute()),
|
||||
'epochs': args.epochs,
|
||||
'batch': args.batch,
|
||||
'imgsz': args.imgsz,
|
||||
@@ -121,18 +219,21 @@ def main():
|
||||
results = model.train(**train_args)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 50)
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 50)
|
||||
print("=" * 60)
|
||||
print(f"Best model: {args.project}/{args.name}/weights/best.pt")
|
||||
print(f"Last model: {args.project}/{args.name}/weights/last.pt")
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation...")
|
||||
metrics = model.val()
|
||||
print("\nRunning validation on test set...")
|
||||
metrics = model.val(split='test')
|
||||
print(f"mAP50: {metrics.box.map50:.4f}")
|
||||
print(f"mAP50-95: {metrics.box.map:.4f}")
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
||||
@@ -114,57 +114,106 @@ class AutoLabelReport:
|
||||
|
||||
|
||||
class ReportWriter:
|
||||
"""Writes auto-label reports to file."""
|
||||
"""Writes auto-label reports to file with optional sharding."""
|
||||
|
||||
def __init__(self, output_path: str | Path):
|
||||
def __init__(
|
||||
self,
|
||||
output_path: str | Path,
|
||||
max_records_per_file: int = 0
|
||||
):
|
||||
"""
|
||||
Initialize report writer.
|
||||
|
||||
Args:
|
||||
output_path: Path to output JSONL file
|
||||
output_path: Path to output JSONL file (base name if sharding)
|
||||
max_records_per_file: Max records per file (0 = no limit, single file)
|
||||
"""
|
||||
self.output_path = Path(output_path)
|
||||
self.output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.max_records_per_file = max_records_per_file
|
||||
|
||||
# Sharding state
|
||||
self._current_shard = 0
|
||||
self._records_in_current_shard = 0
|
||||
self._shard_files: list[Path] = []
|
||||
|
||||
def _get_shard_path(self) -> Path:
|
||||
"""Get the path for current shard."""
|
||||
if self.max_records_per_file > 0:
|
||||
base = self.output_path.stem
|
||||
suffix = self.output_path.suffix
|
||||
shard_path = self.output_path.parent / f"{base}_part{self._current_shard:03d}{suffix}"
|
||||
else:
|
||||
shard_path = self.output_path
|
||||
|
||||
if shard_path not in self._shard_files:
|
||||
self._shard_files.append(shard_path)
|
||||
|
||||
return shard_path
|
||||
|
||||
def _check_shard_rotation(self) -> None:
|
||||
"""Check if we need to rotate to a new shard file."""
|
||||
if self.max_records_per_file > 0:
|
||||
if self._records_in_current_shard >= self.max_records_per_file:
|
||||
self._current_shard += 1
|
||||
self._records_in_current_shard = 0
|
||||
|
||||
def write(self, report: AutoLabelReport) -> None:
|
||||
"""Append a report to the output file."""
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(report.to_json() + '\n')
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_dict(self, report_dict: dict) -> None:
|
||||
"""Append a report dict to the output file (for parallel processing)."""
|
||||
import json
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
self._check_shard_rotation()
|
||||
shard_path = self._get_shard_path()
|
||||
with open(shard_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(report_dict, ensure_ascii=False) + '\n')
|
||||
f.flush()
|
||||
self._records_in_current_shard += 1
|
||||
|
||||
def write_batch(self, reports: list[AutoLabelReport]) -> None:
|
||||
"""Write multiple reports."""
|
||||
with open(self.output_path, 'a', encoding='utf-8') as f:
|
||||
for report in reports:
|
||||
f.write(report.to_json() + '\n')
|
||||
self.write(report)
|
||||
|
||||
def get_shard_files(self) -> list[Path]:
|
||||
"""Get list of all shard files created."""
|
||||
return self._shard_files.copy()
|
||||
|
||||
|
||||
class ReportReader:
|
||||
"""Reads auto-label reports from file."""
|
||||
"""Reads auto-label reports from file(s)."""
|
||||
|
||||
def __init__(self, input_path: str | Path):
|
||||
"""
|
||||
Initialize report reader.
|
||||
|
||||
Args:
|
||||
input_path: Path to input JSONL file
|
||||
input_path: Path to input JSONL file or glob pattern (e.g., 'reports/*.jsonl')
|
||||
"""
|
||||
self.input_path = Path(input_path)
|
||||
|
||||
# Handle glob pattern
|
||||
if '*' in str(input_path) or '?' in str(input_path):
|
||||
parent = self.input_path.parent
|
||||
pattern = self.input_path.name
|
||||
self.input_paths = sorted(parent.glob(pattern))
|
||||
else:
|
||||
self.input_paths = [self.input_path]
|
||||
|
||||
def read_all(self) -> list[AutoLabelReport]:
|
||||
"""Read all reports from file."""
|
||||
"""Read all reports from file(s)."""
|
||||
reports = []
|
||||
|
||||
if not self.input_path.exists():
|
||||
return reports
|
||||
for input_path in self.input_paths:
|
||||
if not input_path.exists():
|
||||
continue
|
||||
|
||||
with open(self.input_path, 'r', encoding='utf-8') as f:
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
|
||||
@@ -72,7 +72,7 @@ class CSVLoader:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
csv_path: str | Path,
|
||||
csv_path: str | Path | list[str | Path],
|
||||
pdf_dir: str | Path | None = None,
|
||||
doc_map_path: str | Path | None = None,
|
||||
encoding: str = 'utf-8'
|
||||
@@ -81,13 +81,31 @@ class CSVLoader:
|
||||
Initialize CSV loader.
|
||||
|
||||
Args:
|
||||
csv_path: Path to the CSV file
|
||||
csv_path: Path to CSV file(s). Can be:
|
||||
- Single path: 'data/file.csv'
|
||||
- List of paths: ['data/file1.csv', 'data/file2.csv']
|
||||
- Glob pattern: 'data/*.csv' or 'data/export_*.csv'
|
||||
pdf_dir: Directory containing PDF files (default: data/raw_pdfs)
|
||||
doc_map_path: Optional path to document mapping CSV
|
||||
encoding: CSV file encoding (default: utf-8)
|
||||
"""
|
||||
self.csv_path = Path(csv_path)
|
||||
self.pdf_dir = Path(pdf_dir) if pdf_dir else self.csv_path.parent.parent / 'raw_pdfs'
|
||||
# Handle multiple CSV files
|
||||
if isinstance(csv_path, list):
|
||||
self.csv_paths = [Path(p) for p in csv_path]
|
||||
else:
|
||||
csv_path = Path(csv_path)
|
||||
# Check if it's a glob pattern (contains * or ?)
|
||||
if '*' in str(csv_path) or '?' in str(csv_path):
|
||||
parent = csv_path.parent
|
||||
pattern = csv_path.name
|
||||
self.csv_paths = sorted(parent.glob(pattern))
|
||||
else:
|
||||
self.csv_paths = [csv_path]
|
||||
|
||||
# For backward compatibility
|
||||
self.csv_path = self.csv_paths[0] if self.csv_paths else None
|
||||
|
||||
self.pdf_dir = Path(pdf_dir) if pdf_dir else (self.csv_path.parent.parent / 'raw_pdfs' if self.csv_path else Path('data/raw_pdfs'))
|
||||
self.doc_map_path = Path(doc_map_path) if doc_map_path else None
|
||||
self.encoding = encoding
|
||||
|
||||
@@ -185,21 +203,14 @@ class CSVLoader:
|
||||
raw_data=dict(row)
|
||||
)
|
||||
|
||||
def load_all(self) -> list[InvoiceRow]:
|
||||
"""Load all rows from CSV."""
|
||||
rows = []
|
||||
for row in self.iter_rows():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def iter_rows(self) -> Iterator[InvoiceRow]:
|
||||
"""Iterate over CSV rows."""
|
||||
def _iter_single_csv(self, csv_path: Path) -> Iterator[InvoiceRow]:
|
||||
"""Iterate over rows from a single CSV file."""
|
||||
# Handle BOM - try utf-8-sig first to handle BOM correctly
|
||||
encodings = ['utf-8-sig', self.encoding, 'latin-1']
|
||||
|
||||
for enc in encodings:
|
||||
try:
|
||||
with open(self.csv_path, 'r', encoding=enc) as f:
|
||||
with open(csv_path, 'r', encoding=enc) as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
parsed = self._parse_row(row)
|
||||
@@ -209,7 +220,27 @@ class CSVLoader:
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
|
||||
raise ValueError(f"Could not read CSV file with any supported encoding")
|
||||
raise ValueError(f"Could not read CSV file {csv_path} with any supported encoding")
|
||||
|
||||
def load_all(self) -> list[InvoiceRow]:
|
||||
"""Load all rows from CSV(s)."""
|
||||
rows = []
|
||||
for row in self.iter_rows():
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def iter_rows(self) -> Iterator[InvoiceRow]:
|
||||
"""Iterate over CSV rows from all CSV files."""
|
||||
seen_doc_ids = set()
|
||||
|
||||
for csv_path in self.csv_paths:
|
||||
if not csv_path.exists():
|
||||
continue
|
||||
for row in self._iter_single_csv(csv_path):
|
||||
# Deduplicate by DocumentId
|
||||
if row.DocumentId not in seen_doc_ids:
|
||||
seen_doc_ids.add(row.DocumentId)
|
||||
yield row
|
||||
|
||||
def get_pdf_path(self, invoice_row: InvoiceRow) -> Path | None:
|
||||
"""
|
||||
@@ -300,7 +331,7 @@ class CSVLoader:
|
||||
return issues
|
||||
|
||||
|
||||
def load_invoice_csv(csv_path: str | Path, pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
|
||||
"""Convenience function to load invoice CSV."""
|
||||
def load_invoice_csv(csv_path: str | Path | list[str | Path], pdf_dir: str | Path | None = None) -> list[InvoiceRow]:
|
||||
"""Convenience function to load invoice CSV(s)."""
|
||||
loader = CSVLoader(csv_path, pdf_dir)
|
||||
return loader.load_all()
|
||||
|
||||
429
src/data/db.py
Normal file
429
src/data/db.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""
|
||||
Database utilities for autolabel workflow.
|
||||
"""
|
||||
|
||||
import json
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||
from config import get_db_connection_string
|
||||
|
||||
|
||||
class DocumentDB:
|
||||
"""Database interface for document processing status."""
|
||||
|
||||
def __init__(self, connection_string: str = None):
|
||||
self.connection_string = connection_string or get_db_connection_string()
|
||||
self.conn = None
|
||||
|
||||
def connect(self):
|
||||
"""Connect to database."""
|
||||
if self.conn is None:
|
||||
self.conn = psycopg2.connect(self.connection_string)
|
||||
return self.conn
|
||||
|
||||
def close(self):
|
||||
"""Close database connection."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def __enter__(self):
|
||||
self.connect()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
|
||||
def get_successful_doc_ids(self) -> Set[str]:
|
||||
"""Get all document IDs that have been successfully processed."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents WHERE success = true")
|
||||
return {row[0] for row in cursor.fetchall()}
|
||||
|
||||
def get_failed_doc_ids(self) -> Set[str]:
|
||||
"""Get all document IDs that failed processing."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("SELECT document_id FROM documents WHERE success = false")
|
||||
return {row[0] for row in cursor.fetchall()}
|
||||
|
||||
def check_document_status(self, doc_id: str) -> Optional[bool]:
|
||||
"""
|
||||
Check if a document exists and its success status.
|
||||
|
||||
Returns:
|
||||
True if exists and success=true
|
||||
False if exists and success=false
|
||||
None if not exists
|
||||
"""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT success FROM documents WHERE document_id = %s",
|
||||
(doc_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return row[0]
|
||||
|
||||
def check_documents_status_batch(self, doc_ids: list[str]) -> Dict[str, Optional[bool]]:
|
||||
"""
|
||||
Batch check document status for multiple IDs.
|
||||
|
||||
Returns:
|
||||
Dict mapping doc_id to status:
|
||||
True if exists and success=true
|
||||
False if exists and success=false
|
||||
(missing from dict if not exists)
|
||||
"""
|
||||
if not doc_ids:
|
||||
return {}
|
||||
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT document_id, success FROM documents WHERE document_id = ANY(%s)",
|
||||
(doc_ids,)
|
||||
)
|
||||
return {row[0]: row[1] for row in cursor.fetchall()}
|
||||
|
||||
def delete_document(self, doc_id: str):
|
||||
"""Delete a document and its field results (for re-processing)."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# field_results will be cascade deleted
|
||||
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
|
||||
conn.commit()
|
||||
|
||||
def get_document(self, doc_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a single document with its field results."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
# Get document
|
||||
cursor.execute("""
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors
|
||||
FROM documents WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
||||
doc = {
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6],
|
||||
'annotations_generated': row[7],
|
||||
'processing_time_ms': row[8],
|
||||
'timestamp': str(row[9]) if row[9] else None,
|
||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||
'field_results': []
|
||||
}
|
||||
|
||||
# Get field results
|
||||
cursor.execute("""
|
||||
SELECT field_name, csv_value, matched, score, matched_text,
|
||||
candidate_used, bbox, page_no, context_keywords, error
|
||||
FROM field_results WHERE document_id = %s
|
||||
""", (doc_id,))
|
||||
|
||||
for fr in cursor.fetchall():
|
||||
doc['field_results'].append({
|
||||
'field_name': fr[0],
|
||||
'csv_value': fr[1],
|
||||
'matched': fr[2],
|
||||
'score': fr[3],
|
||||
'matched_text': fr[4],
|
||||
'candidate_used': fr[5],
|
||||
'bbox': fr[6] if isinstance(fr[6], list) else json.loads(fr[6]) if fr[6] else None,
|
||||
'page_no': fr[7],
|
||||
'context_keywords': fr[8] if isinstance(fr[8], list) else json.loads(fr[8] or '[]'),
|
||||
'error': fr[9]
|
||||
})
|
||||
|
||||
return doc
|
||||
|
||||
def get_all_documents_summary(self, success_only: bool = False, limit: int = None) -> list[Dict[str, Any]]:
|
||||
"""Get summary of all documents (without field_results for efficiency)."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total
|
||||
FROM documents
|
||||
"""
|
||||
if success_only:
|
||||
query += " WHERE success = true"
|
||||
query += " ORDER BY timestamp DESC"
|
||||
if limit:
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query)
|
||||
return [
|
||||
{
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6]
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_field_stats(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Get match statistics per field."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
SELECT field_name,
|
||||
COUNT(*) as total,
|
||||
SUM(CASE WHEN matched THEN 1 ELSE 0 END) as matched
|
||||
FROM field_results
|
||||
GROUP BY field_name
|
||||
ORDER BY field_name
|
||||
""")
|
||||
return {
|
||||
row[0]: {'total': row[1], 'matched': row[2]}
|
||||
for row in cursor.fetchall()
|
||||
}
|
||||
|
||||
def get_failed_matches(self, field_name: str = None, limit: int = 100) -> list[Dict[str, Any]]:
|
||||
"""Get field results that failed to match."""
|
||||
conn = self.connect()
|
||||
with conn.cursor() as cursor:
|
||||
query = """
|
||||
SELECT fr.document_id, fr.field_name, fr.csv_value, fr.error,
|
||||
d.pdf_type
|
||||
FROM field_results fr
|
||||
JOIN documents d ON fr.document_id = d.document_id
|
||||
WHERE fr.matched = false
|
||||
"""
|
||||
params = []
|
||||
if field_name:
|
||||
query += " AND fr.field_name = %s"
|
||||
params.append(field_name)
|
||||
query += f" LIMIT {limit}"
|
||||
|
||||
cursor.execute(query, params)
|
||||
return [
|
||||
{
|
||||
'document_id': row[0],
|
||||
'field_name': row[1],
|
||||
'csv_value': row[2],
|
||||
'error': row[3],
|
||||
'pdf_type': row[4]
|
||||
}
|
||||
for row in cursor.fetchall()
|
||||
]
|
||||
|
||||
def get_documents_batch(self, doc_ids: list[str]) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
Get multiple documents with their field results in a single batch query.
|
||||
|
||||
This is much more efficient than calling get_document() in a loop.
|
||||
|
||||
Args:
|
||||
doc_ids: List of document IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict mapping doc_id to document data (with field_results)
|
||||
"""
|
||||
if not doc_ids:
|
||||
return {}
|
||||
|
||||
conn = self.connect()
|
||||
result: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Batch fetch all documents
|
||||
cursor.execute("""
|
||||
SELECT document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors
|
||||
FROM documents WHERE document_id = ANY(%s)
|
||||
""", (doc_ids,))
|
||||
|
||||
for row in cursor.fetchall():
|
||||
result[row[0]] = {
|
||||
'document_id': row[0],
|
||||
'pdf_path': row[1],
|
||||
'pdf_type': row[2],
|
||||
'success': row[3],
|
||||
'total_pages': row[4],
|
||||
'fields_matched': row[5],
|
||||
'fields_total': row[6],
|
||||
'annotations_generated': row[7],
|
||||
'processing_time_ms': row[8],
|
||||
'timestamp': str(row[9]) if row[9] else None,
|
||||
'errors': row[10] if isinstance(row[10], list) else json.loads(row[10] or '[]'),
|
||||
'field_results': []
|
||||
}
|
||||
|
||||
if not result:
|
||||
return {}
|
||||
|
||||
# Batch fetch all field results for these documents
|
||||
cursor.execute("""
|
||||
SELECT document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error
|
||||
FROM field_results WHERE document_id = ANY(%s)
|
||||
""", (list(result.keys()),))
|
||||
|
||||
for fr in cursor.fetchall():
|
||||
doc_id = fr[0]
|
||||
if doc_id in result:
|
||||
result[doc_id]['field_results'].append({
|
||||
'field_name': fr[1],
|
||||
'csv_value': fr[2],
|
||||
'matched': fr[3],
|
||||
'score': fr[4],
|
||||
'matched_text': fr[5],
|
||||
'candidate_used': fr[6],
|
||||
'bbox': fr[7] if isinstance(fr[7], list) else json.loads(fr[7]) if fr[7] else None,
|
||||
'page_no': fr[8],
|
||||
'context_keywords': fr[9] if isinstance(fr[9], list) else json.loads(fr[9] or '[]'),
|
||||
'error': fr[10]
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def save_document(self, report: Dict[str, Any]):
|
||||
"""Save or update a document and its field results using batch operations."""
|
||||
conn = self.connect()
|
||||
doc_id = report.get('document_id')
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Delete existing record if any (for update)
|
||||
cursor.execute("DELETE FROM documents WHERE document_id = %s", (doc_id,))
|
||||
|
||||
# Insert document
|
||||
cursor.execute("""
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""", (
|
||||
doc_id,
|
||||
report.get('pdf_path'),
|
||||
report.get('pdf_type'),
|
||||
report.get('success'),
|
||||
report.get('total_pages'),
|
||||
report.get('fields_matched'),
|
||||
report.get('fields_total'),
|
||||
report.get('annotations_generated'),
|
||||
report.get('processing_time_ms'),
|
||||
report.get('timestamp'),
|
||||
json.dumps(report.get('errors', []))
|
||||
))
|
||||
|
||||
# Batch insert field results using execute_values
|
||||
field_results = report.get('field_results', [])
|
||||
if field_results:
|
||||
field_values = [
|
||||
(
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
)
|
||||
for field in field_results
|
||||
]
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_values)
|
||||
|
||||
conn.commit()
|
||||
|
||||
def save_documents_batch(self, reports: list[Dict[str, Any]]):
|
||||
"""Save multiple documents in a batch."""
|
||||
if not reports:
|
||||
return
|
||||
|
||||
conn = self.connect()
|
||||
doc_ids = [r['document_id'] for r in reports]
|
||||
|
||||
with conn.cursor() as cursor:
|
||||
# Delete existing records
|
||||
cursor.execute(
|
||||
"DELETE FROM documents WHERE document_id = ANY(%s)",
|
||||
(doc_ids,)
|
||||
)
|
||||
|
||||
# Batch insert documents
|
||||
doc_values = [
|
||||
(
|
||||
r.get('document_id'),
|
||||
r.get('pdf_path'),
|
||||
r.get('pdf_type'),
|
||||
r.get('success'),
|
||||
r.get('total_pages'),
|
||||
r.get('fields_matched'),
|
||||
r.get('fields_total'),
|
||||
r.get('annotations_generated'),
|
||||
r.get('processing_time_ms'),
|
||||
r.get('timestamp'),
|
||||
json.dumps(r.get('errors', []))
|
||||
)
|
||||
for r in reports
|
||||
]
|
||||
execute_values(cursor, """
|
||||
INSERT INTO documents
|
||||
(document_id, pdf_path, pdf_type, success, total_pages,
|
||||
fields_matched, fields_total, annotations_generated,
|
||||
processing_time_ms, timestamp, errors)
|
||||
VALUES %s
|
||||
""", doc_values)
|
||||
|
||||
# Batch insert field results
|
||||
field_values = []
|
||||
for r in reports:
|
||||
doc_id = r.get('document_id')
|
||||
for field in r.get('field_results', []):
|
||||
field_values.append((
|
||||
doc_id,
|
||||
field.get('field_name'),
|
||||
field.get('csv_value'),
|
||||
field.get('matched'),
|
||||
field.get('score'),
|
||||
field.get('matched_text'),
|
||||
field.get('candidate_used'),
|
||||
json.dumps(field.get('bbox')) if field.get('bbox') else None,
|
||||
field.get('page_no'),
|
||||
json.dumps(field.get('context_keywords', [])),
|
||||
field.get('error')
|
||||
))
|
||||
|
||||
if field_values:
|
||||
execute_values(cursor, """
|
||||
INSERT INTO field_results
|
||||
(document_id, field_name, csv_value, matched, score,
|
||||
matched_text, candidate_used, bbox, page_no, context_keywords, error)
|
||||
VALUES %s
|
||||
""", field_values)
|
||||
|
||||
conn.commit()
|
||||
@@ -72,7 +72,7 @@ class FieldExtractor:
|
||||
"""Lazy-load OCR engine only when needed."""
|
||||
if self._ocr_engine is None:
|
||||
from ..ocr import OCREngine
|
||||
self._ocr_engine = OCREngine(lang=self.ocr_lang, use_gpu=self.use_gpu)
|
||||
self._ocr_engine = OCREngine(lang=self.ocr_lang)
|
||||
return self._ocr_engine
|
||||
|
||||
def extract_from_detection_with_pdf(
|
||||
@@ -290,31 +290,65 @@ class FieldExtractor:
|
||||
|
||||
def _normalize_amount(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize monetary amount."""
|
||||
# Remove currency and common suffixes
|
||||
text = re.sub(r'[SEK|kr|:-]+', '', text, flags=re.IGNORECASE)
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
# Try to extract amount using regex patterns
|
||||
# Pattern 1: Number with comma as decimal (Swedish format: 1 234,56)
|
||||
# Pattern 2: Number with dot as decimal (1234.56)
|
||||
# Pattern 3: Number followed by currency (275,60 kr or 275.60 SEK)
|
||||
|
||||
patterns = [
|
||||
# Swedish format with space thousand separator: 1 234,56 or 1234,56
|
||||
r'(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?',
|
||||
# Simple decimal: 350.00 or 350,00
|
||||
r'(\d+[,\.]\d{2})',
|
||||
# Integer amount
|
||||
r'(\d{2,})',
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, text, re.IGNORECASE)
|
||||
if matches:
|
||||
# Take the last match (usually the total amount)
|
||||
amount_str = matches[-1]
|
||||
# Clean up
|
||||
amount_str = amount_str.replace(' ', '').replace('\xa0', '')
|
||||
# Handle comma as decimal separator
|
||||
if ',' in text and '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
if ',' in amount_str:
|
||||
amount_str = amount_str.replace(',', '.')
|
||||
|
||||
# Try to parse as float
|
||||
try:
|
||||
amount = float(text)
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
return f"{amount:.2f}", True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None, False, f"Cannot parse amount: {text}"
|
||||
|
||||
def _normalize_date(self, text: str) -> tuple[str | None, bool, str | None]:
|
||||
"""Normalize date."""
|
||||
"""
|
||||
Normalize date from text that may contain surrounding text.
|
||||
|
||||
Handles various date formats found in Swedish invoices:
|
||||
- 2025-08-29 (ISO format)
|
||||
- 2025.08.29 (dot separator)
|
||||
- 29/08/2025 (European format)
|
||||
- 29.08.2025 (European with dots)
|
||||
- 20250829 (compact format)
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
# Common date patterns
|
||||
# Common date patterns - order matters, most specific first
|
||||
patterns = [
|
||||
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m[1]}-{int(m[2]):02d}-{int(m[3]):02d}"),
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m[3]}-{int(m[2]):02d}-{int(m[1]):02d}"),
|
||||
(r'(\d{4})(\d{2})(\d{2})', lambda m: f"{m[1]}-{m[2]}-{m[3]}"),
|
||||
# ISO format: 2025-08-29
|
||||
(r'(\d{4})-(\d{1,2})-(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
||||
# Dot format: 2025.08.29 (common in Swedish)
|
||||
(r'(\d{4})\.(\d{1,2})\.(\d{1,2})', lambda m: f"{m.group(1)}-{int(m.group(2)):02d}-{int(m.group(3)):02d}"),
|
||||
# European slash format: 29/08/2025
|
||||
(r'(\d{1,2})/(\d{1,2})/(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
||||
# European dot format: 29.08.2025
|
||||
(r'(\d{1,2})\.(\d{1,2})\.(\d{4})', lambda m: f"{m.group(3)}-{int(m.group(2)):02d}-{int(m.group(1)):02d}"),
|
||||
# Compact format: 20250829
|
||||
(r'(?<!\d)(\d{4})(\d{2})(\d{2})(?!\d)', lambda m: f"{m.group(1)}-{m.group(2)}-{m.group(3)}"),
|
||||
]
|
||||
|
||||
for pattern, formatter in patterns:
|
||||
@@ -323,7 +357,9 @@ class FieldExtractor:
|
||||
try:
|
||||
date_str = formatter(match)
|
||||
# Validate date
|
||||
datetime.strptime(date_str, '%Y-%m-%d')
|
||||
parsed_date = datetime.strptime(date_str, '%Y-%m-%d')
|
||||
# Sanity check: year should be reasonable (2000-2100)
|
||||
if 2000 <= parsed_date.year <= 2100:
|
||||
return date_str, True, None
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
@@ -4,9 +4,16 @@ Field Matching Module
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
@@ -16,6 +23,93 @@ class TokenLike(Protocol):
|
||||
page_no: int
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
@@ -57,18 +151,20 @@ class FieldMatcher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_radius: float = 100.0, # pixels
|
||||
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
|
||||
min_score_threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize the matcher.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords
|
||||
context_radius: Distance to search for context keywords (default 200px to handle
|
||||
typical label-value spacing in scanned invoices at 150 DPI)
|
||||
min_score_threshold: Minimum score to consider a match valid
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
@@ -92,6 +188,9 @@ class FieldMatcher:
|
||||
matches = []
|
||||
page_tokens = [t for t in tokens if t.page_no == page_no]
|
||||
|
||||
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
|
||||
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||
@@ -108,7 +207,7 @@ class FieldMatcher:
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
@@ -124,6 +223,9 @@ class FieldMatcher:
|
||||
matches = self._deduplicate_matches(matches)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Clear token index to free memory
|
||||
self._token_index = None
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _find_exact_matches(
|
||||
@@ -134,6 +236,8 @@ class FieldMatcher:
|
||||
) -> list[Match]:
|
||||
"""Find tokens that exactly match the value."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
@@ -141,13 +245,12 @@ class FieldMatcher:
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match
|
||||
elif token_text.lower() == value.lower():
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
|
||||
token_digits = re.sub(r'\D', '', token_text)
|
||||
value_digits = re.sub(r'\D', '', value)
|
||||
elif value_digits is not None:
|
||||
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
@@ -181,7 +284,7 @@ class FieldMatcher:
|
||||
) -> list[Match]:
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
matches = []
|
||||
value_clean = re.sub(r'\s+', '', value)
|
||||
value_clean = _WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
@@ -213,7 +316,7 @@ class FieldMatcher:
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = re.sub(r'\s+', '', concat_text)
|
||||
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, start_token, field_name
|
||||
@@ -252,7 +355,7 @@ class FieldMatcher:
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro')
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
@@ -390,13 +493,12 @@ class FieldMatcher:
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token
|
||||
for match in date_pattern.finditer(token_text):
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in _DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
@@ -491,10 +593,28 @@ class FieldMatcher:
|
||||
target_token: TokenLike,
|
||||
field_name: str
|
||||
) -> tuple[list[str], float]:
|
||||
"""Find context keywords near the target token."""
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if self._token_index:
|
||||
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = self._token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
@@ -509,7 +629,6 @@ class FieldMatcher:
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
# Calculate distance
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
@@ -522,7 +641,8 @@ class FieldMatcher:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
boost = min(0.15, len(found_keywords) * 0.05)
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
|
||||
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||
@@ -548,23 +668,62 @@ class FieldMatcher:
|
||||
return None
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""Remove duplicate matches based on bbox overlap."""
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
|
||||
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
|
||||
"""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Sort by score descending
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
# Sort by: 1) score descending, 2) prefer matches with context keywords,
|
||||
# 3) prefer upper positions (smaller y) for same-score matches
|
||||
# This helps select the "main" occurrence in invoice body rather than footer
|
||||
matches.sort(key=lambda m: (
|
||||
-m.score,
|
||||
-len(m.context_keywords), # More keywords = better
|
||||
m.bbox[1] # Smaller y (upper position) = better
|
||||
))
|
||||
|
||||
# Use spatial grid for efficient overlap checking
|
||||
# Grid cell size based on typical bbox size
|
||||
grid_size = 50.0 # pixels
|
||||
grid: dict[tuple[int, int], list[Match]] = {}
|
||||
unique = []
|
||||
|
||||
for match in matches:
|
||||
bbox = match.bbox
|
||||
# Calculate grid cells this bbox touches
|
||||
min_gx = int(bbox[0] / grid_size)
|
||||
min_gy = int(bbox[1] / grid_size)
|
||||
max_gx = int(bbox[2] / grid_size)
|
||||
max_gy = int(bbox[3] / grid_size)
|
||||
|
||||
# Check for overlap only with matches in nearby grid cells
|
||||
is_duplicate = False
|
||||
for existing in unique:
|
||||
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7:
|
||||
cells_to_check = set()
|
||||
for gx in range(min_gx - 1, max_gx + 2):
|
||||
for gy in range(min_gy - 1, max_gy + 2):
|
||||
cells_to_check.add((gx, gy))
|
||||
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
unique.append(match)
|
||||
# Add to all grid cells this bbox touches
|
||||
for gx in range(min_gx, max_gx + 1):
|
||||
for gy in range(min_gy, max_gy + 1):
|
||||
key = (gx, gy)
|
||||
if key not in grid:
|
||||
grid[key] = []
|
||||
grid[key].append(match)
|
||||
|
||||
return unique
|
||||
|
||||
@@ -582,9 +741,9 @@ class FieldMatcher:
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
||||
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
@@ -173,12 +173,29 @@ class FieldNormalizer:
|
||||
|
||||
# Integer if no decimals
|
||||
if num == int(num):
|
||||
variants.append(str(int(num)))
|
||||
variants.append(f"{int(num)},00")
|
||||
variants.append(f"{int(num)}.00")
|
||||
int_val = int(num)
|
||||
variants.append(str(int_val))
|
||||
variants.append(f"{int_val},00")
|
||||
variants.append(f"{int_val}.00")
|
||||
|
||||
# European format with dot as thousand separator (e.g., 20.485,00)
|
||||
if int_val >= 1000:
|
||||
# Format: XX.XXX,XX
|
||||
formatted = f"{int_val:,}".replace(',', '.')
|
||||
variants.append(formatted) # 20.485
|
||||
variants.append(f"{formatted},00") # 20.485,00
|
||||
else:
|
||||
variants.append(f"{num:.2f}")
|
||||
variants.append(f"{num:.2f}".replace('.', ','))
|
||||
|
||||
# European format with dot as thousand separator
|
||||
if num >= 1000:
|
||||
# Split integer and decimal parts
|
||||
int_part = int(num)
|
||||
dec_part = num - int_part
|
||||
formatted_int = f"{int_part:,}".replace(',', '.')
|
||||
formatted = f"{formatted_int},{dec_part:.2f}"[2:] # Remove "0."
|
||||
variants.append(f"{formatted_int},{int(dec_part * 100):02d}") # 20.485,00
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
@@ -247,9 +264,35 @@ class FieldNormalizer:
|
||||
iso = parsed_date.strftime('%Y-%m-%d')
|
||||
eu_slash = parsed_date.strftime('%d/%m/%Y')
|
||||
eu_dot = parsed_date.strftime('%d.%m.%Y')
|
||||
compact = parsed_date.strftime('%Y%m%d')
|
||||
compact = parsed_date.strftime('%Y%m%d') # YYYYMMDD
|
||||
compact_short = parsed_date.strftime('%y%m%d') # YYMMDD (e.g., 260108)
|
||||
|
||||
variants.extend([iso, eu_slash, eu_dot, compact])
|
||||
# Short year with dot separator (e.g., 02.01.26)
|
||||
eu_dot_short = parsed_date.strftime('%d.%m.%y')
|
||||
|
||||
# Spaced formats (e.g., "2026 01 12", "26 01 12")
|
||||
spaced_full = parsed_date.strftime('%Y %m %d')
|
||||
spaced_short = parsed_date.strftime('%y %m %d')
|
||||
|
||||
# Swedish month name formats (e.g., "9 januari 2026", "9 jan 2026")
|
||||
swedish_months_full = [
|
||||
'januari', 'februari', 'mars', 'april', 'maj', 'juni',
|
||||
'juli', 'augusti', 'september', 'oktober', 'november', 'december'
|
||||
]
|
||||
swedish_months_abbrev = [
|
||||
'jan', 'feb', 'mar', 'apr', 'maj', 'jun',
|
||||
'jul', 'aug', 'sep', 'okt', 'nov', 'dec'
|
||||
]
|
||||
month_full = swedish_months_full[parsed_date.month - 1]
|
||||
month_abbrev = swedish_months_abbrev[parsed_date.month - 1]
|
||||
swedish_format_full = f"{parsed_date.day} {month_full} {parsed_date.year}"
|
||||
swedish_format_abbrev = f"{parsed_date.day} {month_abbrev} {parsed_date.year}"
|
||||
|
||||
variants.extend([
|
||||
iso, eu_slash, eu_dot, compact, compact_short,
|
||||
eu_dot_short, spaced_full, spaced_short,
|
||||
swedish_format_full, swedish_format_abbrev
|
||||
])
|
||||
|
||||
return list(set(v for v in variants if v))
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .paddle_ocr import OCREngine, extract_ocr_tokens
|
||||
from .paddle_ocr import OCREngine, OCRResult, OCRToken, extract_ocr_tokens
|
||||
|
||||
__all__ = ['OCREngine', 'extract_ocr_tokens']
|
||||
__all__ = ['OCREngine', 'OCRResult', 'OCRToken', 'extract_ocr_tokens']
|
||||
|
||||
@@ -4,11 +4,18 @@ OCR Extraction Module using PaddleOCR
|
||||
Extracts text tokens with bounding boxes from scanned PDFs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
import numpy as np
|
||||
|
||||
# Suppress PaddlePaddle reinitialization warnings
|
||||
os.environ.setdefault('GLOG_minloglevel', '2')
|
||||
warnings.filterwarnings('ignore', message='.*PDX has already been initialized.*')
|
||||
warnings.filterwarnings('ignore', message='.*reinitialization.*')
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRToken:
|
||||
@@ -39,13 +46,19 @@ class OCRToken:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OCRResult:
|
||||
"""Result from OCR extraction including tokens and preprocessed image."""
|
||||
tokens: list[OCRToken]
|
||||
output_img: np.ndarray | None = None # Preprocessed image from PaddleOCR
|
||||
|
||||
|
||||
class OCREngine:
|
||||
"""PaddleOCR wrapper for text extraction."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lang: str = "en",
|
||||
use_gpu: bool = True, # Default to GPU for better performance
|
||||
det_model_dir: str | None = None,
|
||||
rec_model_dir: str | None = None
|
||||
):
|
||||
@@ -54,17 +67,21 @@ class OCREngine:
|
||||
|
||||
Args:
|
||||
lang: Language code ('en', 'sv', 'ch', etc.)
|
||||
use_gpu: Whether to use GPU acceleration (default: True)
|
||||
det_model_dir: Custom detection model directory
|
||||
rec_model_dir: Custom recognition model directory
|
||||
|
||||
Note:
|
||||
PaddleOCR 3.x automatically uses GPU if available via PaddlePaddle.
|
||||
Use `paddle.set_device('gpu')` before initialization to force GPU.
|
||||
"""
|
||||
# Suppress warnings during import and initialization
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore')
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
# PaddleOCR init with GPU support
|
||||
# PaddleOCR 3.x init (use_gpu removed, device controlled by paddle.set_device)
|
||||
init_params = {
|
||||
'lang': lang,
|
||||
'use_gpu': use_gpu,
|
||||
'show_log': False, # Reduce log noise
|
||||
}
|
||||
if det_model_dir:
|
||||
init_params['text_detection_model_dir'] = det_model_dir
|
||||
@@ -72,12 +89,13 @@ class OCREngine:
|
||||
init_params['text_recognition_model_dir'] = rec_model_dir
|
||||
|
||||
self.ocr = PaddleOCR(**init_params)
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
def extract_from_image(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None
|
||||
) -> list[OCRToken]:
|
||||
"""
|
||||
Extract text tokens from an image.
|
||||
@@ -85,17 +103,73 @@ class OCREngine:
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
max_size: Maximum image dimension. Larger images will be scaled down
|
||||
to avoid OCR issues with PaddleOCR on large images.
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
|
||||
Returns:
|
||||
List of OCRToken objects
|
||||
List of OCRToken objects with bbox in pixel coords (or PDF points if scale_to_pdf_points is set)
|
||||
"""
|
||||
result = self.extract_with_image(image, page_no, max_size, scale_to_pdf_points)
|
||||
return result.tokens
|
||||
|
||||
def extract_with_image(
|
||||
self,
|
||||
image: str | Path | np.ndarray,
|
||||
page_no: int = 0,
|
||||
max_size: int = 2000,
|
||||
scale_to_pdf_points: float | None = None
|
||||
) -> OCRResult:
|
||||
"""
|
||||
Extract text tokens from an image and return the preprocessed image.
|
||||
|
||||
PaddleOCR applies document preprocessing (unwarping, rotation, enhancement)
|
||||
and returns coordinates relative to the preprocessed image (output_img).
|
||||
This method returns both tokens and output_img so the caller can save
|
||||
the correct image that matches the coordinates.
|
||||
|
||||
Args:
|
||||
image: Image path or numpy array
|
||||
page_no: Page number for reference
|
||||
max_size: Maximum image dimension. Larger images will be scaled down
|
||||
to avoid OCR issues with PaddleOCR on large images.
|
||||
scale_to_pdf_points: If provided, scale bbox coordinates by this factor
|
||||
to convert from pixel to PDF point coordinates.
|
||||
Use (72 / dpi) for images rendered at a specific DPI.
|
||||
|
||||
Returns:
|
||||
OCRResult with tokens and output_img (preprocessed image from PaddleOCR)
|
||||
"""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
# Load image if path
|
||||
if isinstance(image, (str, Path)):
|
||||
image = str(image)
|
||||
img = PILImage.open(str(image))
|
||||
img_array = np.array(img)
|
||||
else:
|
||||
img_array = image
|
||||
|
||||
# Check if image needs scaling for OCR
|
||||
h, w = img_array.shape[:2]
|
||||
ocr_scale_factor = 1.0
|
||||
|
||||
if max(h, w) > max_size:
|
||||
ocr_scale_factor = max_size / max(h, w)
|
||||
new_w = int(w * ocr_scale_factor)
|
||||
new_h = int(h * ocr_scale_factor)
|
||||
# Resize image for OCR
|
||||
img = PILImage.fromarray(img_array)
|
||||
img = img.resize((new_w, new_h), PILImage.Resampling.LANCZOS)
|
||||
img_array = np.array(img)
|
||||
|
||||
# PaddleOCR 3.x uses predict() method instead of ocr()
|
||||
result = self.ocr.predict(image)
|
||||
result = self.ocr.predict(img_array)
|
||||
|
||||
tokens = []
|
||||
output_img = None
|
||||
|
||||
if result:
|
||||
for item in result:
|
||||
# PaddleOCR 3.x returns list of dicts with 'rec_texts', 'rec_scores', 'dt_polys'
|
||||
@@ -104,16 +178,30 @@ class OCREngine:
|
||||
rec_scores = item.get('rec_scores', [])
|
||||
dt_polys = item.get('dt_polys', [])
|
||||
|
||||
for i, (text, score, poly) in enumerate(zip(rec_texts, rec_scores, dt_polys)):
|
||||
# Get output_img from doc_preprocessor_res
|
||||
# This is the preprocessed image that coordinates are relative to
|
||||
doc_preproc = item.get('doc_preprocessor_res', {})
|
||||
if isinstance(doc_preproc, dict):
|
||||
output_img = doc_preproc.get('output_img')
|
||||
|
||||
# Coordinates are relative to output_img (preprocessed image)
|
||||
# No rotation compensation needed - just use coordinates directly
|
||||
for text, score, poly in zip(rec_texts, rec_scores, dt_polys):
|
||||
# poly is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
|
||||
x_coords = [p[0] for p in poly]
|
||||
y_coords = [p[1] for p in poly]
|
||||
x_coords = [float(p[0]) for p in poly]
|
||||
y_coords = [float(p[1]) for p in poly]
|
||||
|
||||
# Apply PDF points scale if requested
|
||||
if scale_to_pdf_points is not None:
|
||||
final_scale = scale_to_pdf_points
|
||||
else:
|
||||
final_scale = 1.0
|
||||
|
||||
bbox = (
|
||||
min(x_coords),
|
||||
min(y_coords),
|
||||
max(x_coords),
|
||||
max(y_coords)
|
||||
min(x_coords) * final_scale,
|
||||
min(y_coords) * final_scale,
|
||||
max(x_coords) * final_scale,
|
||||
max(y_coords) * final_scale
|
||||
)
|
||||
|
||||
tokens.append(OCRToken(
|
||||
@@ -129,11 +217,17 @@ class OCREngine:
|
||||
x_coords = [p[0] for p in bbox_points]
|
||||
y_coords = [p[1] for p in bbox_points]
|
||||
|
||||
# Apply PDF points scale if requested
|
||||
if scale_to_pdf_points is not None:
|
||||
final_scale = scale_to_pdf_points
|
||||
else:
|
||||
final_scale = 1.0
|
||||
|
||||
bbox = (
|
||||
min(x_coords),
|
||||
min(y_coords),
|
||||
max(x_coords),
|
||||
max(y_coords)
|
||||
min(x_coords) * final_scale,
|
||||
min(y_coords) * final_scale,
|
||||
max(x_coords) * final_scale,
|
||||
max(y_coords) * final_scale
|
||||
)
|
||||
|
||||
tokens.append(OCRToken(
|
||||
@@ -143,7 +237,11 @@ class OCREngine:
|
||||
page_no=page_no
|
||||
))
|
||||
|
||||
return tokens
|
||||
# If no output_img was found, use the original input array
|
||||
if output_img is None:
|
||||
output_img = img_array
|
||||
|
||||
return OCRResult(tokens=tokens, output_img=output_img)
|
||||
|
||||
def extract_from_pdf(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,12 @@
|
||||
from .detector import is_text_pdf, get_pdf_type
|
||||
from .renderer import render_pdf_to_images
|
||||
from .extractor import extract_text_tokens
|
||||
from .extractor import extract_text_tokens, PDFDocument, Token
|
||||
|
||||
__all__ = ['is_text_pdf', 'get_pdf_type', 'render_pdf_to_images', 'extract_text_tokens']
|
||||
__all__ = [
|
||||
'is_text_pdf',
|
||||
'get_pdf_type',
|
||||
'render_pdf_to_images',
|
||||
'extract_text_tokens',
|
||||
'PDFDocument',
|
||||
'Token',
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@ Extracts text tokens with bounding boxes from text-layer PDFs.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
from typing import Generator, Optional
|
||||
import fitz # PyMuPDF
|
||||
|
||||
|
||||
@@ -46,6 +46,134 @@ class Token:
|
||||
return ((self.x0 + self.x1) / 2, (self.y0 + self.y1) / 2)
|
||||
|
||||
|
||||
class PDFDocument:
|
||||
"""
|
||||
Context manager for efficient PDF document handling.
|
||||
|
||||
Caches the open document handle to avoid repeated open/close cycles.
|
||||
Use this when you need to perform multiple operations on the same PDF.
|
||||
"""
|
||||
|
||||
def __init__(self, pdf_path: str | Path):
|
||||
self.pdf_path = Path(pdf_path)
|
||||
self._doc: Optional[fitz.Document] = None
|
||||
self._dimensions_cache: dict[int, tuple[float, float]] = {}
|
||||
|
||||
def __enter__(self) -> 'PDFDocument':
|
||||
self._doc = fitz.open(self.pdf_path)
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._doc:
|
||||
self._doc.close()
|
||||
self._doc = None
|
||||
|
||||
@property
|
||||
def doc(self) -> fitz.Document:
|
||||
if self._doc is None:
|
||||
raise RuntimeError("PDFDocument must be used within a context manager")
|
||||
return self._doc
|
||||
|
||||
@property
|
||||
def page_count(self) -> int:
|
||||
return len(self.doc)
|
||||
|
||||
def is_text_pdf(self, min_chars: int = 30) -> bool:
|
||||
"""Check if PDF has extractable text layer."""
|
||||
if self.page_count == 0:
|
||||
return False
|
||||
first_page = self.doc[0]
|
||||
text = first_page.get_text()
|
||||
return len(text.strip()) > min_chars
|
||||
|
||||
def get_page_dimensions(self, page_no: int = 0) -> tuple[float, float]:
|
||||
"""Get page dimensions in points (cached)."""
|
||||
if page_no not in self._dimensions_cache:
|
||||
page = self.doc[page_no]
|
||||
rect = page.rect
|
||||
self._dimensions_cache[page_no] = (rect.width, rect.height)
|
||||
return self._dimensions_cache[page_no]
|
||||
|
||||
def get_render_dimensions(self, page_no: int = 0, dpi: int = 300) -> tuple[int, int]:
|
||||
"""Get rendered image dimensions in pixels."""
|
||||
width, height = self.get_page_dimensions(page_no)
|
||||
zoom = dpi / 72
|
||||
return int(width * zoom), int(height * zoom)
|
||||
|
||||
def extract_text_tokens(self, page_no: int) -> Generator[Token, None, None]:
|
||||
"""Extract text tokens from a specific page."""
|
||||
page = self.doc[page_no]
|
||||
|
||||
text_dict = page.get_text("dict")
|
||||
tokens_found = False
|
||||
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0:
|
||||
continue
|
||||
|
||||
for line in block.get("lines", []):
|
||||
for span in line.get("spans", []):
|
||||
text = span.get("text", "").strip()
|
||||
if not text:
|
||||
continue
|
||||
|
||||
bbox = span.get("bbox")
|
||||
if bbox and all(abs(b) < 1e9 for b in bbox):
|
||||
tokens_found = True
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=tuple(bbox),
|
||||
page_no=page_no
|
||||
)
|
||||
|
||||
# Fallback: if dict mode failed, use words mode
|
||||
if not tokens_found:
|
||||
words = page.get_text("words")
|
||||
for word_info in words:
|
||||
x0, y0, x1, y1, text, *_ = word_info
|
||||
text = text.strip()
|
||||
if text:
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=page_no
|
||||
)
|
||||
|
||||
def render_page(self, page_no: int, output_path: Path, dpi: int = 300) -> Path:
|
||||
"""Render a page to an image file."""
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
page = self.doc[page_no]
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
pix.save(str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def render_all_pages(
|
||||
self,
|
||||
output_dir: Path,
|
||||
dpi: int = 300
|
||||
) -> Generator[tuple[int, Path], None, None]:
|
||||
"""Render all pages to images."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
pdf_name = self.pdf_path.stem
|
||||
|
||||
zoom = dpi / 72
|
||||
matrix = fitz.Matrix(zoom, zoom)
|
||||
|
||||
for page_no in range(self.page_count):
|
||||
page = self.doc[page_no]
|
||||
pix = page.get_pixmap(matrix=matrix)
|
||||
|
||||
image_path = output_dir / f"{pdf_name}_page_{page_no:03d}.png"
|
||||
pix.save(str(image_path))
|
||||
|
||||
yield page_no, image_path
|
||||
|
||||
|
||||
def extract_text_tokens(
|
||||
pdf_path: str | Path,
|
||||
page_no: int | None = None
|
||||
@@ -70,6 +198,7 @@ def extract_text_tokens(
|
||||
# Get text with position info using "dict" mode
|
||||
text_dict = page.get_text("dict")
|
||||
|
||||
tokens_found = False
|
||||
for block in text_dict.get("blocks", []):
|
||||
if block.get("type") != 0: # Skip non-text blocks
|
||||
continue
|
||||
@@ -81,13 +210,28 @@ def extract_text_tokens(
|
||||
continue
|
||||
|
||||
bbox = span.get("bbox")
|
||||
if bbox:
|
||||
# Check for corrupted bbox (overflow values)
|
||||
if bbox and all(abs(b) < 1e9 for b in bbox):
|
||||
tokens_found = True
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=tuple(bbox),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
# Fallback: if dict mode failed, use words mode
|
||||
if not tokens_found:
|
||||
words = page.get_text("words")
|
||||
for word_info in words:
|
||||
x0, y0, x1, y1, text, *_ = word_info
|
||||
text = text.strip()
|
||||
if text:
|
||||
yield Token(
|
||||
text=text,
|
||||
bbox=(x0, y0, x1, y1),
|
||||
page_no=pg_no
|
||||
)
|
||||
|
||||
doc.close()
|
||||
|
||||
|
||||
|
||||
22
src/processing/__init__.py
Normal file
22
src/processing/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Processing module for multi-pool parallel processing.
|
||||
|
||||
This module provides a robust dual-pool architecture for processing
|
||||
documents with both CPU-bound and GPU-bound tasks.
|
||||
"""
|
||||
|
||||
from src.processing.worker_pool import WorkerPool, TaskResult
|
||||
from src.processing.cpu_pool import CPUWorkerPool
|
||||
from src.processing.gpu_pool import GPUWorkerPool
|
||||
from src.processing.task_dispatcher import TaskDispatcher, TaskType
|
||||
from src.processing.dual_pool_coordinator import DualPoolCoordinator
|
||||
|
||||
__all__ = [
|
||||
"WorkerPool",
|
||||
"TaskResult",
|
||||
"CPUWorkerPool",
|
||||
"GPUWorkerPool",
|
||||
"TaskDispatcher",
|
||||
"TaskType",
|
||||
"DualPoolCoordinator",
|
||||
]
|
||||
351
src/processing/autolabel_tasks.py
Normal file
351
src/processing/autolabel_tasks.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Task functions for autolabel processing in multi-pool architecture.
|
||||
|
||||
Provides CPU and GPU task functions that can be called from worker pools.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
# Global OCR instance (initialized once per GPU worker process)
|
||||
_ocr_engine: Optional[Any] = None
|
||||
|
||||
|
||||
def init_cpu_worker() -> None:
|
||||
"""
|
||||
Initialize CPU worker process.
|
||||
|
||||
Disables GPU access and suppresses unnecessary warnings.
|
||||
"""
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
||||
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
||||
|
||||
|
||||
def init_gpu_worker(gpu_id: int = 0, gpu_mem: int = 4000) -> None:
|
||||
"""
|
||||
Initialize GPU worker process with PaddleOCR.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID.
|
||||
gpu_mem: Maximum GPU memory in MB.
|
||||
"""
|
||||
global _ocr_engine
|
||||
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
|
||||
# Suppress PaddleX warnings
|
||||
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
|
||||
warnings.filterwarnings("ignore", message=".*reinitialization.*")
|
||||
|
||||
# Lazy initialization - OCR will be created on first use
|
||||
_ocr_engine = None
|
||||
|
||||
|
||||
def _get_ocr_engine():
|
||||
"""Get or create OCR engine for current GPU worker."""
|
||||
global _ocr_engine
|
||||
|
||||
if _ocr_engine is None:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
from src.ocr import OCREngine
|
||||
_ocr_engine = OCREngine()
|
||||
|
||||
return _ocr_engine
|
||||
|
||||
|
||||
def _save_output_img(output_img, image_path: Path) -> None:
|
||||
"""Save OCR preprocessed image to replace rendered image."""
|
||||
from PIL import Image as PILImage
|
||||
|
||||
if output_img is not None:
|
||||
img = PILImage.fromarray(output_img)
|
||||
img.save(str(image_path))
|
||||
|
||||
|
||||
def process_text_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a text PDF (CPU task - no OCR needed).
|
||||
|
||||
Args:
|
||||
task_data: Dictionary with keys:
|
||||
- row_dict: Document fields from CSV
|
||||
- pdf_path: Path to PDF file
|
||||
- output_dir: Output directory
|
||||
- dpi: Rendering DPI
|
||||
- min_confidence: Minimum match confidence
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
from src.data import AutoLabelReport, FieldMatchResult
|
||||
from src.pdf import PDFDocument
|
||||
from src.matcher import FieldMatcher
|
||||
from src.normalize import normalize_field
|
||||
from src.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", 150)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "text"
|
||||
|
||||
result = {
|
||||
"doc_id": doc_id,
|
||||
"success": False,
|
||||
"pages": [],
|
||||
"report": None,
|
||||
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
|
||||
}
|
||||
|
||||
try:
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
matcher = FieldMatcher()
|
||||
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
images_dir = output_dir / "temp" / doc_id / "images"
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# Text extraction (no OCR)
|
||||
tokens = list(pdf_doc.extract_text_tokens(page_no))
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
matched_fields.add(field_name)
|
||||
report.add_field_result(
|
||||
FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords,
|
||||
)
|
||||
)
|
||||
|
||||
# Generate annotations
|
||||
annotations = generator.generate_from_matches(
|
||||
matches, img_width, img_height, dpi=dpi
|
||||
)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"page_no": page_no,
|
||||
"count": len(annotations),
|
||||
}
|
||||
)
|
||||
report.annotations_generated += len(annotations)
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result["stats"][class_name] += 1
|
||||
|
||||
# Record unmatched fields
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if value and field_name not in matched_fields:
|
||||
report.add_field_result(
|
||||
FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=-1,
|
||||
)
|
||||
)
|
||||
|
||||
if page_annotations:
|
||||
result["pages"] = page_annotations
|
||||
result["success"] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result["report"] = report.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def process_scanned_pdf(task_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Process a scanned PDF (GPU task - requires OCR).
|
||||
|
||||
Args:
|
||||
task_data: Dictionary with keys:
|
||||
- row_dict: Document fields from CSV
|
||||
- pdf_path: Path to PDF file
|
||||
- output_dir: Output directory
|
||||
- dpi: Rendering DPI
|
||||
- min_confidence: Minimum match confidence
|
||||
|
||||
Returns:
|
||||
Result dictionary with success status, annotations, and report.
|
||||
"""
|
||||
from src.data import AutoLabelReport, FieldMatchResult
|
||||
from src.pdf import PDFDocument
|
||||
from src.matcher import FieldMatcher
|
||||
from src.normalize import normalize_field
|
||||
from src.yolo.annotation_generator import AnnotationGenerator, FIELD_CLASSES
|
||||
|
||||
row_dict = task_data["row_dict"]
|
||||
pdf_path = Path(task_data["pdf_path"])
|
||||
output_dir = Path(task_data["output_dir"])
|
||||
dpi = task_data.get("dpi", 150)
|
||||
min_confidence = task_data.get("min_confidence", 0.5)
|
||||
|
||||
start_time = time.time()
|
||||
doc_id = row_dict["DocumentId"]
|
||||
|
||||
report = AutoLabelReport(document_id=doc_id)
|
||||
report.pdf_path = str(pdf_path)
|
||||
report.pdf_type = "scanned"
|
||||
|
||||
result = {
|
||||
"doc_id": doc_id,
|
||||
"success": False,
|
||||
"pages": [],
|
||||
"report": None,
|
||||
"stats": {name: 0 for name in FIELD_CLASSES.keys()},
|
||||
}
|
||||
|
||||
try:
|
||||
# Get OCR engine from worker cache
|
||||
ocr_engine = _get_ocr_engine()
|
||||
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
generator = AnnotationGenerator(min_confidence=min_confidence)
|
||||
matcher = FieldMatcher()
|
||||
|
||||
page_annotations = []
|
||||
matched_fields = set()
|
||||
|
||||
images_dir = output_dir / "temp" / doc_id / "images"
|
||||
for page_no, image_path in pdf_doc.render_all_pages(images_dir, dpi=dpi):
|
||||
report.total_pages += 1
|
||||
img_width, img_height = pdf_doc.get_render_dimensions(page_no, dpi)
|
||||
|
||||
# OCR extraction
|
||||
ocr_result = ocr_engine.extract_with_image(
|
||||
str(image_path),
|
||||
page_no,
|
||||
scale_to_pdf_points=72 / dpi,
|
||||
)
|
||||
tokens = ocr_result.tokens
|
||||
|
||||
# Save preprocessed image
|
||||
_save_output_img(ocr_result.output_img, image_path)
|
||||
|
||||
# Update dimensions to match OCR output
|
||||
if ocr_result.output_img is not None:
|
||||
img_height, img_width = ocr_result.output_img.shape[:2]
|
||||
|
||||
# Match fields
|
||||
matches = {}
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if not value:
|
||||
continue
|
||||
|
||||
normalized = normalize_field(field_name, str(value))
|
||||
field_matches = matcher.find_matches(tokens, field_name, normalized, page_no)
|
||||
|
||||
if field_matches:
|
||||
best = field_matches[0]
|
||||
matches[field_name] = field_matches
|
||||
matched_fields.add(field_name)
|
||||
report.add_field_result(
|
||||
FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=True,
|
||||
score=best.score,
|
||||
matched_text=best.matched_text,
|
||||
candidate_used=best.value,
|
||||
bbox=best.bbox,
|
||||
page_no=page_no,
|
||||
context_keywords=best.context_keywords,
|
||||
)
|
||||
)
|
||||
|
||||
# Generate annotations
|
||||
annotations = generator.generate_from_matches(
|
||||
matches, img_width, img_height, dpi=dpi
|
||||
)
|
||||
|
||||
if annotations:
|
||||
page_annotations.append(
|
||||
{
|
||||
"image_path": str(image_path),
|
||||
"page_no": page_no,
|
||||
"count": len(annotations),
|
||||
}
|
||||
)
|
||||
report.annotations_generated += len(annotations)
|
||||
for ann in annotations:
|
||||
class_name = list(FIELD_CLASSES.keys())[ann.class_id]
|
||||
result["stats"][class_name] += 1
|
||||
|
||||
# Record unmatched fields
|
||||
for field_name in FIELD_CLASSES.keys():
|
||||
value = row_dict.get(field_name)
|
||||
if value and field_name not in matched_fields:
|
||||
report.add_field_result(
|
||||
FieldMatchResult(
|
||||
field_name=field_name,
|
||||
csv_value=str(value),
|
||||
matched=False,
|
||||
page_no=-1,
|
||||
)
|
||||
)
|
||||
|
||||
if page_annotations:
|
||||
result["pages"] = page_annotations
|
||||
result["success"] = True
|
||||
report.success = True
|
||||
else:
|
||||
report.errors.append("No annotations generated")
|
||||
|
||||
except Exception as e:
|
||||
report.errors.append(str(e))
|
||||
|
||||
report.processing_time_ms = (time.time() - start_time) * 1000
|
||||
result["report"] = report.to_dict()
|
||||
|
||||
return result
|
||||
71
src/processing/cpu_pool.py
Normal file
71
src/processing/cpu_pool.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
CPU Worker Pool for text PDF processing.
|
||||
|
||||
This pool handles CPU-bound tasks like text extraction from native PDFs
|
||||
that don't require OCR.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
|
||||
from src.processing.worker_pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global resources for CPU workers (initialized once per process)
|
||||
_cpu_initialized: bool = False
|
||||
|
||||
|
||||
def _init_cpu_worker() -> None:
|
||||
"""
|
||||
Initialize a CPU worker process.
|
||||
|
||||
Disables GPU access and sets up CPU-only environment.
|
||||
"""
|
||||
global _cpu_initialized
|
||||
|
||||
# Disable GPU access for CPU workers
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
|
||||
|
||||
# Set threading limits for better CPU utilization
|
||||
os.environ.setdefault("OMP_NUM_THREADS", "1")
|
||||
os.environ.setdefault("MKL_NUM_THREADS", "1")
|
||||
|
||||
_cpu_initialized = True
|
||||
logger.debug(f"CPU worker initialized in process {os.getpid()}")
|
||||
|
||||
|
||||
class CPUWorkerPool(WorkerPool):
|
||||
"""
|
||||
Worker pool for CPU-bound tasks.
|
||||
|
||||
Handles text PDF processing that doesn't require OCR.
|
||||
Each worker is initialized with CUDA disabled to prevent
|
||||
accidental GPU memory consumption.
|
||||
|
||||
Example:
|
||||
with CPUWorkerPool(max_workers=4) as pool:
|
||||
future = pool.submit(process_text_pdf, pdf_path)
|
||||
result = future.result()
|
||||
"""
|
||||
|
||||
def __init__(self, max_workers: int = 4) -> None:
|
||||
"""
|
||||
Initialize CPU worker pool.
|
||||
|
||||
Args:
|
||||
max_workers: Number of CPU worker processes.
|
||||
Defaults to 4 for balanced performance.
|
||||
"""
|
||||
super().__init__(max_workers=max_workers, use_gpu=False, gpu_id=-1)
|
||||
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""Return the CPU worker initializer."""
|
||||
return _init_cpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
"""Return empty args for CPU initializer."""
|
||||
return ()
|
||||
339
src/processing/dual_pool_coordinator.py
Normal file
339
src/processing/dual_pool_coordinator.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
Dual Pool Coordinator for managing CPU and GPU worker pools.
|
||||
|
||||
Coordinates task distribution between CPU and GPU pools, handles result
|
||||
collection using as_completed(), and provides callbacks for progress tracking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from concurrent.futures import Future, TimeoutError, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from src.processing.cpu_pool import CPUWorkerPool
|
||||
from src.processing.gpu_pool import GPUWorkerPool
|
||||
from src.processing.task_dispatcher import Task, TaskDispatcher, TaskType
|
||||
from src.processing.worker_pool import TaskResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchStats:
|
||||
"""Statistics for a batch processing run."""
|
||||
|
||||
total: int = 0
|
||||
cpu_submitted: int = 0
|
||||
gpu_submitted: int = 0
|
||||
successful: int = 0
|
||||
failed: int = 0
|
||||
cpu_time: float = 0.0
|
||||
gpu_time: float = 0.0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def success_rate(self) -> float:
|
||||
"""Calculate success rate as percentage."""
|
||||
if self.total == 0:
|
||||
return 0.0
|
||||
return (self.successful / self.total) * 100
|
||||
|
||||
|
||||
class DualPoolCoordinator:
|
||||
"""
|
||||
Coordinates CPU and GPU worker pools for parallel document processing.
|
||||
|
||||
Uses separate ProcessPoolExecutor instances for CPU and GPU tasks,
|
||||
with as_completed() for efficient result collection across both pools.
|
||||
|
||||
Key features:
|
||||
- Automatic task classification (CPU vs GPU)
|
||||
- Parallel submission to both pools
|
||||
- Unified result collection with timeouts
|
||||
- Progress callbacks for UI integration
|
||||
- Proper resource cleanup
|
||||
|
||||
Example:
|
||||
with DualPoolCoordinator(cpu_workers=4, gpu_workers=1) as coord:
|
||||
results = coord.process_batch(
|
||||
documents=docs,
|
||||
cpu_task_fn=process_text_pdf,
|
||||
gpu_task_fn=process_scanned_pdf,
|
||||
on_result=lambda r: save_to_db(r),
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cpu_workers: int = 4,
|
||||
gpu_workers: int = 1,
|
||||
gpu_id: int = 0,
|
||||
task_timeout: float = 300.0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the dual pool coordinator.
|
||||
|
||||
Args:
|
||||
cpu_workers: Number of CPU worker processes.
|
||||
gpu_workers: Number of GPU worker processes (usually 1).
|
||||
gpu_id: GPU device ID to use.
|
||||
task_timeout: Timeout in seconds for individual tasks.
|
||||
"""
|
||||
self.cpu_workers = cpu_workers
|
||||
self.gpu_workers = gpu_workers
|
||||
self.gpu_id = gpu_id
|
||||
self.task_timeout = task_timeout
|
||||
|
||||
self._cpu_pool: Optional[CPUWorkerPool] = None
|
||||
self._gpu_pool: Optional[GPUWorkerPool] = None
|
||||
self._dispatcher = TaskDispatcher()
|
||||
self._started = False
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start both worker pools."""
|
||||
if self._started:
|
||||
raise RuntimeError("Coordinator already started")
|
||||
|
||||
logger.info(
|
||||
f"Starting DualPoolCoordinator: "
|
||||
f"{self.cpu_workers} CPU workers, {self.gpu_workers} GPU workers"
|
||||
)
|
||||
|
||||
self._cpu_pool = CPUWorkerPool(max_workers=self.cpu_workers)
|
||||
self._gpu_pool = GPUWorkerPool(
|
||||
max_workers=self.gpu_workers,
|
||||
gpu_id=self.gpu_id,
|
||||
)
|
||||
|
||||
self._cpu_pool.start()
|
||||
self._gpu_pool.start()
|
||||
self._started = True
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shutdown both worker pools."""
|
||||
logger.info("Shutting down DualPoolCoordinator")
|
||||
|
||||
if self._cpu_pool is not None:
|
||||
self._cpu_pool.shutdown(wait=wait)
|
||||
self._cpu_pool = None
|
||||
|
||||
if self._gpu_pool is not None:
|
||||
self._gpu_pool.shutdown(wait=wait)
|
||||
self._gpu_pool = None
|
||||
|
||||
self._started = False
|
||||
|
||||
def __enter__(self) -> "DualPoolCoordinator":
|
||||
"""Context manager entry."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit."""
|
||||
self.shutdown(wait=True)
|
||||
|
||||
def process_batch(
|
||||
self,
|
||||
documents: List[dict],
|
||||
cpu_task_fn: Callable[[dict], Any],
|
||||
gpu_task_fn: Callable[[dict], Any],
|
||||
on_result: Optional[Callable[[TaskResult], None]] = None,
|
||||
on_error: Optional[Callable[[str, Exception], None]] = None,
|
||||
on_progress: Optional[Callable[[int, int], None]] = None,
|
||||
id_field: str = "id",
|
||||
) -> List[TaskResult]:
|
||||
"""
|
||||
Process a batch of documents using both CPU and GPU pools.
|
||||
|
||||
Documents are automatically classified and routed to the appropriate
|
||||
pool. Results are collected as they complete.
|
||||
|
||||
Args:
|
||||
documents: List of document info dicts to process.
|
||||
cpu_task_fn: Function to process text PDFs (called in CPU pool).
|
||||
gpu_task_fn: Function to process scanned PDFs (called in GPU pool).
|
||||
on_result: Callback for each successful result.
|
||||
on_error: Callback for each failed task (task_id, exception).
|
||||
on_progress: Callback for progress updates (completed, total).
|
||||
id_field: Field name to use as task ID.
|
||||
|
||||
Returns:
|
||||
List of TaskResult objects for all tasks.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If coordinator is not started.
|
||||
"""
|
||||
if not self._started:
|
||||
raise RuntimeError("Coordinator not started. Use context manager or call start().")
|
||||
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
stats = BatchStats(total=len(documents))
|
||||
|
||||
# Create and partition tasks
|
||||
tasks = self._dispatcher.create_tasks(documents, id_field=id_field)
|
||||
cpu_tasks, gpu_tasks = self._dispatcher.partition_tasks(tasks)
|
||||
|
||||
# Submit tasks to pools
|
||||
futures_map: Dict[Future, Task] = {}
|
||||
|
||||
# Submit CPU tasks
|
||||
cpu_start = time.time()
|
||||
for task in cpu_tasks:
|
||||
future = self._cpu_pool.submit(cpu_task_fn, task.data)
|
||||
futures_map[future] = task
|
||||
stats.cpu_submitted += 1
|
||||
|
||||
# Submit GPU tasks
|
||||
gpu_start = time.time()
|
||||
for task in gpu_tasks:
|
||||
future = self._gpu_pool.submit(gpu_task_fn, task.data)
|
||||
futures_map[future] = task
|
||||
stats.gpu_submitted += 1
|
||||
|
||||
logger.info(
|
||||
f"Submitted {stats.cpu_submitted} CPU tasks, {stats.gpu_submitted} GPU tasks"
|
||||
)
|
||||
|
||||
# Collect results as they complete
|
||||
results: List[TaskResult] = []
|
||||
completed = 0
|
||||
|
||||
for future in as_completed(futures_map.keys(), timeout=self.task_timeout * len(documents)):
|
||||
task = futures_map[future]
|
||||
pool_type = "CPU" if task.task_type == TaskType.CPU else "GPU"
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
data = future.result(timeout=self.task_timeout)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=True,
|
||||
data=data,
|
||||
pool_type=pool_type,
|
||||
processing_time=processing_time,
|
||||
)
|
||||
stats.successful += 1
|
||||
|
||||
if pool_type == "CPU":
|
||||
stats.cpu_time += processing_time
|
||||
else:
|
||||
stats.gpu_time += processing_time
|
||||
|
||||
if on_result is not None:
|
||||
try:
|
||||
on_result(result)
|
||||
except Exception as e:
|
||||
logger.warning(f"on_result callback failed: {e}")
|
||||
|
||||
except TimeoutError:
|
||||
error_msg = f"Task timed out after {self.task_timeout}s"
|
||||
logger.error(f"[{pool_type}] Task {task.id}: {error_msg}")
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=False,
|
||||
data=None,
|
||||
error=error_msg,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
stats.failed += 1
|
||||
stats.errors.append(f"{task.id}: {error_msg}")
|
||||
|
||||
if on_error is not None:
|
||||
try:
|
||||
on_error(task.id, TimeoutError(error_msg))
|
||||
except Exception as e:
|
||||
logger.warning(f"on_error callback failed: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[{pool_type}] Task {task.id} failed: {error_msg}")
|
||||
|
||||
result = TaskResult(
|
||||
task_id=task.id,
|
||||
success=False,
|
||||
data=None,
|
||||
error=error_msg,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
stats.failed += 1
|
||||
stats.errors.append(f"{task.id}: {error_msg}")
|
||||
|
||||
if on_error is not None:
|
||||
try:
|
||||
on_error(task.id, e)
|
||||
except Exception as callback_error:
|
||||
logger.warning(f"on_error callback failed: {callback_error}")
|
||||
|
||||
results.append(result)
|
||||
completed += 1
|
||||
|
||||
if on_progress is not None:
|
||||
try:
|
||||
on_progress(completed, stats.total)
|
||||
except Exception as e:
|
||||
logger.warning(f"on_progress callback failed: {e}")
|
||||
|
||||
# Log final stats
|
||||
logger.info(
|
||||
f"Batch complete: {stats.successful}/{stats.total} successful "
|
||||
f"({stats.success_rate:.1f}%), {stats.failed} failed"
|
||||
)
|
||||
if stats.cpu_submitted > 0:
|
||||
logger.info(f"CPU: {stats.cpu_submitted} tasks, {stats.cpu_time:.2f}s total")
|
||||
if stats.gpu_submitted > 0:
|
||||
logger.info(f"GPU: {stats.gpu_submitted} tasks, {stats.gpu_time:.2f}s total")
|
||||
|
||||
return results
|
||||
|
||||
def process_single(
|
||||
self,
|
||||
document: dict,
|
||||
cpu_task_fn: Callable[[dict], Any],
|
||||
gpu_task_fn: Callable[[dict], Any],
|
||||
id_field: str = "id",
|
||||
) -> TaskResult:
|
||||
"""
|
||||
Process a single document.
|
||||
|
||||
Convenience method for processing one document at a time.
|
||||
|
||||
Args:
|
||||
document: Document info dict.
|
||||
cpu_task_fn: Function for text PDF processing.
|
||||
gpu_task_fn: Function for scanned PDF processing.
|
||||
id_field: Field name for task ID.
|
||||
|
||||
Returns:
|
||||
TaskResult for the document.
|
||||
"""
|
||||
results = self.process_batch(
|
||||
documents=[document],
|
||||
cpu_task_fn=cpu_task_fn,
|
||||
gpu_task_fn=gpu_task_fn,
|
||||
id_field=id_field,
|
||||
)
|
||||
return results[0] if results else TaskResult(
|
||||
task_id=str(document.get(id_field, "unknown")),
|
||||
success=False,
|
||||
data=None,
|
||||
error="No result returned",
|
||||
)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if both pools are running."""
|
||||
return (
|
||||
self._started
|
||||
and self._cpu_pool is not None
|
||||
and self._gpu_pool is not None
|
||||
and self._cpu_pool.is_running
|
||||
and self._gpu_pool.is_running
|
||||
)
|
||||
110
src/processing/gpu_pool.py
Normal file
110
src/processing/gpu_pool.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
GPU Worker Pool for OCR processing.
|
||||
|
||||
This pool handles GPU-bound tasks like PaddleOCR for scanned PDF processing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from src.processing.worker_pool import WorkerPool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global OCR instance for GPU workers (initialized once per process)
|
||||
_ocr_instance: Optional[Any] = None
|
||||
_gpu_initialized: bool = False
|
||||
|
||||
|
||||
def _init_gpu_worker(gpu_id: int = 0) -> None:
|
||||
"""
|
||||
Initialize a GPU worker process with PaddleOCR.
|
||||
|
||||
Args:
|
||||
gpu_id: GPU device ID to use.
|
||||
"""
|
||||
global _ocr_instance, _gpu_initialized
|
||||
|
||||
# Set GPU device before importing paddle
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
||||
|
||||
# Reduce logging noise
|
||||
os.environ["GLOG_minloglevel"] = "2"
|
||||
|
||||
# Suppress PaddleX warnings
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
|
||||
warnings.filterwarnings("ignore", message=".*reinitialization.*")
|
||||
|
||||
try:
|
||||
# Import PaddleOCR after setting environment
|
||||
# PaddleOCR 3.x uses paddle.set_device() for GPU control, not use_gpu param
|
||||
import paddle
|
||||
paddle.set_device(f"gpu:{gpu_id}")
|
||||
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
# PaddleOCR 3.x init - minimal params, GPU controlled via paddle.set_device
|
||||
_ocr_instance = PaddleOCR(lang="en")
|
||||
_gpu_initialized = True
|
||||
logger.info(f"GPU worker initialized on GPU {gpu_id} in process {os.getpid()}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize GPU worker: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def get_ocr_instance() -> Any:
|
||||
"""
|
||||
Get the initialized OCR instance for the current worker.
|
||||
|
||||
Returns:
|
||||
PaddleOCR instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If OCR is not initialized.
|
||||
"""
|
||||
global _ocr_instance
|
||||
if _ocr_instance is None:
|
||||
raise RuntimeError("OCR not initialized. This function must be called from a GPU worker.")
|
||||
return _ocr_instance
|
||||
|
||||
|
||||
class GPUWorkerPool(WorkerPool):
|
||||
"""
|
||||
Worker pool for GPU-bound OCR tasks.
|
||||
|
||||
Handles scanned PDF processing using PaddleOCR with GPU acceleration.
|
||||
Typically limited to 1 worker to avoid GPU memory conflicts.
|
||||
|
||||
Example:
|
||||
with GPUWorkerPool(max_workers=1, gpu_id=0) as pool:
|
||||
future = pool.submit(process_scanned_pdf, pdf_path)
|
||||
result = future.result()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int = 1,
|
||||
gpu_id: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize GPU worker pool.
|
||||
|
||||
Args:
|
||||
max_workers: Number of GPU worker processes.
|
||||
Defaults to 1 to avoid GPU memory conflicts.
|
||||
gpu_id: GPU device ID to use.
|
||||
"""
|
||||
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
||||
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""Return the GPU worker initializer."""
|
||||
return _init_gpu_worker
|
||||
|
||||
def get_init_args(self) -> tuple:
|
||||
"""Return args for GPU initializer."""
|
||||
return (self.gpu_id,)
|
||||
174
src/processing/task_dispatcher.py
Normal file
174
src/processing/task_dispatcher.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Task Dispatcher for classifying and routing tasks to appropriate worker pools.
|
||||
|
||||
Determines whether a document should be processed by CPU (text PDF) or
|
||||
GPU (scanned PDF requiring OCR) workers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
"""Task type classification."""
|
||||
|
||||
CPU = auto() # Text PDF - no OCR needed
|
||||
GPU = auto() # Scanned PDF - requires OCR
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""
|
||||
Represents a processing task.
|
||||
|
||||
Attributes:
|
||||
id: Unique task identifier.
|
||||
task_type: Whether task needs CPU or GPU processing.
|
||||
data: Task payload (document info, paths, etc.).
|
||||
"""
|
||||
|
||||
id: str
|
||||
task_type: TaskType
|
||||
data: Any
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
"""
|
||||
Classifies and partitions tasks for CPU and GPU worker pools.
|
||||
|
||||
Uses PDF characteristics to determine if OCR is needed:
|
||||
- Text PDFs with extractable text -> CPU
|
||||
- Scanned PDFs / image-based PDFs -> GPU (OCR)
|
||||
|
||||
Example:
|
||||
dispatcher = TaskDispatcher()
|
||||
tasks = [Task(id="1", task_type=dispatcher.classify(doc), data=doc) for doc in docs]
|
||||
cpu_tasks, gpu_tasks = dispatcher.partition_tasks(tasks)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text_char_threshold: int = 100,
|
||||
ocr_ratio_threshold: float = 0.3,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the task dispatcher.
|
||||
|
||||
Args:
|
||||
text_char_threshold: Minimum characters to consider as text PDF.
|
||||
ocr_ratio_threshold: If text/expected ratio below this, use OCR.
|
||||
"""
|
||||
self.text_char_threshold = text_char_threshold
|
||||
self.ocr_ratio_threshold = ocr_ratio_threshold
|
||||
|
||||
def classify_by_pdf_info(
|
||||
self,
|
||||
has_text: bool,
|
||||
text_length: int,
|
||||
page_count: int = 1,
|
||||
) -> TaskType:
|
||||
"""
|
||||
Classify task based on PDF text extraction info.
|
||||
|
||||
Args:
|
||||
has_text: Whether PDF has extractable text layer.
|
||||
text_length: Number of characters extracted.
|
||||
page_count: Number of pages in PDF.
|
||||
|
||||
Returns:
|
||||
TaskType.CPU for text PDFs, TaskType.GPU for scanned PDFs.
|
||||
"""
|
||||
if not has_text:
|
||||
return TaskType.GPU
|
||||
|
||||
# Check if text density is reasonable
|
||||
avg_chars_per_page = text_length / max(page_count, 1)
|
||||
if avg_chars_per_page < self.text_char_threshold:
|
||||
return TaskType.GPU
|
||||
|
||||
return TaskType.CPU
|
||||
|
||||
def classify_document(self, doc_info: dict) -> TaskType:
|
||||
"""
|
||||
Classify a document based on its metadata.
|
||||
|
||||
Args:
|
||||
doc_info: Document information dict with keys like:
|
||||
- 'is_scanned': bool (if known)
|
||||
- 'text_length': int
|
||||
- 'page_count': int
|
||||
- 'pdf_path': str
|
||||
|
||||
Returns:
|
||||
TaskType for the document.
|
||||
"""
|
||||
# If explicitly marked as scanned
|
||||
if doc_info.get("is_scanned", False):
|
||||
return TaskType.GPU
|
||||
|
||||
# If we have text extraction info
|
||||
text_length = doc_info.get("text_length", 0)
|
||||
page_count = doc_info.get("page_count", 1)
|
||||
has_text = doc_info.get("has_text", text_length > 0)
|
||||
|
||||
return self.classify_by_pdf_info(
|
||||
has_text=has_text,
|
||||
text_length=text_length,
|
||||
page_count=page_count,
|
||||
)
|
||||
|
||||
def partition_tasks(
|
||||
self,
|
||||
tasks: List[Task],
|
||||
) -> Tuple[List[Task], List[Task]]:
|
||||
"""
|
||||
Partition tasks into CPU and GPU groups.
|
||||
|
||||
Args:
|
||||
tasks: List of Task objects with task_type set.
|
||||
|
||||
Returns:
|
||||
Tuple of (cpu_tasks, gpu_tasks).
|
||||
"""
|
||||
cpu_tasks = [t for t in tasks if t.task_type == TaskType.CPU]
|
||||
gpu_tasks = [t for t in tasks if t.task_type == TaskType.GPU]
|
||||
|
||||
logger.info(
|
||||
f"Task partition: {len(cpu_tasks)} CPU tasks, {len(gpu_tasks)} GPU tasks"
|
||||
)
|
||||
|
||||
return cpu_tasks, gpu_tasks
|
||||
|
||||
def create_tasks(
|
||||
self,
|
||||
documents: List[dict],
|
||||
id_field: str = "id",
|
||||
) -> List[Task]:
|
||||
"""
|
||||
Create Task objects from document dicts.
|
||||
|
||||
Args:
|
||||
documents: List of document info dicts.
|
||||
id_field: Field name to use as task ID.
|
||||
|
||||
Returns:
|
||||
List of Task objects with types classified.
|
||||
"""
|
||||
tasks = []
|
||||
for doc in documents:
|
||||
task_id = str(doc.get(id_field, id(doc)))
|
||||
task_type = self.classify_document(doc)
|
||||
tasks.append(Task(id=task_id, task_type=task_type, data=doc))
|
||||
|
||||
cpu_count = sum(1 for t in tasks if t.task_type == TaskType.CPU)
|
||||
gpu_count = len(tasks) - cpu_count
|
||||
logger.debug(f"Created {len(tasks)} tasks: {cpu_count} CPU, {gpu_count} GPU")
|
||||
|
||||
return tasks
|
||||
182
src/processing/worker_pool.py
Normal file
182
src/processing/worker_pool.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""
|
||||
Abstract base class for worker pools.
|
||||
|
||||
Provides a unified interface for CPU and GPU worker pools with proper
|
||||
initialization, task submission, and resource cleanup.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskResult:
|
||||
"""Container for task execution results."""
|
||||
|
||||
task_id: str
|
||||
success: bool
|
||||
data: Any
|
||||
error: Optional[str] = None
|
||||
processing_time: float = 0.0
|
||||
pool_type: str = ""
|
||||
extra: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class WorkerPool(ABC):
|
||||
"""
|
||||
Abstract base class for worker pools.
|
||||
|
||||
Provides a common interface for ProcessPoolExecutor-based worker pools
|
||||
with proper initialization using the 'spawn' start method for CUDA
|
||||
compatibility.
|
||||
|
||||
Attributes:
|
||||
max_workers: Maximum number of worker processes.
|
||||
use_gpu: Whether this pool uses GPU resources.
|
||||
gpu_id: GPU device ID (only relevant if use_gpu=True).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_workers: int,
|
||||
use_gpu: bool = False,
|
||||
gpu_id: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the worker pool configuration.
|
||||
|
||||
Args:
|
||||
max_workers: Maximum number of worker processes.
|
||||
use_gpu: Whether this pool uses GPU resources.
|
||||
gpu_id: GPU device ID for GPU pools.
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.use_gpu = use_gpu
|
||||
self.gpu_id = gpu_id
|
||||
self._executor: Optional[ProcessPoolExecutor] = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the pool name for logging."""
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def get_initializer(self) -> Optional[Callable[..., None]]:
|
||||
"""
|
||||
Return the worker initialization function.
|
||||
|
||||
This function is called once per worker process when it starts.
|
||||
Use it to load models, set environment variables, etc.
|
||||
|
||||
Returns:
|
||||
Callable to initialize each worker, or None if no initialization needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_init_args(self) -> tuple:
|
||||
"""
|
||||
Return arguments for the initializer function.
|
||||
|
||||
Returns:
|
||||
Tuple of arguments to pass to the initializer.
|
||||
"""
|
||||
pass
|
||||
|
||||
def start(self) -> None:
|
||||
"""
|
||||
Start the worker pool.
|
||||
|
||||
Creates a ProcessPoolExecutor with the 'spawn' start method
|
||||
for CUDA compatibility.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the pool is already started.
|
||||
"""
|
||||
if self._started:
|
||||
raise RuntimeError(f"{self.name} is already started")
|
||||
|
||||
# Use 'spawn' for CUDA compatibility
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
initializer = self.get_initializer()
|
||||
initargs = self.get_init_args()
|
||||
|
||||
logger.info(
|
||||
f"Starting {self.name} with {self.max_workers} workers "
|
||||
f"(GPU: {self.use_gpu}, GPU ID: {self.gpu_id})"
|
||||
)
|
||||
|
||||
self._executor = ProcessPoolExecutor(
|
||||
max_workers=self.max_workers,
|
||||
mp_context=ctx,
|
||||
initializer=initializer,
|
||||
initargs=initargs if initializer else (),
|
||||
)
|
||||
self._started = True
|
||||
|
||||
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
|
||||
"""
|
||||
Submit a task to the worker pool.
|
||||
|
||||
Args:
|
||||
fn: Function to execute.
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
Future representing the pending result.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the pool is not started.
|
||||
"""
|
||||
if not self._started or self._executor is None:
|
||||
raise RuntimeError(f"{self.name} is not started. Call start() first.")
|
||||
|
||||
return self._executor.submit(fn, *args, **kwargs)
|
||||
|
||||
def shutdown(self, wait: bool = True, cancel_futures: bool = False) -> None:
|
||||
"""
|
||||
Shutdown the worker pool.
|
||||
|
||||
Args:
|
||||
wait: If True, wait for all pending futures to complete.
|
||||
cancel_futures: If True, cancel all pending futures.
|
||||
"""
|
||||
if self._executor is not None:
|
||||
logger.info(f"Shutting down {self.name} (wait={wait})")
|
||||
self._executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
||||
self._executor = None
|
||||
self._started = False
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""Check if the pool is currently running."""
|
||||
return self._started and self._executor is not None
|
||||
|
||||
def __enter__(self) -> "WorkerPool":
|
||||
"""Context manager entry - start the pool."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
"""Context manager exit - shutdown the pool."""
|
||||
self.shutdown(wait=True)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "running" if self.is_running else "stopped"
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f"workers={self.max_workers}, "
|
||||
f"gpu={self.use_gpu}, "
|
||||
f"status={status})"
|
||||
)
|
||||
9
src/web/__init__.py
Normal file
9
src/web/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Web Application Module
|
||||
|
||||
Provides REST API and web interface for invoice field extraction.
|
||||
"""
|
||||
|
||||
from .app import create_app
|
||||
|
||||
__all__ = ["create_app"]
|
||||
616
src/web/app.py
Normal file
616
src/web/app.py
Normal file
@@ -0,0 +1,616 @@
|
||||
"""
|
||||
FastAPI Application Factory
|
||||
|
||||
Creates and configures the FastAPI application.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from .config import AppConfig, default_config
|
||||
from .routes import create_api_router
|
||||
from .services import InferenceService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_app(config: AppConfig | None = None) -> FastAPI:
|
||||
"""
|
||||
Create and configure FastAPI application.
|
||||
|
||||
Args:
|
||||
config: Application configuration. Uses default if not provided.
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
config = config or default_config
|
||||
|
||||
# Create inference service
|
||||
inference_service = InferenceService(
|
||||
model_config=config.model,
|
||||
storage_config=config.storage,
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application lifespan manager."""
|
||||
logger.info("Starting Invoice Inference API...")
|
||||
|
||||
# Initialize inference service on startup
|
||||
try:
|
||||
inference_service.initialize()
|
||||
logger.info("Inference service ready")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
# Continue anyway - service will retry on first request
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down Invoice Inference API...")
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Invoice Field Extraction API",
|
||||
description="""
|
||||
REST API for extracting fields from Swedish invoices.
|
||||
|
||||
## Features
|
||||
- YOLO-based field detection
|
||||
- OCR text extraction
|
||||
- Field normalization and validation
|
||||
- Visualization of detections
|
||||
|
||||
## Supported Fields
|
||||
- InvoiceNumber
|
||||
- InvoiceDate
|
||||
- InvoiceDueDate
|
||||
- OCR (reference number)
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- Amount
|
||||
""",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Mount static files for results
|
||||
config.storage.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
app.mount(
|
||||
"/static/results",
|
||||
StaticFiles(directory=str(config.storage.result_dir)),
|
||||
name="results",
|
||||
)
|
||||
|
||||
# Include API routes
|
||||
api_router = create_api_router(inference_service, config.storage)
|
||||
app.include_router(api_router)
|
||||
|
||||
# Root endpoint - serve HTML UI
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root() -> str:
|
||||
"""Serve the web UI."""
|
||||
return get_html_ui()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def get_html_ui() -> str:
|
||||
"""Generate HTML UI for the web application."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Invoice Field Extraction</title>
|
||||
<style>
|
||||
* {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1200px;
|
||||
margin: 0 auto;
|
||||
}
|
||||
|
||||
header {
|
||||
text-align: center;
|
||||
color: white;
|
||||
margin-bottom: 30px;
|
||||
}
|
||||
|
||||
header h1 {
|
||||
font-size: 2.5rem;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
header p {
|
||||
opacity: 0.9;
|
||||
font-size: 1.1rem;
|
||||
}
|
||||
|
||||
.main-content {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr 1fr;
|
||||
gap: 20px;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
.main-content {
|
||||
grid-template-columns: 1fr;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
background: white;
|
||||
border-radius: 16px;
|
||||
padding: 24px;
|
||||
box-shadow: 0 10px 40px rgba(0,0,0,0.2);
|
||||
}
|
||||
|
||||
.card h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
font-size: 1.3rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 10px;
|
||||
}
|
||||
|
||||
.upload-area {
|
||||
border: 3px dashed #ddd;
|
||||
border-radius: 12px;
|
||||
padding: 40px;
|
||||
text-align: center;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
background: #fafafa;
|
||||
}
|
||||
|
||||
.upload-area:hover, .upload-area.dragover {
|
||||
border-color: #667eea;
|
||||
background: #f0f4ff;
|
||||
}
|
||||
|
||||
.upload-area.has-file {
|
||||
border-color: #10b981;
|
||||
background: #ecfdf5;
|
||||
}
|
||||
|
||||
.upload-icon {
|
||||
font-size: 48px;
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
|
||||
.upload-area p {
|
||||
color: #666;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.upload-area small {
|
||||
color: #999;
|
||||
}
|
||||
|
||||
#file-input {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.file-name {
|
||||
margin-top: 15px;
|
||||
padding: 10px 15px;
|
||||
background: #e0f2fe;
|
||||
border-radius: 8px;
|
||||
color: #0369a1;
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.btn {
|
||||
display: inline-block;
|
||||
padding: 14px 28px;
|
||||
border: none;
|
||||
border-radius: 10px;
|
||||
font-size: 1rem;
|
||||
font-weight: 600;
|
||||
cursor: pointer;
|
||||
transition: all 0.3s;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
}
|
||||
|
||||
.btn-primary {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
width: 100%;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.btn-primary:hover:not(:disabled) {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 20px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
.btn-primary:disabled {
|
||||
opacity: 0.6;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.loading {
|
||||
display: none;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.loading.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 40px;
|
||||
height: 40px;
|
||||
border: 4px solid #f3f3f3;
|
||||
border-top: 4px solid #667eea;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
margin: 0 auto 15px;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
.results {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.results.active {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.result-header {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 20px;
|
||||
padding-bottom: 15px;
|
||||
border-bottom: 2px solid #eee;
|
||||
}
|
||||
|
||||
.result-status {
|
||||
padding: 6px 12px;
|
||||
border-radius: 20px;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.result-status.success {
|
||||
background: #dcfce7;
|
||||
color: #166534;
|
||||
}
|
||||
|
||||
.result-status.partial {
|
||||
background: #fef3c7;
|
||||
color: #92400e;
|
||||
}
|
||||
|
||||
.result-status.error {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
}
|
||||
|
||||
.fields-grid {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(2, 1fr);
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.field-item {
|
||||
padding: 12px;
|
||||
background: #f8fafc;
|
||||
border-radius: 10px;
|
||||
border-left: 4px solid #667eea;
|
||||
}
|
||||
|
||||
.field-item label {
|
||||
display: block;
|
||||
font-size: 0.75rem;
|
||||
color: #64748b;
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.5px;
|
||||
margin-bottom: 4px;
|
||||
}
|
||||
|
||||
.field-item .value {
|
||||
font-size: 1.1rem;
|
||||
font-weight: 600;
|
||||
color: #1e293b;
|
||||
}
|
||||
|
||||
.field-item .confidence {
|
||||
font-size: 0.75rem;
|
||||
color: #10b981;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.visualization {
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.visualization img {
|
||||
width: 100%;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 4px 20px rgba(0,0,0,0.1);
|
||||
}
|
||||
|
||||
.processing-time {
|
||||
text-align: center;
|
||||
color: #64748b;
|
||||
font-size: 0.9rem;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
.error-message {
|
||||
background: #fee2e2;
|
||||
color: #991b1b;
|
||||
padding: 15px;
|
||||
border-radius: 10px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
|
||||
footer {
|
||||
text-align: center;
|
||||
color: white;
|
||||
opacity: 0.8;
|
||||
margin-top: 30px;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<header>
|
||||
<h1>📄 Invoice Field Extraction</h1>
|
||||
<p>Upload a Swedish invoice (PDF or image) to extract fields automatically</p>
|
||||
</header>
|
||||
|
||||
<div class="main-content">
|
||||
<div class="card">
|
||||
<h2>📤 Upload Document</h2>
|
||||
|
||||
<div class="upload-area" id="upload-area">
|
||||
<div class="upload-icon">📁</div>
|
||||
<p>Drag & drop your file here</p>
|
||||
<p>or <strong>click to browse</strong></p>
|
||||
<small>Supports PDF, PNG, JPG (max 50MB)</small>
|
||||
<input type="file" id="file-input" accept=".pdf,.png,.jpg,.jpeg">
|
||||
<div class="file-name" id="file-name" style="display: none;"></div>
|
||||
</div>
|
||||
|
||||
<button class="btn btn-primary" id="submit-btn" disabled>
|
||||
🚀 Extract Fields
|
||||
</button>
|
||||
|
||||
<div class="loading" id="loading">
|
||||
<div class="spinner"></div>
|
||||
<p>Processing document...</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="card">
|
||||
<h2>📊 Extraction Results</h2>
|
||||
|
||||
<div id="placeholder" style="text-align: center; padding: 40px; color: #999;">
|
||||
<div style="font-size: 64px; margin-bottom: 15px;">🔍</div>
|
||||
<p>Upload a document to see extraction results</p>
|
||||
</div>
|
||||
|
||||
<div class="results" id="results">
|
||||
<div class="result-header">
|
||||
<span>Document: <strong id="doc-id"></strong></span>
|
||||
<span class="result-status" id="result-status"></span>
|
||||
</div>
|
||||
|
||||
<div class="fields-grid" id="fields-grid"></div>
|
||||
|
||||
<div class="processing-time" id="processing-time"></div>
|
||||
|
||||
<div class="error-message" id="error-message" style="display: none;"></div>
|
||||
|
||||
<div class="visualization" id="visualization" style="display: none;">
|
||||
<h3 style="margin-bottom: 10px; color: #333;">🎯 Detection Visualization</h3>
|
||||
<img id="viz-image" src="" alt="Detection visualization">
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
<p>Powered by ColaCoder</p>
|
||||
</footer>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const uploadArea = document.getElementById('upload-area');
|
||||
const fileInput = document.getElementById('file-input');
|
||||
const fileName = document.getElementById('file-name');
|
||||
const submitBtn = document.getElementById('submit-btn');
|
||||
const loading = document.getElementById('loading');
|
||||
const placeholder = document.getElementById('placeholder');
|
||||
const results = document.getElementById('results');
|
||||
|
||||
let selectedFile = null;
|
||||
|
||||
// Drag and drop handlers
|
||||
uploadArea.addEventListener('click', () => fileInput.click());
|
||||
|
||||
uploadArea.addEventListener('dragover', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.add('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('dragleave', () => {
|
||||
uploadArea.classList.remove('dragover');
|
||||
});
|
||||
|
||||
uploadArea.addEventListener('drop', (e) => {
|
||||
e.preventDefault();
|
||||
uploadArea.classList.remove('dragover');
|
||||
const files = e.dataTransfer.files;
|
||||
if (files.length > 0) {
|
||||
handleFile(files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
fileInput.addEventListener('change', (e) => {
|
||||
if (e.target.files.length > 0) {
|
||||
handleFile(e.target.files[0]);
|
||||
}
|
||||
});
|
||||
|
||||
function handleFile(file) {
|
||||
const validTypes = ['.pdf', '.png', '.jpg', '.jpeg'];
|
||||
const ext = '.' + file.name.split('.').pop().toLowerCase();
|
||||
|
||||
if (!validTypes.includes(ext)) {
|
||||
alert('Please upload a PDF, PNG, or JPG file.');
|
||||
return;
|
||||
}
|
||||
|
||||
selectedFile = file;
|
||||
fileName.textContent = `📎 ${file.name}`;
|
||||
fileName.style.display = 'block';
|
||||
uploadArea.classList.add('has-file');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
|
||||
submitBtn.addEventListener('click', async () => {
|
||||
if (!selectedFile) return;
|
||||
|
||||
// Show loading
|
||||
submitBtn.disabled = true;
|
||||
loading.classList.add('active');
|
||||
placeholder.style.display = 'none';
|
||||
results.classList.remove('active');
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', selectedFile);
|
||||
|
||||
const response = await fetch('/api/v1/infer', {
|
||||
method: 'POST',
|
||||
body: formData,
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(data.detail || 'Processing failed');
|
||||
}
|
||||
|
||||
displayResults(data);
|
||||
|
||||
} catch (error) {
|
||||
console.error('Error:', error);
|
||||
document.getElementById('error-message').textContent = error.message;
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
results.classList.add('active');
|
||||
} finally {
|
||||
loading.classList.remove('active');
|
||||
submitBtn.disabled = false;
|
||||
}
|
||||
});
|
||||
|
||||
function displayResults(data) {
|
||||
const result = data.result;
|
||||
|
||||
// Document ID
|
||||
document.getElementById('doc-id').textContent = result.document_id;
|
||||
|
||||
// Status
|
||||
const statusEl = document.getElementById('result-status');
|
||||
statusEl.textContent = result.success ? 'Success' : 'Partial';
|
||||
statusEl.className = 'result-status ' + (result.success ? 'success' : 'partial');
|
||||
|
||||
// Fields
|
||||
const fieldsGrid = document.getElementById('fields-grid');
|
||||
fieldsGrid.innerHTML = '';
|
||||
|
||||
const fieldOrder = ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Amount', 'Bankgiro', 'Plusgiro'];
|
||||
|
||||
fieldOrder.forEach(field => {
|
||||
const value = result.fields[field];
|
||||
const confidence = result.confidence[field];
|
||||
|
||||
if (value !== null && value !== undefined) {
|
||||
const fieldDiv = document.createElement('div');
|
||||
fieldDiv.className = 'field-item';
|
||||
fieldDiv.innerHTML = `
|
||||
<label>${formatFieldName(field)}</label>
|
||||
<div class="value">${value}</div>
|
||||
${confidence ? `<div class="confidence">✓ ${(confidence * 100).toFixed(1)}% confident</div>` : ''}
|
||||
`;
|
||||
fieldsGrid.appendChild(fieldDiv);
|
||||
}
|
||||
});
|
||||
|
||||
// Processing time
|
||||
document.getElementById('processing-time').textContent =
|
||||
`⏱️ Processed in ${result.processing_time_ms.toFixed(0)}ms`;
|
||||
|
||||
// Visualization
|
||||
if (result.visualization_url) {
|
||||
const vizDiv = document.getElementById('visualization');
|
||||
const vizImg = document.getElementById('viz-image');
|
||||
vizImg.src = result.visualization_url;
|
||||
vizDiv.style.display = 'block';
|
||||
}
|
||||
|
||||
// Errors
|
||||
if (result.errors && result.errors.length > 0) {
|
||||
document.getElementById('error-message').textContent = result.errors.join(', ');
|
||||
document.getElementById('error-message').style.display = 'block';
|
||||
} else {
|
||||
document.getElementById('error-message').style.display = 'none';
|
||||
}
|
||||
|
||||
results.classList.add('active');
|
||||
}
|
||||
|
||||
function formatFieldName(name) {
|
||||
return name.replace(/([A-Z])/g, ' $1').trim();
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
69
src/web/config.py
Normal file
69
src/web/config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Web Application Configuration
|
||||
|
||||
Centralized configuration for the web application.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelConfig:
|
||||
"""YOLO model configuration."""
|
||||
|
||||
model_path: Path = Path("runs/train/invoice_yolo11n_full/weights/best.pt")
|
||||
confidence_threshold: float = 0.3
|
||||
use_gpu: bool = True
|
||||
dpi: int = 150
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ServerConfig:
|
||||
"""Server configuration."""
|
||||
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
debug: bool = False
|
||||
reload: bool = False
|
||||
workers: int = 1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StorageConfig:
|
||||
"""File storage configuration."""
|
||||
|
||||
upload_dir: Path = Path("uploads")
|
||||
result_dir: Path = Path("results")
|
||||
max_file_size_mb: int = 50
|
||||
allowed_extensions: tuple[str, ...] = (".pdf", ".png", ".jpg", ".jpeg")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Create directories if they don't exist."""
|
||||
object.__setattr__(self, "upload_dir", Path(self.upload_dir))
|
||||
object.__setattr__(self, "result_dir", Path(self.result_dir))
|
||||
self.upload_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.result_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
|
||||
model: ModelConfig = field(default_factory=ModelConfig)
|
||||
server: ServerConfig = field(default_factory=ServerConfig)
|
||||
storage: StorageConfig = field(default_factory=StorageConfig)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: dict[str, Any]) -> "AppConfig":
|
||||
"""Create config from dictionary."""
|
||||
return cls(
|
||||
model=ModelConfig(**config_dict.get("model", {})),
|
||||
server=ServerConfig(**config_dict.get("server", {})),
|
||||
storage=StorageConfig(**config_dict.get("storage", {})),
|
||||
)
|
||||
|
||||
|
||||
# Default configuration instance
|
||||
default_config = AppConfig()
|
||||
183
src/web/routes.py
Normal file
183
src/web/routes.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
API Routes
|
||||
|
||||
FastAPI route definitions for the inference API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import APIRouter, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .schemas import (
|
||||
BatchInferenceResponse,
|
||||
DetectionResult,
|
||||
ErrorResponse,
|
||||
HealthResponse,
|
||||
InferenceResponse,
|
||||
InferenceResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .services import InferenceService
|
||||
from .config import StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_api_router(
|
||||
inference_service: "InferenceService",
|
||||
storage_config: "StorageConfig",
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Create API router with inference endpoints.
|
||||
|
||||
Args:
|
||||
inference_service: Inference service instance
|
||||
storage_config: Storage configuration
|
||||
|
||||
Returns:
|
||||
Configured APIRouter
|
||||
"""
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check() -> HealthResponse:
|
||||
"""Check service health status."""
|
||||
return HealthResponse(
|
||||
status="healthy",
|
||||
model_loaded=inference_service.is_initialized,
|
||||
gpu_available=inference_service.gpu_available,
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/infer",
|
||||
response_model=InferenceResponse,
|
||||
responses={
|
||||
400: {"model": ErrorResponse, "description": "Invalid file"},
|
||||
500: {"model": ErrorResponse, "description": "Processing error"},
|
||||
},
|
||||
)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(..., description="PDF or image file to process"),
|
||||
) -> InferenceResponse:
|
||||
"""
|
||||
Process a document and extract invoice fields.
|
||||
|
||||
Accepts PDF or image files (PNG, JPG, JPEG).
|
||||
Returns extracted field values with confidence scores.
|
||||
"""
|
||||
# Validate file extension
|
||||
if not file.filename:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Filename is required",
|
||||
)
|
||||
|
||||
file_ext = Path(file.filename).suffix.lower()
|
||||
if file_ext not in storage_config.allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unsupported file type: {file_ext}. Allowed: {storage_config.allowed_extensions}",
|
||||
)
|
||||
|
||||
# Generate document ID
|
||||
doc_id = str(uuid.uuid4())[:8]
|
||||
|
||||
# Save uploaded file
|
||||
upload_path = storage_config.upload_dir / f"{doc_id}{file_ext}"
|
||||
try:
|
||||
with open(upload_path, "wb") as f:
|
||||
shutil.copyfileobj(file.file, f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save uploaded file: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to save uploaded file",
|
||||
)
|
||||
|
||||
try:
|
||||
# Process based on file type
|
||||
if file_ext == ".pdf":
|
||||
service_result = inference_service.process_pdf(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
else:
|
||||
service_result = inference_service.process_image(
|
||||
upload_path, document_id=doc_id
|
||||
)
|
||||
|
||||
# Build response
|
||||
viz_url = None
|
||||
if service_result.visualization_path:
|
||||
viz_url = f"/api/v1/results/{service_result.visualization_path.name}"
|
||||
|
||||
inference_result = InferenceResult(
|
||||
document_id=service_result.document_id,
|
||||
success=service_result.success,
|
||||
fields=service_result.fields,
|
||||
confidence=service_result.confidence,
|
||||
detections=[
|
||||
DetectionResult(**d) for d in service_result.detections
|
||||
],
|
||||
processing_time_ms=service_result.processing_time_ms,
|
||||
visualization_url=viz_url,
|
||||
errors=service_result.errors,
|
||||
)
|
||||
|
||||
return InferenceResponse(
|
||||
status="success" if service_result.success else "partial",
|
||||
message=f"Processed document {doc_id}",
|
||||
result=inference_result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing document: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=str(e),
|
||||
)
|
||||
|
||||
finally:
|
||||
# Cleanup uploaded file
|
||||
upload_path.unlink(missing_ok=True)
|
||||
|
||||
@router.get("/results/{filename}")
|
||||
async def get_result_image(filename: str) -> FileResponse:
|
||||
"""Get visualization result image."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
path=file_path,
|
||||
media_type="image/png",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
@router.delete("/results/{filename}")
|
||||
async def delete_result(filename: str) -> dict:
|
||||
"""Delete a result file."""
|
||||
file_path = storage_config.result_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Result file not found: {filename}",
|
||||
)
|
||||
|
||||
file_path.unlink()
|
||||
return {"status": "deleted", "filename": filename}
|
||||
|
||||
return router
|
||||
83
src/web/schemas.py
Normal file
83
src/web/schemas.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
API Request/Response Schemas
|
||||
|
||||
Pydantic models for API validation and serialization.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DetectionResult(BaseModel):
|
||||
"""Single detection result."""
|
||||
|
||||
field: str = Field(..., description="Field type (e.g., invoice_number, amount)")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Detection confidence")
|
||||
bbox: list[float] = Field(..., description="Bounding box [x1, y1, x2, y2]")
|
||||
|
||||
|
||||
class ExtractedField(BaseModel):
|
||||
"""Extracted and normalized field value."""
|
||||
|
||||
field_name: str = Field(..., description="Field name")
|
||||
value: str | None = Field(None, description="Extracted value")
|
||||
confidence: float = Field(..., ge=0, le=1, description="Extraction confidence")
|
||||
is_valid: bool = Field(True, description="Whether the value passed validation")
|
||||
|
||||
|
||||
class InferenceResult(BaseModel):
|
||||
"""Complete inference result for a document."""
|
||||
|
||||
document_id: str = Field(..., description="Document identifier")
|
||||
success: bool = Field(..., description="Whether inference succeeded")
|
||||
fields: dict[str, str | None] = Field(
|
||||
default_factory=dict, description="Extracted field values"
|
||||
)
|
||||
confidence: dict[str, float] = Field(
|
||||
default_factory=dict, description="Confidence scores per field"
|
||||
)
|
||||
detections: list[DetectionResult] = Field(
|
||||
default_factory=list, description="Raw YOLO detections"
|
||||
)
|
||||
processing_time_ms: float = Field(..., description="Processing time in milliseconds")
|
||||
visualization_url: str | None = Field(
|
||||
None, description="URL to visualization image"
|
||||
)
|
||||
errors: list[str] = Field(default_factory=list, description="Error messages")
|
||||
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
"""API response for inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status: success or error")
|
||||
message: str = Field(..., description="Response message")
|
||||
result: InferenceResult | None = Field(None, description="Inference result")
|
||||
|
||||
|
||||
class BatchInferenceResponse(BaseModel):
|
||||
"""API response for batch inference endpoint."""
|
||||
|
||||
status: str = Field(..., description="Response status")
|
||||
message: str = Field(..., description="Response message")
|
||||
total: int = Field(..., description="Total documents processed")
|
||||
successful: int = Field(..., description="Number of successful extractions")
|
||||
results: list[InferenceResult] = Field(
|
||||
default_factory=list, description="Individual results"
|
||||
)
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response."""
|
||||
|
||||
status: str = Field(..., description="Service status")
|
||||
model_loaded: bool = Field(..., description="Whether model is loaded")
|
||||
gpu_available: bool = Field(..., description="Whether GPU is available")
|
||||
version: str = Field(..., description="API version")
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response."""
|
||||
|
||||
status: str = Field(default="error", description="Error status")
|
||||
message: str = Field(..., description="Error message")
|
||||
detail: str | None = Field(None, description="Detailed error information")
|
||||
270
src/web/services.py
Normal file
270
src/web/services.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""
|
||||
Inference Service
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config import ModelConfig, StorageConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceResult:
|
||||
"""Result from inference service."""
|
||||
|
||||
document_id: str
|
||||
success: bool = False
|
||||
fields: dict[str, str | None] = field(default_factory=dict)
|
||||
confidence: dict[str, float] = field(default_factory=dict)
|
||||
detections: list[dict] = field(default_factory=list)
|
||||
processing_time_ms: float = 0.0
|
||||
visualization_path: Path | None = None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration
|
||||
storage_config: Storage configuration
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
self._is_initialized = False
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the inference pipeline (lazy loading)."""
|
||||
if self._is_initialized:
|
||||
return
|
||||
|
||||
logger.info("Initializing inference service...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from ..inference.pipeline import InferencePipeline
|
||||
from ..inference.yolo_detector import YOLODetector
|
||||
|
||||
# Initialize YOLO detector for visualization
|
||||
self._detector = YOLODetector(
|
||||
str(self.model_config.model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
device="cuda" if self.model_config.use_gpu else "cpu",
|
||||
)
|
||||
|
||||
# Initialize full pipeline
|
||||
self._pipeline = InferencePipeline(
|
||||
model_path=str(self.model_config.model_path),
|
||||
confidence_threshold=self.model_config.confidence_threshold,
|
||||
use_gpu=self.model_config.use_gpu,
|
||||
dpi=self.model_config.dpi,
|
||||
enable_fallback=True,
|
||||
)
|
||||
|
||||
self._is_initialized = True
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"Inference service initialized in {elapsed:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize inference service: {e}")
|
||||
raise
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if service is initialized."""
|
||||
return self._is_initialized
|
||||
|
||||
@property
|
||||
def gpu_available(self) -> bool:
|
||||
"""Check if GPU is available."""
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
def process_image(
|
||||
self,
|
||||
image_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process an image file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_image(image_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_visualization(image_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image {image_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def process_pdf(
|
||||
self,
|
||||
pdf_path: Path,
|
||||
document_id: str | None = None,
|
||||
save_visualization: bool = True,
|
||||
) -> ServiceResult:
|
||||
"""
|
||||
Process a PDF file and extract invoice fields.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to PDF file
|
||||
document_id: Optional document ID
|
||||
save_visualization: Whether to save visualization
|
||||
|
||||
Returns:
|
||||
ServiceResult with extracted fields
|
||||
"""
|
||||
if not self._is_initialized:
|
||||
self.initialize()
|
||||
|
||||
doc_id = document_id or str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
|
||||
result = ServiceResult(document_id=doc_id)
|
||||
|
||||
try:
|
||||
# Run inference pipeline
|
||||
pipeline_result = self._pipeline.process_pdf(pdf_path, document_id=doc_id)
|
||||
|
||||
result.fields = pipeline_result.fields
|
||||
result.confidence = pipeline_result.confidence
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
|
||||
# Save visualization (render first page)
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
viz_path = self._save_pdf_visualization(pdf_path, doc_id)
|
||||
result.visualization_path = viz_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing PDF {pdf_path}: {e}")
|
||||
result.errors.append(str(e))
|
||||
result.success = False
|
||||
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
from ..pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Render first page
|
||||
for page_no, image_bytes in render_pdf_to_images(
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = self.storage_config.result_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
|
||||
output_path = self.storage_config.result_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
@@ -1,4 +1,5 @@
|
||||
from .annotation_generator import AnnotationGenerator, generate_annotations
|
||||
from .dataset_builder import DatasetBuilder
|
||||
from .db_dataset import DBYOLODataset, create_datasets
|
||||
|
||||
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder']
|
||||
__all__ = ['AnnotationGenerator', 'generate_annotations', 'DatasetBuilder', 'DBYOLODataset', 'create_datasets']
|
||||
|
||||
@@ -174,20 +174,47 @@ class AnnotationGenerator:
|
||||
output_path: str | Path,
|
||||
train_path: str = 'train/images',
|
||||
val_path: str = 'val/images',
|
||||
test_path: str = 'test/images'
|
||||
test_path: str = 'test/images',
|
||||
use_wsl_paths: bool | None = None
|
||||
) -> None:
|
||||
"""Generate YOLO dataset YAML configuration."""
|
||||
"""
|
||||
Generate YOLO dataset YAML configuration.
|
||||
|
||||
Args:
|
||||
output_path: Path to output YAML file
|
||||
train_path: Relative path to training images
|
||||
val_path: Relative path to validation images
|
||||
test_path: Relative path to test images
|
||||
use_wsl_paths: If True, convert Windows paths to WSL format.
|
||||
If None, auto-detect based on environment.
|
||||
"""
|
||||
import os
|
||||
import platform
|
||||
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use absolute path for WSL compatibility
|
||||
dataset_dir = output_path.parent.absolute()
|
||||
# Convert Windows path to WSL path if needed
|
||||
dataset_path_str = str(dataset_dir).replace('\\', '/')
|
||||
if dataset_path_str[1] == ':':
|
||||
# Windows path like C:/... -> /mnt/c/...
|
||||
dataset_path_str = str(dataset_dir)
|
||||
|
||||
# Auto-detect WSL environment
|
||||
if use_wsl_paths is None:
|
||||
# Check if running inside WSL
|
||||
is_wsl = 'microsoft' in platform.uname().release.lower() if platform.system() == 'Linux' else False
|
||||
# Check WSL_DISTRO_NAME environment variable (set when running in WSL)
|
||||
is_wsl = is_wsl or os.environ.get('WSL_DISTRO_NAME') is not None
|
||||
use_wsl_paths = is_wsl
|
||||
|
||||
# Convert path format based on environment
|
||||
if use_wsl_paths:
|
||||
# Running in WSL: convert Windows paths to /mnt/c/... format
|
||||
dataset_path_str = dataset_path_str.replace('\\', '/')
|
||||
if len(dataset_path_str) > 1 and dataset_path_str[1] == ':':
|
||||
drive = dataset_path_str[0].lower()
|
||||
dataset_path_str = f"/mnt/{drive}{dataset_path_str[2:]}"
|
||||
elif platform.system() == 'Windows':
|
||||
# Running on native Windows: use forward slashes for YOLO compatibility
|
||||
dataset_path_str = dataset_path_str.replace('\\', '/')
|
||||
|
||||
config = f"""# Invoice Field Detection Dataset
|
||||
path: {dataset_path_str}
|
||||
|
||||
625
src/yolo/db_dataset.py
Normal file
625
src/yolo/db_dataset.py
Normal file
@@ -0,0 +1,625 @@
|
||||
"""
|
||||
Database-backed YOLO Dataset
|
||||
|
||||
Loads images from filesystem and labels from PostgreSQL database.
|
||||
Generates YOLO format labels dynamically at training time.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .annotation_generator import FIELD_CLASSES, YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Module-level LRU cache for image loading (shared across dataset instances)
|
||||
@lru_cache(maxsize=256)
|
||||
def _load_image_cached(image_path: str) -> tuple[np.ndarray, int, int]:
|
||||
"""
|
||||
Load and cache image from disk.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file (must be string for hashability)
|
||||
|
||||
Returns:
|
||||
Tuple of (image_array, width, height)
|
||||
"""
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
width, height = image.size
|
||||
image_array = np.array(image)
|
||||
return image_array, width, height
|
||||
|
||||
|
||||
def clear_image_cache():
|
||||
"""Clear the image cache to free memory."""
|
||||
_load_image_cached.cache_clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetItem:
|
||||
"""Single item in the dataset."""
|
||||
document_id: str
|
||||
image_path: Path
|
||||
page_no: int
|
||||
labels: list[YOLOAnnotation]
|
||||
is_scanned: bool = False # True if bbox is in pixel coords, False if in PDF points
|
||||
|
||||
|
||||
class DBYOLODataset:
|
||||
"""
|
||||
YOLO Dataset that reads labels from database.
|
||||
|
||||
This dataset:
|
||||
1. Scans temp directory for rendered images
|
||||
2. Queries database for bbox data
|
||||
3. Generates YOLO labels dynamically
|
||||
4. Performs train/val/test split at runtime
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
images_dir: str | Path,
|
||||
db: Any, # DocumentDB instance
|
||||
split: str = 'train',
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
dpi: int = 300,
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20,
|
||||
min_bbox_height_px: int = 30,
|
||||
limit: int | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize database-backed YOLO dataset.
|
||||
|
||||
Args:
|
||||
images_dir: Directory containing temp/{doc_id}/images/*.png
|
||||
db: DocumentDB instance for label queries
|
||||
split: Which split to use ('train', 'val', 'test')
|
||||
train_ratio: Ratio for training set
|
||||
val_ratio: Ratio for validation set
|
||||
seed: Random seed for reproducible splits
|
||||
dpi: DPI used for rendering (for coordinate conversion)
|
||||
min_confidence: Minimum match score to include
|
||||
bbox_padding_px: Padding around bboxes
|
||||
min_bbox_height_px: Minimum bbox height
|
||||
limit: Maximum number of documents to use (None for all)
|
||||
"""
|
||||
self.images_dir = Path(images_dir)
|
||||
self.db = db
|
||||
self.split = split
|
||||
self.train_ratio = train_ratio
|
||||
self.val_ratio = val_ratio
|
||||
self.seed = seed
|
||||
self.dpi = dpi
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
self.limit = limit
|
||||
|
||||
# Load and split dataset
|
||||
self.items: list[DatasetItem] = []
|
||||
self._all_items: list[DatasetItem] = [] # Cache all items for sharing
|
||||
self._doc_ids_ordered: list[str] = [] # Cache ordered doc IDs for consistent splits
|
||||
self._load_dataset()
|
||||
|
||||
@classmethod
|
||||
def from_shared_data(
|
||||
cls,
|
||||
source_dataset: 'DBYOLODataset',
|
||||
split: str,
|
||||
) -> 'DBYOLODataset':
|
||||
"""
|
||||
Create a new dataset instance sharing data from an existing one.
|
||||
|
||||
This avoids re-loading data from filesystem and database.
|
||||
|
||||
Args:
|
||||
source_dataset: Dataset to share data from
|
||||
split: Which split to use ('train', 'val', 'test')
|
||||
|
||||
Returns:
|
||||
New dataset instance with shared data
|
||||
"""
|
||||
# Create instance without loading (we'll share data)
|
||||
instance = object.__new__(cls)
|
||||
|
||||
# Copy configuration
|
||||
instance.images_dir = source_dataset.images_dir
|
||||
instance.db = source_dataset.db
|
||||
instance.split = split
|
||||
instance.train_ratio = source_dataset.train_ratio
|
||||
instance.val_ratio = source_dataset.val_ratio
|
||||
instance.seed = source_dataset.seed
|
||||
instance.dpi = source_dataset.dpi
|
||||
instance.min_confidence = source_dataset.min_confidence
|
||||
instance.bbox_padding_px = source_dataset.bbox_padding_px
|
||||
instance.min_bbox_height_px = source_dataset.min_bbox_height_px
|
||||
instance.limit = source_dataset.limit
|
||||
|
||||
# Share loaded data
|
||||
instance._all_items = source_dataset._all_items
|
||||
instance._doc_ids_ordered = source_dataset._doc_ids_ordered
|
||||
|
||||
# Split items for this split
|
||||
instance.items = instance._split_dataset_from_cache()
|
||||
print(f"Split '{split}': {len(instance.items)} items")
|
||||
|
||||
return instance
|
||||
|
||||
def _load_dataset(self):
|
||||
"""Load dataset items from filesystem and database."""
|
||||
# Find all document directories
|
||||
temp_dir = self.images_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Temp directory not found: {temp_dir}")
|
||||
return
|
||||
|
||||
# Collect all document IDs with images
|
||||
doc_image_map: dict[str, list[Path]] = {}
|
||||
for doc_dir in temp_dir.iterdir():
|
||||
if not doc_dir.is_dir():
|
||||
continue
|
||||
|
||||
images_path = doc_dir / 'images'
|
||||
if not images_path.exists():
|
||||
continue
|
||||
|
||||
images = list(images_path.glob('*.png'))
|
||||
if images:
|
||||
doc_image_map[doc_dir.name] = sorted(images)
|
||||
|
||||
print(f"Found {len(doc_image_map)} documents with images")
|
||||
|
||||
# Query database for all document labels
|
||||
doc_ids = list(doc_image_map.keys())
|
||||
doc_labels = self._load_labels_from_db(doc_ids)
|
||||
|
||||
print(f"Loaded labels for {len(doc_labels)} documents from database")
|
||||
|
||||
# Build dataset items
|
||||
all_items: list[DatasetItem] = []
|
||||
skipped_no_labels = 0
|
||||
skipped_no_db_record = 0
|
||||
total_images = 0
|
||||
|
||||
for doc_id, images in doc_image_map.items():
|
||||
doc_data = doc_labels.get(doc_id)
|
||||
|
||||
# Skip documents that don't exist in database
|
||||
if doc_data is None:
|
||||
skipped_no_db_record += len(images)
|
||||
total_images += len(images)
|
||||
continue
|
||||
|
||||
labels_by_page, is_scanned = doc_data
|
||||
|
||||
for image_path in images:
|
||||
total_images += 1
|
||||
# Extract page number from filename (e.g., "doc_page_000.png")
|
||||
page_no = self._extract_page_no(image_path.stem)
|
||||
|
||||
# Get labels for this page
|
||||
page_labels = labels_by_page.get(page_no, [])
|
||||
|
||||
if page_labels: # Only include pages with labels
|
||||
all_items.append(DatasetItem(
|
||||
document_id=doc_id,
|
||||
image_path=image_path,
|
||||
page_no=page_no,
|
||||
labels=page_labels,
|
||||
is_scanned=is_scanned
|
||||
))
|
||||
else:
|
||||
skipped_no_labels += 1
|
||||
|
||||
print(f"Total images found: {total_images}")
|
||||
print(f"Images with labels: {len(all_items)}")
|
||||
if skipped_no_db_record > 0:
|
||||
print(f"Skipped {skipped_no_db_record} images (document not in database)")
|
||||
if skipped_no_labels > 0:
|
||||
print(f"Skipped {skipped_no_labels} images (no labels for page)")
|
||||
|
||||
# Cache all items for sharing with other splits
|
||||
self._all_items = all_items
|
||||
|
||||
# Split dataset
|
||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||
print(f"Split '{self.split}': {len(self.items)} items")
|
||||
|
||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]]:
|
||||
"""
|
||||
Load labels from database for given document IDs using batch queries.
|
||||
|
||||
Returns:
|
||||
Dict of doc_id -> (page_labels, is_scanned)
|
||||
where page_labels is {page_no -> list[YOLOAnnotation]}
|
||||
and is_scanned indicates if bbox is in pixel coords (True) or PDF points (False)
|
||||
"""
|
||||
result: dict[str, tuple[dict[int, list[YOLOAnnotation]], bool]] = {}
|
||||
|
||||
# Query in batches using efficient batch method
|
||||
batch_size = 500
|
||||
for i in range(0, len(doc_ids), batch_size):
|
||||
batch_ids = doc_ids[i:i + batch_size]
|
||||
|
||||
# Use batch query instead of individual queries (N+1 fix)
|
||||
docs_batch = self.db.get_documents_batch(batch_ids)
|
||||
|
||||
for doc_id, doc in docs_batch.items():
|
||||
if not doc.get('success'):
|
||||
continue
|
||||
|
||||
# Check if scanned PDF (OCR bbox is in pixels, text PDF bbox is in PDF points)
|
||||
is_scanned = doc.get('pdf_type') == 'scanned'
|
||||
|
||||
page_labels: dict[int, list[YOLOAnnotation]] = {}
|
||||
|
||||
for field_result in doc.get('field_results', []):
|
||||
if not field_result.get('matched'):
|
||||
continue
|
||||
|
||||
field_name = field_result.get('field_name')
|
||||
if field_name not in FIELD_CLASSES:
|
||||
continue
|
||||
|
||||
score = field_result.get('score', 0)
|
||||
if score < self.min_confidence:
|
||||
continue
|
||||
|
||||
bbox = field_result.get('bbox')
|
||||
page_no = field_result.get('page_no', 0)
|
||||
|
||||
if bbox and len(bbox) == 4:
|
||||
annotation = self._create_annotation(
|
||||
field_name=field_name,
|
||||
bbox=bbox,
|
||||
score=score
|
||||
)
|
||||
|
||||
if page_no not in page_labels:
|
||||
page_labels[page_no] = []
|
||||
page_labels[page_no].append(annotation)
|
||||
|
||||
if page_labels:
|
||||
result[doc_id] = (page_labels, is_scanned)
|
||||
|
||||
return result
|
||||
|
||||
def _create_annotation(
|
||||
self,
|
||||
field_name: str,
|
||||
bbox: list[float],
|
||||
score: float
|
||||
) -> YOLOAnnotation:
|
||||
"""
|
||||
Create a YOLO annotation from bbox.
|
||||
|
||||
Note: bbox is in PDF points (72 DPI), will be normalized later.
|
||||
"""
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
x0, y0, x1, y1 = bbox
|
||||
|
||||
# Store raw PDF coordinates - will be normalized when getting item
|
||||
return YOLOAnnotation(
|
||||
class_id=class_id,
|
||||
x_center=(x0 + x1) / 2, # center in PDF points
|
||||
y_center=(y0 + y1) / 2,
|
||||
width=x1 - x0,
|
||||
height=y1 - y0,
|
||||
confidence=score
|
||||
)
|
||||
|
||||
def _extract_page_no(self, stem: str) -> int:
|
||||
"""Extract page number from image filename."""
|
||||
# Format: "{doc_id}_page_{page_no:03d}"
|
||||
parts = stem.rsplit('_', 1)
|
||||
if len(parts) == 2:
|
||||
try:
|
||||
return int(parts[1])
|
||||
except ValueError:
|
||||
pass
|
||||
return 0
|
||||
|
||||
def _split_dataset(self, items: list[DatasetItem]) -> tuple[list[DatasetItem], list[str]]:
|
||||
"""
|
||||
Split items into train/val/test based on document ID.
|
||||
|
||||
Returns:
|
||||
Tuple of (split_items, ordered_doc_ids) where ordered_doc_ids can be
|
||||
reused for consistent splits across shared datasets.
|
||||
"""
|
||||
# Group by document ID for proper splitting
|
||||
doc_items: dict[str, list[DatasetItem]] = {}
|
||||
for item in items:
|
||||
if item.document_id not in doc_items:
|
||||
doc_items[item.document_id] = []
|
||||
doc_items[item.document_id].append(item)
|
||||
|
||||
# Shuffle document IDs
|
||||
doc_ids = list(doc_items.keys())
|
||||
random.seed(self.seed)
|
||||
random.shuffle(doc_ids)
|
||||
|
||||
# Apply limit if specified (before splitting)
|
||||
if self.limit is not None and self.limit < len(doc_ids):
|
||||
doc_ids = doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
n_train = int(n_total * self.train_ratio)
|
||||
n_val = int(n_total * self.val_ratio)
|
||||
|
||||
# Split document IDs
|
||||
if self.split == 'train':
|
||||
split_doc_ids = doc_ids[:n_train]
|
||||
elif self.split == 'val':
|
||||
split_doc_ids = doc_ids[n_train:n_train + n_val]
|
||||
else: # test
|
||||
split_doc_ids = doc_ids[n_train + n_val:]
|
||||
|
||||
# Collect items for this split
|
||||
split_items = []
|
||||
for doc_id in split_doc_ids:
|
||||
split_items.extend(doc_items[doc_id])
|
||||
|
||||
return split_items, doc_ids
|
||||
|
||||
def _split_dataset_from_cache(self) -> list[DatasetItem]:
|
||||
"""
|
||||
Split items using cached data from a shared dataset.
|
||||
|
||||
Uses pre-computed doc_ids order for consistent splits.
|
||||
"""
|
||||
# Group by document ID
|
||||
doc_items: dict[str, list[DatasetItem]] = {}
|
||||
for item in self._all_items:
|
||||
if item.document_id not in doc_items:
|
||||
doc_items[item.document_id] = []
|
||||
doc_items[item.document_id].append(item)
|
||||
|
||||
# Use cached doc_ids order
|
||||
doc_ids = self._doc_ids_ordered
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
n_train = int(n_total * self.train_ratio)
|
||||
n_val = int(n_total * self.val_ratio)
|
||||
|
||||
# Split document IDs based on split type
|
||||
if self.split == 'train':
|
||||
split_doc_ids = doc_ids[:n_train]
|
||||
elif self.split == 'val':
|
||||
split_doc_ids = doc_ids[n_train:n_train + n_val]
|
||||
else: # test
|
||||
split_doc_ids = doc_ids[n_train + n_val:]
|
||||
|
||||
# Collect items for this split
|
||||
split_items = []
|
||||
for doc_id in split_doc_ids:
|
||||
if doc_id in doc_items:
|
||||
split_items.extend(doc_items[doc_id])
|
||||
|
||||
return split_items
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Get image and labels for index.
|
||||
|
||||
Returns:
|
||||
(image, labels) where:
|
||||
- image: numpy array (H, W, C)
|
||||
- labels: numpy array (N, 5) with [class_id, x_center, y_center, width, height]
|
||||
"""
|
||||
item = self.items[idx]
|
||||
|
||||
# Load image using LRU cache (significant speedup during training)
|
||||
image_array, img_width, img_height = _load_image_cached(str(item.image_path))
|
||||
|
||||
# Convert annotations to YOLO format (normalized)
|
||||
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
|
||||
|
||||
return image_array, labels
|
||||
|
||||
def _convert_labels(
|
||||
self,
|
||||
annotations: list[YOLOAnnotation],
|
||||
img_width: int,
|
||||
img_height: int,
|
||||
is_scanned: bool = False
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Convert annotations to normalized YOLO format.
|
||||
|
||||
Args:
|
||||
annotations: List of annotations
|
||||
img_width: Actual image width in pixels
|
||||
img_height: Actual image height in pixels
|
||||
is_scanned: If True, bbox is already in pixels; if False, bbox is in PDF points
|
||||
|
||||
Returns:
|
||||
numpy array (N, 5) with [class_id, x_center, y_center, width, height]
|
||||
"""
|
||||
if not annotations:
|
||||
return np.zeros((0, 5), dtype=np.float32)
|
||||
|
||||
# Scale factor: PDF points (72 DPI) -> rendered pixels
|
||||
# Note: After the OCR coordinate fix, ALL bbox (both text and scanned PDF)
|
||||
# are stored in PDF points, so we always apply the same scaling.
|
||||
scale = self.dpi / 72.0
|
||||
|
||||
labels = []
|
||||
for ann in annotations:
|
||||
# Convert to pixels (if needed)
|
||||
x_center_px = ann.x_center * scale
|
||||
y_center_px = ann.y_center * scale
|
||||
width_px = ann.width * scale
|
||||
height_px = ann.height * scale
|
||||
|
||||
# Add padding
|
||||
pad = self.bbox_padding_px
|
||||
width_px += 2 * pad
|
||||
height_px += 2 * pad
|
||||
|
||||
# Ensure minimum height
|
||||
if height_px < self.min_bbox_height_px:
|
||||
height_px = self.min_bbox_height_px
|
||||
|
||||
# Normalize to 0-1
|
||||
x_center = x_center_px / img_width
|
||||
y_center = y_center_px / img_height
|
||||
width = width_px / img_width
|
||||
height = height_px / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
x_center = max(0, min(1, x_center))
|
||||
y_center = max(0, min(1, y_center))
|
||||
width = max(0, min(1, width))
|
||||
height = max(0, min(1, height))
|
||||
|
||||
labels.append([ann.class_id, x_center, y_center, width, height])
|
||||
|
||||
return np.array(labels, dtype=np.float32)
|
||||
|
||||
def get_image_path(self, idx: int) -> Path:
|
||||
"""Get image path for index."""
|
||||
return self.items[idx].image_path
|
||||
|
||||
def get_labels_for_yolo(self, idx: int) -> str:
|
||||
"""
|
||||
Get YOLO format labels as string for index.
|
||||
|
||||
Returns:
|
||||
String with YOLO format labels (one per line)
|
||||
"""
|
||||
item = self.items[idx]
|
||||
|
||||
# Use cached image loading to avoid duplicate disk reads
|
||||
_, img_width, img_height = _load_image_cached(str(item.image_path))
|
||||
|
||||
labels = self._convert_labels(item.labels, img_width, img_height, item.is_scanned)
|
||||
|
||||
lines = []
|
||||
for label in labels:
|
||||
class_id = int(label[0])
|
||||
x_center, y_center, width, height = label[1:5]
|
||||
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
def export_to_yolo_format(
|
||||
self,
|
||||
output_dir: str | Path,
|
||||
split_name: Optional[str] = None
|
||||
) -> int:
|
||||
"""
|
||||
Export dataset to standard YOLO format (images + label files).
|
||||
|
||||
This is useful for training with standard YOLO training scripts.
|
||||
|
||||
Args:
|
||||
output_dir: Output directory
|
||||
split_name: Name for the split subdirectory (default: self.split)
|
||||
|
||||
Returns:
|
||||
Number of items exported
|
||||
"""
|
||||
import shutil
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
split_name = split_name or self.split
|
||||
|
||||
images_out = output_dir / split_name / 'images'
|
||||
labels_out = output_dir / split_name / 'labels'
|
||||
|
||||
# Clear existing directories before export
|
||||
if images_out.exists():
|
||||
shutil.rmtree(images_out)
|
||||
if labels_out.exists():
|
||||
shutil.rmtree(labels_out)
|
||||
|
||||
images_out.mkdir(parents=True, exist_ok=True)
|
||||
labels_out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = 0
|
||||
for idx in range(len(self)):
|
||||
item = self.items[idx]
|
||||
|
||||
# Copy image
|
||||
dest_image = images_out / item.image_path.name
|
||||
shutil.copy2(item.image_path, dest_image)
|
||||
|
||||
# Write label file
|
||||
label_content = self.get_labels_for_yolo(idx)
|
||||
label_path = labels_out / f"{item.image_path.stem}.txt"
|
||||
with open(label_path, 'w') as f:
|
||||
f.write(label_content)
|
||||
|
||||
count += 1
|
||||
|
||||
print(f"Exported {count} items to {output_dir / split_name}")
|
||||
return count
|
||||
|
||||
|
||||
def create_datasets(
|
||||
images_dir: str | Path,
|
||||
db: Any,
|
||||
train_ratio: float = 0.8,
|
||||
val_ratio: float = 0.1,
|
||||
seed: int = 42,
|
||||
limit: int | None = None,
|
||||
**kwargs
|
||||
) -> dict[str, DBYOLODataset]:
|
||||
"""
|
||||
Create train/val/test datasets.
|
||||
|
||||
This function loads data once and shares it across all splits for efficiency.
|
||||
|
||||
Args:
|
||||
images_dir: Directory containing temp/{doc_id}/images/
|
||||
db: DocumentDB instance
|
||||
train_ratio: Training set ratio
|
||||
val_ratio: Validation set ratio
|
||||
seed: Random seed
|
||||
limit: Maximum number of documents to use (None for all)
|
||||
**kwargs: Additional arguments for DBYOLODataset
|
||||
|
||||
Returns:
|
||||
Dict with 'train', 'val', 'test' datasets
|
||||
"""
|
||||
# Create first dataset which loads all data
|
||||
print("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
first_dataset = DBYOLODataset(
|
||||
images_dir=images_dir,
|
||||
db=db,
|
||||
split='train',
|
||||
train_ratio=train_ratio,
|
||||
val_ratio=val_ratio,
|
||||
seed=seed,
|
||||
limit=limit,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create other splits by sharing loaded data
|
||||
datasets = {'train': first_dataset}
|
||||
for split in ['val', 'test']:
|
||||
datasets[split] = DBYOLODataset.from_shared_data(
|
||||
first_dataset,
|
||||
split=split,
|
||||
)
|
||||
|
||||
return datasets
|
||||
Reference in New Issue
Block a user