Re-structure the project.
This commit is contained in:
358
src/matcher/README.md
Normal file
358
src/matcher/README.md
Normal file
@@ -0,0 +1,358 @@
|
||||
# Matcher Module - 字段匹配模块
|
||||
|
||||
将标准化后的字段值与PDF文档中的tokens进行匹配,返回字段在文档中的位置(bbox),用于生成YOLO训练标注。
|
||||
|
||||
## 📁 模块结构
|
||||
|
||||
```
|
||||
src/matcher/
|
||||
├── __init__.py # 导出主要接口
|
||||
├── field_matcher.py # 主类 (205行, 从876行简化)
|
||||
├── models.py # 数据模型
|
||||
├── token_index.py # 空间索引
|
||||
├── context.py # 上下文关键词
|
||||
├── utils.py # 工具函数
|
||||
└── strategies/ # 匹配策略
|
||||
├── __init__.py
|
||||
├── base.py # 基础策略类
|
||||
├── exact_matcher.py # 精确匹配
|
||||
├── concatenated_matcher.py # 多token拼接匹配
|
||||
├── substring_matcher.py # 子串匹配
|
||||
├── fuzzy_matcher.py # 模糊匹配 (金额)
|
||||
└── flexible_date_matcher.py # 灵活日期匹配
|
||||
```
|
||||
|
||||
## 🎯 核心功能
|
||||
|
||||
### FieldMatcher - 字段匹配器
|
||||
|
||||
主类,协调各个匹配策略:
|
||||
|
||||
```python
|
||||
from src.matcher import FieldMatcher
|
||||
|
||||
matcher = FieldMatcher(
|
||||
context_radius=200.0, # 上下文关键词搜索半径(像素)
|
||||
min_score_threshold=0.5 # 最低匹配分数
|
||||
)
|
||||
|
||||
# 匹配字段
|
||||
matches = matcher.find_matches(
|
||||
tokens=tokens, # PDF提取的tokens
|
||||
field_name="InvoiceNumber", # 字段名
|
||||
normalized_values=["100017500321", "INV-100017500321"], # 标准化变体
|
||||
page_no=0 # 页码
|
||||
)
|
||||
|
||||
# matches: List[Match]
|
||||
for match in matches:
|
||||
print(f"Field: {match.field}")
|
||||
print(f"Value: {match.value}")
|
||||
print(f"BBox: {match.bbox}")
|
||||
print(f"Score: {match.score}")
|
||||
print(f"Context: {match.context_keywords}")
|
||||
```
|
||||
|
||||
### 5种匹配策略
|
||||
|
||||
#### 1. ExactMatcher - 精确匹配
|
||||
```python
|
||||
from src.matcher.strategies import ExactMatcher
|
||||
|
||||
matcher = ExactMatcher(context_radius=200.0)
|
||||
matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber")
|
||||
```
|
||||
|
||||
匹配规则:
|
||||
- 完全匹配: score = 1.0
|
||||
- 大小写不敏感: score = 0.95
|
||||
- 纯数字匹配: score = 0.9
|
||||
- 上下文关键词加分: +0.1/keyword (最多+0.25)
|
||||
|
||||
#### 2. ConcatenatedMatcher - 拼接匹配
|
||||
```python
|
||||
from src.matcher.strategies import ConcatenatedMatcher
|
||||
|
||||
matcher = ConcatenatedMatcher()
|
||||
matches = matcher.find_matches(tokens, "100017500321", "InvoiceNumber")
|
||||
```
|
||||
|
||||
用于处理OCR将单个值拆成多个token的情况。
|
||||
|
||||
#### 3. SubstringMatcher - 子串匹配
|
||||
```python
|
||||
from src.matcher.strategies import SubstringMatcher
|
||||
|
||||
matcher = SubstringMatcher()
|
||||
matches = matcher.find_matches(tokens, "2026-01-09", "InvoiceDate")
|
||||
```
|
||||
|
||||
匹配嵌入在长文本中的字段值:
|
||||
- `"Fakturadatum: 2026-01-09"` 匹配 `"2026-01-09"`
|
||||
- `"Fakturanummer: 2465027205"` 匹配 `"2465027205"`
|
||||
|
||||
#### 4. FuzzyMatcher - 模糊匹配
|
||||
```python
|
||||
from src.matcher.strategies import FuzzyMatcher
|
||||
|
||||
matcher = FuzzyMatcher()
|
||||
matches = matcher.find_matches(tokens, "1234.56", "Amount")
|
||||
```
|
||||
|
||||
用于金额字段,允许小数点差异 (±0.01)。
|
||||
|
||||
#### 5. FlexibleDateMatcher - 灵活日期匹配
|
||||
```python
|
||||
from src.matcher.strategies import FlexibleDateMatcher
|
||||
|
||||
matcher = FlexibleDateMatcher()
|
||||
matches = matcher.find_matches(tokens, "2025-01-15", "InvoiceDate")
|
||||
```
|
||||
|
||||
当精确匹配失败时使用:
|
||||
- 同年月: score = 0.7-0.8
|
||||
- 7天内: score = 0.75+
|
||||
- 3天内: score = 0.8+
|
||||
- 14天内: score = 0.6
|
||||
- 30天内: score = 0.55
|
||||
|
||||
### 数据模型
|
||||
|
||||
#### Match - 匹配结果
|
||||
```python
|
||||
from src.matcher.models import Match
|
||||
|
||||
match = Match(
|
||||
field="InvoiceNumber",
|
||||
value="100017500321",
|
||||
bbox=(100.0, 200.0, 300.0, 220.0),
|
||||
page_no=0,
|
||||
score=0.95,
|
||||
matched_text="100017500321",
|
||||
context_keywords=["fakturanr"]
|
||||
)
|
||||
|
||||
# 转换为YOLO格式
|
||||
yolo_annotation = match.to_yolo_format(
|
||||
image_width=1200,
|
||||
image_height=1600,
|
||||
class_id=0
|
||||
)
|
||||
# "0 0.166667 0.131250 0.166667 0.012500"
|
||||
```
|
||||
|
||||
#### TokenIndex - 空间索引
|
||||
```python
|
||||
from src.matcher.token_index import TokenIndex
|
||||
|
||||
# 构建索引
|
||||
index = TokenIndex(tokens, grid_size=100.0)
|
||||
|
||||
# 快速查找附近tokens (O(1)平均复杂度)
|
||||
nearby = index.find_nearby(token, radius=200.0)
|
||||
|
||||
# 获取缓存的中心坐标
|
||||
center = index.get_center(token)
|
||||
|
||||
# 获取缓存的小写文本
|
||||
text_lower = index.get_text_lower(token)
|
||||
```
|
||||
|
||||
### 上下文关键词
|
||||
|
||||
```python
|
||||
from src.matcher.context import CONTEXT_KEYWORDS, find_context_keywords
|
||||
|
||||
# 查看字段的上下文关键词
|
||||
keywords = CONTEXT_KEYWORDS["InvoiceNumber"]
|
||||
# ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', ...]
|
||||
|
||||
# 查找附近的关键词
|
||||
found_keywords, boost_score = find_context_keywords(
|
||||
tokens=tokens,
|
||||
target_token=token,
|
||||
field_name="InvoiceNumber",
|
||||
context_radius=200.0,
|
||||
token_index=index # 可选,提供则使用O(1)查找
|
||||
)
|
||||
```
|
||||
|
||||
支持的字段:
|
||||
- InvoiceNumber
|
||||
- InvoiceDate
|
||||
- InvoiceDueDate
|
||||
- OCR
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- Amount
|
||||
- supplier_organisation_number
|
||||
- supplier_accounts
|
||||
|
||||
### 工具函数
|
||||
|
||||
```python
|
||||
from src.matcher.utils import (
|
||||
normalize_dashes,
|
||||
parse_amount,
|
||||
tokens_on_same_line,
|
||||
bbox_overlap,
|
||||
DATE_PATTERN,
|
||||
WHITESPACE_PATTERN,
|
||||
NON_DIGIT_PATTERN,
|
||||
DASH_PATTERN,
|
||||
)
|
||||
|
||||
# 标准化各种破折号
|
||||
text = normalize_dashes("123–456") # "123-456"
|
||||
|
||||
# 解析瑞典金额格式
|
||||
amount = parse_amount("1 234,56 kr") # 1234.56
|
||||
amount = parse_amount("239 00") # 239.00 (öre格式)
|
||||
|
||||
# 检查tokens是否在同一行
|
||||
same_line = tokens_on_same_line(token1, token2)
|
||||
|
||||
# 计算bbox重叠度 (IoU)
|
||||
overlap = bbox_overlap(bbox1, bbox2) # 0.0 - 1.0
|
||||
```
|
||||
|
||||
## 🧪 测试
|
||||
|
||||
```bash
|
||||
# 在WSL中运行
|
||||
conda activate invoice-py311
|
||||
|
||||
# 运行所有matcher测试
|
||||
pytest tests/matcher/ -v
|
||||
|
||||
# 运行特定策略测试
|
||||
pytest tests/matcher/strategies/test_exact_matcher.py -v
|
||||
|
||||
# 查看覆盖率
|
||||
pytest tests/matcher/ --cov=src/matcher --cov-report=html
|
||||
```
|
||||
|
||||
测试覆盖:
|
||||
- ✅ 77个测试全部通过
|
||||
- ✅ TokenIndex 空间索引
|
||||
- ✅ 5种匹配策略
|
||||
- ✅ 上下文关键词
|
||||
- ✅ 工具函数
|
||||
- ✅ 去重逻辑
|
||||
|
||||
## 📊 重构成果
|
||||
|
||||
| 指标 | 重构前 | 重构后 | 改进 |
|
||||
|------|--------|--------|------|
|
||||
| field_matcher.py | 876行 | 205行 | ↓ 76% |
|
||||
| 模块数 | 1 | 11 | 更清晰 |
|
||||
| 最大文件大小 | 876行 | 154行 | 更易读 |
|
||||
| 测试通过率 | - | 100% | ✅ |
|
||||
|
||||
## 🚀 使用示例
|
||||
|
||||
### 完整流程
|
||||
|
||||
```python
|
||||
from src.matcher import FieldMatcher, find_field_matches
|
||||
|
||||
# 1. 提取PDF tokens (使用PDF模块)
|
||||
from src.pdf import PDFExtractor
|
||||
extractor = PDFExtractor("invoice.pdf")
|
||||
tokens = extractor.extract_tokens()
|
||||
|
||||
# 2. 准备字段值 (从CSV或数据库)
|
||||
field_values = {
|
||||
"InvoiceNumber": "100017500321",
|
||||
"InvoiceDate": "2026-01-09",
|
||||
"Amount": "1234.56",
|
||||
}
|
||||
|
||||
# 3. 查找所有字段匹配
|
||||
results = find_field_matches(tokens, field_values, page_no=0)
|
||||
|
||||
# 4. 使用结果
|
||||
for field_name, matches in results.items():
|
||||
if matches:
|
||||
best_match = matches[0] # 已按score降序排列
|
||||
print(f"{field_name}: {best_match.value} @ {best_match.bbox}")
|
||||
print(f" Score: {best_match.score:.2f}")
|
||||
print(f" Context: {best_match.context_keywords}")
|
||||
```
|
||||
|
||||
### 添加自定义策略
|
||||
|
||||
```python
|
||||
from src.matcher.strategies.base import BaseMatchStrategy
|
||||
from src.matcher.models import Match
|
||||
|
||||
class CustomMatcher(BaseMatchStrategy):
|
||||
"""自定义匹配策略"""
|
||||
|
||||
def find_matches(self, tokens, value, field_name, token_index=None):
|
||||
matches = []
|
||||
# 实现你的匹配逻辑
|
||||
for token in tokens:
|
||||
if self._custom_match_logic(token.text, value):
|
||||
match = Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=0.85,
|
||||
matched_text=token.text,
|
||||
context_keywords=[]
|
||||
)
|
||||
matches.append(match)
|
||||
return matches
|
||||
|
||||
def _custom_match_logic(self, token_text, value):
|
||||
# 你的匹配逻辑
|
||||
return True
|
||||
|
||||
# 在FieldMatcher中使用
|
||||
from src.matcher import FieldMatcher
|
||||
matcher = FieldMatcher()
|
||||
matcher.custom_matcher = CustomMatcher()
|
||||
```
|
||||
|
||||
## 🔧 维护指南
|
||||
|
||||
### 添加新的上下文关键词
|
||||
|
||||
编辑 [src/matcher/context.py](context.py):
|
||||
|
||||
```python
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', '新关键词'],
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
### 调整匹配分数
|
||||
|
||||
编辑对应的策略文件:
|
||||
- [exact_matcher.py](strategies/exact_matcher.py) - 精确匹配分数
|
||||
- [fuzzy_matcher.py](strategies/fuzzy_matcher.py) - 模糊匹配容差
|
||||
- [flexible_date_matcher.py](strategies/flexible_date_matcher.py) - 日期距离分数
|
||||
|
||||
### 性能优化
|
||||
|
||||
1. **TokenIndex网格大小**: 默认100px,可根据实际文档调整
|
||||
2. **上下文半径**: 默认200px,可根据扫描DPI调整
|
||||
3. **去重网格**: 默认50px,影响bbox重叠检测性能
|
||||
|
||||
## 📚 相关文档
|
||||
|
||||
- [PDF模块文档](../pdf/README.md) - Token提取
|
||||
- [Normalize模块文档](../normalize/README.md) - 字段值标准化
|
||||
- [YOLO模块文档](../yolo/README.md) - 标注生成
|
||||
|
||||
## ✅ 总结
|
||||
|
||||
这个模块化的matcher系统提供:
|
||||
- **清晰的职责分离**: 每个策略专注一个匹配方法
|
||||
- **易于测试**: 独立测试每个组件
|
||||
- **高性能**: O(1)空间索引,智能去重
|
||||
- **可扩展**: 轻松添加新策略
|
||||
- **完整测试**: 77个测试100%通过
|
||||
@@ -1,3 +1,4 @@
|
||||
from .field_matcher import FieldMatcher, Match, find_field_matches
|
||||
from .field_matcher import FieldMatcher, find_field_matches
|
||||
from .models import Match, TokenLike
|
||||
|
||||
__all__ = ['FieldMatcher', 'Match', 'find_field_matches']
|
||||
__all__ = ['FieldMatcher', 'Match', 'TokenLike', 'find_field_matches']
|
||||
|
||||
92
src/matcher/context.py
Normal file
92
src/matcher/context.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Context keywords for field matching.
|
||||
"""
|
||||
|
||||
from .models import TokenLike
|
||||
from .token_index import TokenIndex
|
||||
|
||||
|
||||
# Context keywords for each field type (Swedish invoice terms)
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||
}
|
||||
|
||||
|
||||
def find_context_keywords(
|
||||
tokens: list[TokenLike],
|
||||
target_token: TokenLike,
|
||||
field_name: str,
|
||||
context_radius: float,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
|
||||
Args:
|
||||
tokens: List of all tokens
|
||||
target_token: The token to find context for
|
||||
field_name: Name of the field
|
||||
context_radius: Search radius in pixels
|
||||
token_index: Optional spatial index for efficient lookup
|
||||
|
||||
Returns:
|
||||
Tuple of (found_keywords, boost_score)
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if token_index:
|
||||
nearby_tokens = token_index.find_nearby(target_token, context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
@@ -1,158 +1,19 @@
|
||||
"""
|
||||
Field Matching Module
|
||||
Field Matching Module - Refactored
|
||||
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
|
||||
|
||||
|
||||
def _normalize_dashes(text: str) -> str:
|
||||
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
|
||||
return _DASH_PATTERN.sub('-', text)
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
"""Protocol for token objects."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
field: str
|
||||
value: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
score: float # 0-1 confidence score
|
||||
matched_text: str # Actual text that matched
|
||||
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||
|
||||
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||
"""Convert to YOLO annotation format."""
|
||||
x0, y0, x1, y1 = self.bbox
|
||||
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
|
||||
|
||||
# Context keywords for each field type (Swedish invoice terms)
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||
}
|
||||
from .models import TokenLike, Match
|
||||
from .token_index import TokenIndex
|
||||
from .utils import bbox_overlap
|
||||
from .strategies import (
|
||||
ExactMatcher,
|
||||
ConcatenatedMatcher,
|
||||
SubstringMatcher,
|
||||
FuzzyMatcher,
|
||||
FlexibleDateMatcher,
|
||||
)
|
||||
|
||||
|
||||
class FieldMatcher:
|
||||
@@ -175,6 +36,13 @@ class FieldMatcher:
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
# Initialize matching strategies
|
||||
self.exact_matcher = ExactMatcher(context_radius)
|
||||
self.concatenated_matcher = ConcatenatedMatcher(context_radius)
|
||||
self.substring_matcher = SubstringMatcher(context_radius)
|
||||
self.fuzzy_matcher = FuzzyMatcher(context_radius)
|
||||
self.flexible_date_matcher = FlexibleDateMatcher(context_radius)
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
@@ -208,34 +76,46 @@ class FieldMatcher:
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||
exact_matches = self.exact_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(exact_matches)
|
||||
|
||||
# Strategy 2: Multi-token concatenation
|
||||
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
|
||||
concat_matches = self.concatenated_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(concat_matches)
|
||||
|
||||
# Strategy 3: Fuzzy match (for amounts and dates only)
|
||||
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
|
||||
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
|
||||
fuzzy_matches = self.fuzzy_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(fuzzy_matches)
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||
# in OCR payment lines or other unrelated text
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
if field_name in (
|
||||
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
|
||||
'Bankgiro', 'Plusgiro', 'supplier_organisation_number',
|
||||
'supplier_accounts', 'customer_number'
|
||||
):
|
||||
substring_matches = self.substring_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
|
||||
# Only if no exact matches found for date fields
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
|
||||
flexible_matches = self._find_flexible_date_matches(
|
||||
page_tokens, normalized_values, field_name
|
||||
)
|
||||
matches.extend(flexible_matches)
|
||||
for value in normalized_values:
|
||||
flexible_matches = self.flexible_date_matcher.find_matches(
|
||||
page_tokens, value, field_name, self._token_index
|
||||
)
|
||||
matches.extend(flexible_matches)
|
||||
|
||||
# Deduplicate and sort by score
|
||||
matches = self._deduplicate_matches(matches)
|
||||
@@ -246,521 +126,6 @@ class FieldMatcher:
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _find_exact_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find tokens that exactly match the value."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts') else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif value_digits is not None:
|
||||
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Boost score if context keywords are nearby
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
score = min(1.0, score + context_boost)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=score,
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_concatenated_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
matches = []
|
||||
value_clean = _WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
for i, start_token in enumerate(sorted_tokens):
|
||||
# Try to build the value by concatenating nearby tokens
|
||||
concat_text = start_token.text.strip()
|
||||
concat_bbox = list(start_token.bbox)
|
||||
used_tokens = [start_token]
|
||||
|
||||
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||
next_token = sorted_tokens[j]
|
||||
|
||||
# Check if tokens are on the same line (y overlap)
|
||||
if not self._tokens_on_same_line(start_token, next_token):
|
||||
break
|
||||
|
||||
# Check horizontal proximity
|
||||
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||
break
|
||||
|
||||
concat_text += next_token.text.strip()
|
||||
used_tokens.append(next_token)
|
||||
|
||||
# Update bounding box
|
||||
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, start_token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=tuple(concat_bbox),
|
||||
page_no=start_token.page_no,
|
||||
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||
matched_text=concat_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
break
|
||||
|
||||
return matches
|
||||
|
||||
def _find_substring_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find value as a substring within longer tokens.
|
||||
|
||||
Handles cases like:
|
||||
- 'Fakturadatum: 2026-01-09' where the date is embedded
|
||||
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
|
||||
- 'OCR: 1234567890' where reference number is embedded
|
||||
|
||||
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
|
||||
Only matches if the value appears as a distinct segment (not part of a larger number).
|
||||
"""
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
# Fields where spaces/dashes should be ignored during matching
|
||||
# (e.g., org number "55 65 74-6624" should match "5565746624")
|
||||
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
# Normalize different dash types to hyphen-minus for matching
|
||||
token_text_normalized = _normalize_dashes(token_text)
|
||||
|
||||
# For certain fields, also try matching with spaces/dashes removed
|
||||
if field_name in ignore_spaces_fields:
|
||||
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
|
||||
value_compact = value.replace(' ', '').replace('-', '')
|
||||
else:
|
||||
token_text_compact = None
|
||||
value_compact = None
|
||||
|
||||
# Skip if token is the same length as value (would be exact match)
|
||||
if len(token_text_normalized) <= len(value):
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
idx = None
|
||||
case_sensitive_match = True
|
||||
used_compact = False
|
||||
|
||||
if value in token_text_normalized:
|
||||
idx = token_text_normalized.find(value)
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
elif token_text_compact and value_compact in token_text_compact:
|
||||
# Try compact matching (spaces/dashes removed)
|
||||
idx = token_text_compact.find(value_compact)
|
||||
used_compact = True
|
||||
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
|
||||
idx = token_text_compact.lower().find(value_compact.lower())
|
||||
case_sensitive_match = False
|
||||
used_compact = True
|
||||
|
||||
if idx is None:
|
||||
continue
|
||||
|
||||
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
|
||||
if used_compact:
|
||||
# Verify proper boundary in compact text
|
||||
if idx > 0 and token_text_compact[idx - 1].isdigit():
|
||||
continue
|
||||
end_idx = idx + len(value_compact)
|
||||
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
|
||||
continue
|
||||
else:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
# Must be non-digit (allow : space - etc)
|
||||
if char_before.isdigit():
|
||||
continue
|
||||
|
||||
# Check character after (if exists)
|
||||
end_idx = idx + len(value)
|
||||
if end_idx < len(token_text_normalized):
|
||||
char_after = token_text_normalized[end_idx]
|
||||
# Must be non-digit
|
||||
if char_after.isdigit():
|
||||
continue
|
||||
|
||||
# Found valid substring match
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||
token_lower = token_text.lower()
|
||||
inline_context = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_context.append(keyword)
|
||||
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_fuzzy_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find approximate matches for amounts and dates."""
|
||||
matches = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
if field_name == 'Amount':
|
||||
# Try to parse both as numbers
|
||||
try:
|
||||
token_num = self._parse_amount(token_text)
|
||||
value_num = self._parse_amount(value)
|
||||
|
||||
if token_num is not None and value_num is not None:
|
||||
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.8 + context_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
def _find_flexible_date_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
normalized_values: list[str],
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Flexible date matching when exact match fails.
|
||||
|
||||
Strategies:
|
||||
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||
2. Nearby date match: Match dates within 7 days of CSV value
|
||||
3. Heuristic selection: Use context keywords to select the best date
|
||||
|
||||
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||
but we can still find a reasonable date to label.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
matches = []
|
||||
|
||||
# Parse the target date from normalized values
|
||||
target_date = None
|
||||
for value in normalized_values:
|
||||
# Try to parse YYYY-MM-DD format
|
||||
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||
if date_match:
|
||||
try:
|
||||
target_date = datetime(
|
||||
int(date_match.group(1)),
|
||||
int(date_match.group(2)),
|
||||
int(date_match.group(3))
|
||||
)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not target_date:
|
||||
return matches
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in _DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))
|
||||
)
|
||||
date_str = match.group(0)
|
||||
|
||||
# Calculate date difference
|
||||
days_diff = abs((found_date - target_date).days)
|
||||
|
||||
# Check for context keywords
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if keyword is in the same token
|
||||
token_lower = token_text.lower()
|
||||
inline_keywords = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_keywords.append(keyword)
|
||||
|
||||
date_candidates.append({
|
||||
'token': token,
|
||||
'date': found_date,
|
||||
'date_str': date_str,
|
||||
'matched_text': token_text,
|
||||
'days_diff': days_diff,
|
||||
'context_keywords': context_keywords + inline_keywords,
|
||||
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||
'same_year_month': (found_date.year == target_date.year and
|
||||
found_date.month == target_date.month),
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not date_candidates:
|
||||
return matches
|
||||
|
||||
# Score and rank candidates
|
||||
for candidate in date_candidates:
|
||||
score = 0.0
|
||||
|
||||
# Strategy 1: Same year-month gets higher score
|
||||
if candidate['same_year_month']:
|
||||
score = 0.7
|
||||
# Bonus if day is close
|
||||
if candidate['days_diff'] <= 7:
|
||||
score = 0.75
|
||||
if candidate['days_diff'] <= 3:
|
||||
score = 0.8
|
||||
# Strategy 2: Nearby dates (within 14 days)
|
||||
elif candidate['days_diff'] <= 14:
|
||||
score = 0.6
|
||||
elif candidate['days_diff'] <= 30:
|
||||
score = 0.55
|
||||
else:
|
||||
# Too far apart, skip unless has strong context
|
||||
if not candidate['context_keywords']:
|
||||
continue
|
||||
score = 0.5
|
||||
|
||||
# Strategy 3: Boost with context keywords
|
||||
score = min(1.0, score + candidate['context_boost'])
|
||||
|
||||
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||
if candidate['context_keywords']:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
if score >= self.min_score_threshold:
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=candidate['date_str'],
|
||||
bbox=candidate['token'].bbox,
|
||||
page_no=candidate['token'].page_no,
|
||||
score=score,
|
||||
matched_text=candidate['matched_text'],
|
||||
context_keywords=candidate['context_keywords']
|
||||
))
|
||||
|
||||
# Sort by score and return best matches
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Only return the best match to avoid multiple labels for same field
|
||||
return matches[:1] if matches else []
|
||||
|
||||
def _find_context_keywords(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
target_token: TokenLike,
|
||||
field_name: str
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if self._token_index:
|
||||
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = self._token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= self.context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
|
||||
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||
"""Check if two tokens are on the same line."""
|
||||
# Check vertical overlap
|
||||
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||
return y_overlap > min_height * 0.5
|
||||
|
||||
def _parse_amount(self, text: str | int | float) -> float | None:
|
||||
"""Try to parse text as a monetary amount."""
|
||||
# Convert to string first
|
||||
text = str(text)
|
||||
|
||||
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
|
||||
if ore_match:
|
||||
kronor = ore_match.group(1)
|
||||
ore = ore_match.group(2)
|
||||
try:
|
||||
return float(f"{kronor}.{ore}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
|
||||
text = re.sub(r'\s*\(.*\)', '', text)
|
||||
|
||||
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
|
||||
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'[:-]', '', text)
|
||||
|
||||
# Remove spaces (thousand separators) but be careful with öre format
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
# Swedish format: "500,00" means 500.00
|
||||
# Need to handle cases like "500,00." (after removing "kr.")
|
||||
if ',' in text:
|
||||
# Remove any trailing dots first (from "kr." removal)
|
||||
text = text.rstrip('.')
|
||||
# Now replace comma with dot
|
||||
if '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Remove any remaining non-numeric characters except dot
|
||||
text = re.sub(r'[^\d.]', '', text)
|
||||
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
@@ -803,7 +168,7 @@ class FieldMatcher:
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
if bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
@@ -821,27 +186,6 @@ class FieldMatcher:
|
||||
|
||||
return unique
|
||||
|
||||
def _bbox_overlap(
|
||||
self,
|
||||
bbox1: tuple[float, float, float, float],
|
||||
bbox2: tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||
x1 = max(bbox1[0], bbox2[0])
|
||||
y1 = max(bbox1[1], bbox2[1])
|
||||
x2 = min(bbox1[2], bbox2[2])
|
||||
y2 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def find_field_matches(
|
||||
tokens: list[TokenLike],
|
||||
|
||||
875
src/matcher/field_matcher_old.py
Normal file
875
src/matcher/field_matcher_old.py
Normal file
@@ -0,0 +1,875 @@
|
||||
"""
|
||||
Field Matching Module
|
||||
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
_DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
|
||||
|
||||
|
||||
def _normalize_dashes(text: str) -> str:
|
||||
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
|
||||
return _DASH_PATTERN.sub('-', text)
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
"""Protocol for token objects."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
field: str
|
||||
value: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
score: float # 0-1 confidence score
|
||||
matched_text: str # Actual text that matched
|
||||
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||
|
||||
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||
"""Convert to YOLO annotation format."""
|
||||
x0, y0, x1, y1 = self.bbox
|
||||
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
|
||||
|
||||
# Context keywords for each field type (Swedish invoice terms)
|
||||
CONTEXT_KEYWORDS = {
|
||||
'InvoiceNumber': ['fakturanr', 'fakturanummer', 'invoice', 'inv.nr', 'inv nr', 'nr'],
|
||||
'InvoiceDate': ['fakturadatum', 'datum', 'date', 'utfärdad', 'utskriftsdatum', 'dokumentdatum'],
|
||||
'InvoiceDueDate': ['förfallodatum', 'förfaller', 'due date', 'betalas senast', 'att betala senast',
|
||||
'förfallodag', 'oss tillhanda senast', 'senast'],
|
||||
'OCR': ['ocr', 'referens', 'betalningsreferens', 'ref'],
|
||||
'Bankgiro': ['bankgiro', 'bg', 'bg-nr', 'bg nr'],
|
||||
'Plusgiro': ['plusgiro', 'pg', 'pg-nr', 'pg nr'],
|
||||
'Amount': ['att betala', 'summa', 'total', 'belopp', 'amount', 'totalt', 'att erlägga', 'sek', 'kr'],
|
||||
'supplier_organisation_number': ['organisationsnummer', 'org.nr', 'org nr', 'orgnr', 'org.nummer',
|
||||
'momsreg', 'momsnr', 'moms nr', 'vat', 'corporate id'],
|
||||
'supplier_accounts': ['konto', 'kontonr', 'konto nr', 'account', 'klientnr', 'kundnr'],
|
||||
}
|
||||
|
||||
|
||||
class FieldMatcher:
|
||||
"""Matches field values to document tokens."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
|
||||
min_score_threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize the matcher.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords (default 200px to handle
|
||||
typical label-value spacing in scanned invoices at 150 DPI)
|
||||
min_score_threshold: Minimum score to consider a match valid
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
field_name: str,
|
||||
normalized_values: list[str],
|
||||
page_no: int = 0
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find all matches for a field in the token list.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_name: Name of the field to match
|
||||
normalized_values: List of normalized value variants to search for
|
||||
page_no: Page number to filter tokens
|
||||
|
||||
Returns:
|
||||
List of Match objects sorted by score (descending)
|
||||
"""
|
||||
matches = []
|
||||
# Filter tokens by page and exclude hidden metadata tokens
|
||||
# Hidden tokens often have bbox with y < 0 or y > page_height
|
||||
# These are typically PDF metadata stored as invisible text
|
||||
page_tokens = [
|
||||
t for t in tokens
|
||||
if t.page_no == page_no and t.bbox[1] >= 0 and t.bbox[3] > t.bbox[1]
|
||||
]
|
||||
|
||||
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
|
||||
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||
matches.extend(exact_matches)
|
||||
|
||||
# Strategy 2: Multi-token concatenation
|
||||
concat_matches = self._find_concatenated_matches(page_tokens, value, field_name)
|
||||
matches.extend(concat_matches)
|
||||
|
||||
# Strategy 3: Fuzzy match (for amounts and dates only)
|
||||
if field_name in ('Amount', 'InvoiceDate', 'InvoiceDueDate'):
|
||||
fuzzy_matches = self._find_fuzzy_matches(page_tokens, value, field_name)
|
||||
matches.extend(fuzzy_matches)
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
# Note: Amount is excluded because short numbers like "451" can incorrectly match
|
||||
# in OCR payment lines or other unrelated text
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
# Strategy 5: Flexible date matching (year-month match, nearby dates, heuristic selection)
|
||||
# Only if no exact matches found for date fields
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate') and not matches:
|
||||
flexible_matches = self._find_flexible_date_matches(
|
||||
page_tokens, normalized_values, field_name
|
||||
)
|
||||
matches.extend(flexible_matches)
|
||||
|
||||
# Deduplicate and sort by score
|
||||
matches = self._deduplicate_matches(matches)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Clear token index to free memory
|
||||
self._token_index = None
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _find_exact_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find tokens that exactly match the value."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts') else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif value_digits is not None:
|
||||
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Boost score if context keywords are nearby
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
score = min(1.0, score + context_boost)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=score,
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_concatenated_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
matches = []
|
||||
value_clean = _WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
for i, start_token in enumerate(sorted_tokens):
|
||||
# Try to build the value by concatenating nearby tokens
|
||||
concat_text = start_token.text.strip()
|
||||
concat_bbox = list(start_token.bbox)
|
||||
used_tokens = [start_token]
|
||||
|
||||
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||
next_token = sorted_tokens[j]
|
||||
|
||||
# Check if tokens are on the same line (y overlap)
|
||||
if not self._tokens_on_same_line(start_token, next_token):
|
||||
break
|
||||
|
||||
# Check horizontal proximity
|
||||
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||
break
|
||||
|
||||
concat_text += next_token.text.strip()
|
||||
used_tokens.append(next_token)
|
||||
|
||||
# Update bounding box
|
||||
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, start_token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=tuple(concat_bbox),
|
||||
page_no=start_token.page_no,
|
||||
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||
matched_text=concat_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
break
|
||||
|
||||
return matches
|
||||
|
||||
def _find_substring_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find value as a substring within longer tokens.
|
||||
|
||||
Handles cases like:
|
||||
- 'Fakturadatum: 2026-01-09' where the date is embedded
|
||||
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
|
||||
- 'OCR: 1234567890' where reference number is embedded
|
||||
|
||||
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
|
||||
Only matches if the value appears as a distinct segment (not part of a larger number).
|
||||
"""
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
# Fields where spaces/dashes should be ignored during matching
|
||||
# (e.g., org number "55 65 74-6624" should match "5565746624")
|
||||
ignore_spaces_fields = ('supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts')
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
# Normalize different dash types to hyphen-minus for matching
|
||||
token_text_normalized = _normalize_dashes(token_text)
|
||||
|
||||
# For certain fields, also try matching with spaces/dashes removed
|
||||
if field_name in ignore_spaces_fields:
|
||||
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
|
||||
value_compact = value.replace(' ', '').replace('-', '')
|
||||
else:
|
||||
token_text_compact = None
|
||||
value_compact = None
|
||||
|
||||
# Skip if token is the same length as value (would be exact match)
|
||||
if len(token_text_normalized) <= len(value):
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
idx = None
|
||||
case_sensitive_match = True
|
||||
used_compact = False
|
||||
|
||||
if value in token_text_normalized:
|
||||
idx = token_text_normalized.find(value)
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
elif token_text_compact and value_compact in token_text_compact:
|
||||
# Try compact matching (spaces/dashes removed)
|
||||
idx = token_text_compact.find(value_compact)
|
||||
used_compact = True
|
||||
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
|
||||
idx = token_text_compact.lower().find(value_compact.lower())
|
||||
case_sensitive_match = False
|
||||
used_compact = True
|
||||
|
||||
if idx is None:
|
||||
continue
|
||||
|
||||
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
|
||||
if used_compact:
|
||||
# Verify proper boundary in compact text
|
||||
if idx > 0 and token_text_compact[idx - 1].isdigit():
|
||||
continue
|
||||
end_idx = idx + len(value_compact)
|
||||
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
|
||||
continue
|
||||
else:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
# Must be non-digit (allow : space - etc)
|
||||
if char_before.isdigit():
|
||||
continue
|
||||
|
||||
# Check character after (if exists)
|
||||
end_idx = idx + len(value)
|
||||
if end_idx < len(token_text_normalized):
|
||||
char_after = token_text_normalized[end_idx]
|
||||
# Must be non-digit
|
||||
if char_after.isdigit():
|
||||
continue
|
||||
|
||||
# Found valid substring match
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||
token_lower = token_text.lower()
|
||||
inline_context = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_context.append(keyword)
|
||||
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
return matches
|
||||
|
||||
def _find_fuzzy_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""Find approximate matches for amounts and dates."""
|
||||
matches = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
if field_name == 'Amount':
|
||||
# Try to parse both as numbers
|
||||
try:
|
||||
token_num = self._parse_amount(token_text)
|
||||
value_num = self._parse_amount(value)
|
||||
|
||||
if token_num is not None and value_num is not None:
|
||||
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.8 + context_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
return matches
|
||||
|
||||
def _find_flexible_date_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
normalized_values: list[str],
|
||||
field_name: str
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Flexible date matching when exact match fails.
|
||||
|
||||
Strategies:
|
||||
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||
2. Nearby date match: Match dates within 7 days of CSV value
|
||||
3. Heuristic selection: Use context keywords to select the best date
|
||||
|
||||
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||
but we can still find a reasonable date to label.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
matches = []
|
||||
|
||||
# Parse the target date from normalized values
|
||||
target_date = None
|
||||
for value in normalized_values:
|
||||
# Try to parse YYYY-MM-DD format
|
||||
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||
if date_match:
|
||||
try:
|
||||
target_date = datetime(
|
||||
int(date_match.group(1)),
|
||||
int(date_match.group(2)),
|
||||
int(date_match.group(3))
|
||||
)
|
||||
break
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not target_date:
|
||||
return matches
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in _DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))
|
||||
)
|
||||
date_str = match.group(0)
|
||||
|
||||
# Calculate date difference
|
||||
days_diff = abs((found_date - target_date).days)
|
||||
|
||||
# Check for context keywords
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, token, field_name
|
||||
)
|
||||
|
||||
# Check if keyword is in the same token
|
||||
token_lower = token_text.lower()
|
||||
inline_keywords = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_keywords.append(keyword)
|
||||
|
||||
date_candidates.append({
|
||||
'token': token,
|
||||
'date': found_date,
|
||||
'date_str': date_str,
|
||||
'matched_text': token_text,
|
||||
'days_diff': days_diff,
|
||||
'context_keywords': context_keywords + inline_keywords,
|
||||
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||
'same_year_month': (found_date.year == target_date.year and
|
||||
found_date.month == target_date.month),
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not date_candidates:
|
||||
return matches
|
||||
|
||||
# Score and rank candidates
|
||||
for candidate in date_candidates:
|
||||
score = 0.0
|
||||
|
||||
# Strategy 1: Same year-month gets higher score
|
||||
if candidate['same_year_month']:
|
||||
score = 0.7
|
||||
# Bonus if day is close
|
||||
if candidate['days_diff'] <= 7:
|
||||
score = 0.75
|
||||
if candidate['days_diff'] <= 3:
|
||||
score = 0.8
|
||||
# Strategy 2: Nearby dates (within 14 days)
|
||||
elif candidate['days_diff'] <= 14:
|
||||
score = 0.6
|
||||
elif candidate['days_diff'] <= 30:
|
||||
score = 0.55
|
||||
else:
|
||||
# Too far apart, skip unless has strong context
|
||||
if not candidate['context_keywords']:
|
||||
continue
|
||||
score = 0.5
|
||||
|
||||
# Strategy 3: Boost with context keywords
|
||||
score = min(1.0, score + candidate['context_boost'])
|
||||
|
||||
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||
if candidate['context_keywords']:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
if score >= self.min_score_threshold:
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=candidate['date_str'],
|
||||
bbox=candidate['token'].bbox,
|
||||
page_no=candidate['token'].page_no,
|
||||
score=score,
|
||||
matched_text=candidate['matched_text'],
|
||||
context_keywords=candidate['context_keywords']
|
||||
))
|
||||
|
||||
# Sort by score and return best matches
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Only return the best match to avoid multiple labels for same field
|
||||
return matches[:1] if matches else []
|
||||
|
||||
def _find_context_keywords(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
target_token: TokenLike,
|
||||
field_name: str
|
||||
) -> tuple[list[str], float]:
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if self._token_index:
|
||||
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = self._token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= self.context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
|
||||
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||
"""Check if two tokens are on the same line."""
|
||||
# Check vertical overlap
|
||||
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||
return y_overlap > min_height * 0.5
|
||||
|
||||
def _parse_amount(self, text: str | int | float) -> float | None:
|
||||
"""Try to parse text as a monetary amount."""
|
||||
# Convert to string first
|
||||
text = str(text)
|
||||
|
||||
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
|
||||
if ore_match:
|
||||
kronor = ore_match.group(1)
|
||||
ore = ore_match.group(2)
|
||||
try:
|
||||
return float(f"{kronor}.{ore}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
|
||||
text = re.sub(r'\s*\(.*\)', '', text)
|
||||
|
||||
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
|
||||
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'[:-]', '', text)
|
||||
|
||||
# Remove spaces (thousand separators) but be careful with öre format
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
# Swedish format: "500,00" means 500.00
|
||||
# Need to handle cases like "500,00." (after removing "kr.")
|
||||
if ',' in text:
|
||||
# Remove any trailing dots first (from "kr." removal)
|
||||
text = text.rstrip('.')
|
||||
# Now replace comma with dot
|
||||
if '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Remove any remaining non-numeric characters except dot
|
||||
text = re.sub(r'[^\d.]', '', text)
|
||||
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
|
||||
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
|
||||
"""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Sort by: 1) score descending, 2) prefer matches with context keywords,
|
||||
# 3) prefer upper positions (smaller y) for same-score matches
|
||||
# This helps select the "main" occurrence in invoice body rather than footer
|
||||
matches.sort(key=lambda m: (
|
||||
-m.score,
|
||||
-len(m.context_keywords), # More keywords = better
|
||||
m.bbox[1] # Smaller y (upper position) = better
|
||||
))
|
||||
|
||||
# Use spatial grid for efficient overlap checking
|
||||
# Grid cell size based on typical bbox size
|
||||
grid_size = 50.0 # pixels
|
||||
grid: dict[tuple[int, int], list[Match]] = {}
|
||||
unique = []
|
||||
|
||||
for match in matches:
|
||||
bbox = match.bbox
|
||||
# Calculate grid cells this bbox touches
|
||||
min_gx = int(bbox[0] / grid_size)
|
||||
min_gy = int(bbox[1] / grid_size)
|
||||
max_gx = int(bbox[2] / grid_size)
|
||||
max_gy = int(bbox[3] / grid_size)
|
||||
|
||||
# Check for overlap only with matches in nearby grid cells
|
||||
is_duplicate = False
|
||||
cells_to_check = set()
|
||||
for gx in range(min_gx - 1, max_gx + 2):
|
||||
for gy in range(min_gy - 1, max_gy + 2):
|
||||
cells_to_check.add((gx, gy))
|
||||
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
unique.append(match)
|
||||
# Add to all grid cells this bbox touches
|
||||
for gx in range(min_gx, max_gx + 1):
|
||||
for gy in range(min_gy, max_gy + 1):
|
||||
key = (gx, gy)
|
||||
if key not in grid:
|
||||
grid[key] = []
|
||||
grid[key].append(match)
|
||||
|
||||
return unique
|
||||
|
||||
def _bbox_overlap(
|
||||
self,
|
||||
bbox1: tuple[float, float, float, float],
|
||||
bbox2: tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||
x1 = max(bbox1[0], bbox2[0])
|
||||
y1 = max(bbox1[1], bbox2[1])
|
||||
x2 = min(bbox1[2], bbox2[2])
|
||||
y2 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
def find_field_matches(
|
||||
tokens: list[TokenLike],
|
||||
field_values: dict[str, str],
|
||||
page_no: int = 0
|
||||
) -> dict[str, list[Match]]:
|
||||
"""
|
||||
Convenience function to find matches for multiple fields.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens from the document
|
||||
field_values: Dict of field_name -> value to search for
|
||||
page_no: Page number
|
||||
|
||||
Returns:
|
||||
Dict of field_name -> list of matches
|
||||
"""
|
||||
from ..normalize import normalize_field
|
||||
|
||||
matcher = FieldMatcher()
|
||||
results = {}
|
||||
|
||||
for field_name, value in field_values.items():
|
||||
if value is None or str(value).strip() == '':
|
||||
continue
|
||||
|
||||
normalized_values = normalize_field(field_name, str(value))
|
||||
matches = matcher.find_matches(tokens, field_name, normalized_values, page_no)
|
||||
results[field_name] = matches
|
||||
|
||||
return results
|
||||
36
src/matcher/models.py
Normal file
36
src/matcher/models.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Data models for field matching.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
"""Protocol for token objects."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
field: str
|
||||
value: str
|
||||
bbox: tuple[float, float, float, float] # (x0, y0, x1, y1)
|
||||
page_no: int
|
||||
score: float # 0-1 confidence score
|
||||
matched_text: str # Actual text that matched
|
||||
context_keywords: list[str] # Nearby keywords that boosted confidence
|
||||
|
||||
def to_yolo_format(self, image_width: float, image_height: float, class_id: int) -> str:
|
||||
"""Convert to YOLO annotation format."""
|
||||
x0, y0, x1, y1 = self.bbox
|
||||
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
y_center = (y0 + y1) / 2 / image_height
|
||||
width = (x1 - x0) / image_width
|
||||
height = (y1 - y0) / image_height
|
||||
|
||||
return f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
|
||||
17
src/matcher/strategies/__init__.py
Normal file
17
src/matcher/strategies/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Matching strategies for field matching.
|
||||
"""
|
||||
|
||||
from .exact_matcher import ExactMatcher
|
||||
from .concatenated_matcher import ConcatenatedMatcher
|
||||
from .substring_matcher import SubstringMatcher
|
||||
from .fuzzy_matcher import FuzzyMatcher
|
||||
from .flexible_date_matcher import FlexibleDateMatcher
|
||||
|
||||
__all__ = [
|
||||
'ExactMatcher',
|
||||
'ConcatenatedMatcher',
|
||||
'SubstringMatcher',
|
||||
'FuzzyMatcher',
|
||||
'FlexibleDateMatcher',
|
||||
]
|
||||
42
src/matcher/strategies/base.py
Normal file
42
src/matcher/strategies/base.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Base class for matching strategies.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
|
||||
|
||||
class BaseMatchStrategy(ABC):
|
||||
"""Base class for all matching strategies."""
|
||||
|
||||
def __init__(self, context_radius: float = 200.0):
|
||||
"""
|
||||
Initialize the strategy.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
|
||||
@abstractmethod
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""
|
||||
Find matches for the given value.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to search
|
||||
value: Value to find
|
||||
field_name: Name of the field
|
||||
token_index: Optional spatial index for efficient lookup
|
||||
|
||||
Returns:
|
||||
List of Match objects
|
||||
"""
|
||||
pass
|
||||
73
src/matcher/strategies/concatenated_matcher.py
Normal file
73
src/matcher/strategies/concatenated_matcher.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Concatenated match strategy - finds value by concatenating adjacent tokens.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import WHITESPACE_PATTERN, tokens_on_same_line
|
||||
|
||||
|
||||
class ConcatenatedMatcher(BaseMatchStrategy):
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find concatenated matches."""
|
||||
matches = []
|
||||
value_clean = WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
|
||||
for i, start_token in enumerate(sorted_tokens):
|
||||
# Try to build the value by concatenating nearby tokens
|
||||
concat_text = start_token.text.strip()
|
||||
concat_bbox = list(start_token.bbox)
|
||||
used_tokens = [start_token]
|
||||
|
||||
for j in range(i + 1, min(i + 5, len(sorted_tokens))): # Max 5 tokens
|
||||
next_token = sorted_tokens[j]
|
||||
|
||||
# Check if tokens are on the same line (y overlap)
|
||||
if not tokens_on_same_line(start_token, next_token):
|
||||
break
|
||||
|
||||
# Check horizontal proximity
|
||||
if next_token.bbox[0] - concat_bbox[2] > 50: # Max 50px gap
|
||||
break
|
||||
|
||||
concat_text += next_token.text.strip()
|
||||
used_tokens.append(next_token)
|
||||
|
||||
# Update bounding box
|
||||
concat_bbox[0] = min(concat_bbox[0], next_token.bbox[0])
|
||||
concat_bbox[1] = min(concat_bbox[1], next_token.bbox[1])
|
||||
concat_bbox[2] = max(concat_bbox[2], next_token.bbox[2])
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, start_token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=tuple(concat_bbox),
|
||||
page_no=start_token.page_no,
|
||||
score=min(1.0, 0.85 + context_boost), # Slightly lower base score
|
||||
matched_text=concat_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
break
|
||||
|
||||
return matches
|
||||
65
src/matcher/strategies/exact_matcher.py
Normal file
65
src/matcher/strategies/exact_matcher.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Exact match strategy.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import NON_DIGIT_PATTERN
|
||||
|
||||
|
||||
class ExactMatcher(BaseMatchStrategy):
|
||||
"""Find tokens that exactly match the value."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find exact matches."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = NON_DIGIT_PATTERN.sub('', value) if field_name in (
|
||||
'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro',
|
||||
'supplier_organisation_number', 'supplier_accounts'
|
||||
) else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif token_index and token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif value_digits is not None:
|
||||
token_digits = NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
# Boost score if context keywords are nearby
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
score = min(1.0, score + context_boost)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=score,
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
|
||||
return matches
|
||||
149
src/matcher/strategies/flexible_date_matcher.py
Normal file
149
src/matcher/strategies/flexible_date_matcher.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Flexible date match strategy - finds dates with year-month or nearby date matching.
|
||||
"""
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords, CONTEXT_KEYWORDS
|
||||
from ..utils import DATE_PATTERN
|
||||
|
||||
|
||||
class FlexibleDateMatcher(BaseMatchStrategy):
|
||||
"""
|
||||
Flexible date matching when exact match fails.
|
||||
|
||||
Strategies:
|
||||
1. Year-month match: If CSV has 2026-01-15, match any 2026-01-XX date
|
||||
2. Nearby date match: Match dates within 7 days of CSV value
|
||||
3. Heuristic selection: Use context keywords to select the best date
|
||||
|
||||
This handles cases where CSV InvoiceDate doesn't exactly match PDF,
|
||||
but we can still find a reasonable date to label.
|
||||
"""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find flexible date matches."""
|
||||
matches = []
|
||||
|
||||
# Parse the target date from normalized values
|
||||
target_date = None
|
||||
|
||||
# Try to parse YYYY-MM-DD format
|
||||
date_match = re.match(r'^(\d{4})-(\d{2})-(\d{2})$', value)
|
||||
if date_match:
|
||||
try:
|
||||
target_date = datetime(
|
||||
int(date_match.group(1)),
|
||||
int(date_match.group(2)),
|
||||
int(date_match.group(3))
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not target_date:
|
||||
return matches
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
int(match.group(2)),
|
||||
int(match.group(3))
|
||||
)
|
||||
date_str = match.group(0)
|
||||
|
||||
# Calculate date difference
|
||||
days_diff = abs((found_date - target_date).days)
|
||||
|
||||
# Check for context keywords
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
# Check if keyword is in the same token
|
||||
token_lower = token_text.lower()
|
||||
inline_keywords = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_keywords.append(keyword)
|
||||
|
||||
date_candidates.append({
|
||||
'token': token,
|
||||
'date': found_date,
|
||||
'date_str': date_str,
|
||||
'matched_text': token_text,
|
||||
'days_diff': days_diff,
|
||||
'context_keywords': context_keywords + inline_keywords,
|
||||
'context_boost': context_boost + (0.1 if inline_keywords else 0),
|
||||
'same_year_month': (found_date.year == target_date.year and
|
||||
found_date.month == target_date.month),
|
||||
})
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
if not date_candidates:
|
||||
return matches
|
||||
|
||||
# Score and rank candidates
|
||||
for candidate in date_candidates:
|
||||
score = 0.0
|
||||
|
||||
# Strategy 1: Same year-month gets higher score
|
||||
if candidate['same_year_month']:
|
||||
score = 0.7
|
||||
# Bonus if day is close
|
||||
if candidate['days_diff'] <= 7:
|
||||
score = 0.75
|
||||
if candidate['days_diff'] <= 3:
|
||||
score = 0.8
|
||||
# Strategy 2: Nearby dates (within 14 days)
|
||||
elif candidate['days_diff'] <= 14:
|
||||
score = 0.6
|
||||
elif candidate['days_diff'] <= 30:
|
||||
score = 0.55
|
||||
else:
|
||||
# Too far apart, skip unless has strong context
|
||||
if not candidate['context_keywords']:
|
||||
continue
|
||||
score = 0.5
|
||||
|
||||
# Strategy 3: Boost with context keywords
|
||||
score = min(1.0, score + candidate['context_boost'])
|
||||
|
||||
# For InvoiceDate, prefer dates that appear near invoice-related keywords
|
||||
# For InvoiceDueDate, prefer dates near due-date keywords
|
||||
if candidate['context_keywords']:
|
||||
score = min(1.0, score + 0.05)
|
||||
|
||||
if score >= 0.5: # Min threshold for flexible matching
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=candidate['date_str'],
|
||||
bbox=candidate['token'].bbox,
|
||||
page_no=candidate['token'].page_no,
|
||||
score=score,
|
||||
matched_text=candidate['matched_text'],
|
||||
context_keywords=candidate['context_keywords']
|
||||
))
|
||||
|
||||
# Sort by score and return best matches
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Only return the best match to avoid multiple labels for same field
|
||||
return matches[:1] if matches else []
|
||||
52
src/matcher/strategies/fuzzy_matcher.py
Normal file
52
src/matcher/strategies/fuzzy_matcher.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Fuzzy match strategy for amounts and dates.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords
|
||||
from ..utils import parse_amount
|
||||
|
||||
|
||||
class FuzzyMatcher(BaseMatchStrategy):
|
||||
"""Find approximate matches for amounts and dates."""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find fuzzy matches."""
|
||||
matches = []
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
if field_name == 'Amount':
|
||||
# Try to parse both as numbers
|
||||
try:
|
||||
token_num = parse_amount(token_text)
|
||||
value_num = parse_amount(value)
|
||||
|
||||
if token_num is not None and value_num is not None:
|
||||
if abs(token_num - value_num) < 0.01: # Within 1 cent
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox,
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, 0.8 + context_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords
|
||||
))
|
||||
except:
|
||||
pass
|
||||
|
||||
return matches
|
||||
143
src/matcher/strategies/substring_matcher.py
Normal file
143
src/matcher/strategies/substring_matcher.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
Substring match strategy - finds value as substring within longer tokens.
|
||||
"""
|
||||
|
||||
from .base import BaseMatchStrategy
|
||||
from ..models import TokenLike, Match
|
||||
from ..token_index import TokenIndex
|
||||
from ..context import find_context_keywords, CONTEXT_KEYWORDS
|
||||
from ..utils import normalize_dashes
|
||||
|
||||
|
||||
class SubstringMatcher(BaseMatchStrategy):
|
||||
"""
|
||||
Find value as a substring within longer tokens.
|
||||
|
||||
Handles cases like:
|
||||
- 'Fakturadatum: 2026-01-09' where the date is embedded
|
||||
- 'Fakturanummer: 2465027205' where OCR/invoice number is embedded
|
||||
- 'OCR: 1234567890' where reference number is embedded
|
||||
|
||||
Uses lower score (0.75-0.85) than exact match to prefer exact matches.
|
||||
Only matches if the value appears as a distinct segment (not part of a larger number).
|
||||
"""
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
tokens: list[TokenLike],
|
||||
value: str,
|
||||
field_name: str,
|
||||
token_index: TokenIndex | None = None
|
||||
) -> list[Match]:
|
||||
"""Find substring matches."""
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = (
|
||||
'InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR',
|
||||
'Bankgiro', 'Plusgiro', 'Amount',
|
||||
'supplier_organisation_number', 'supplier_accounts', 'customer_number'
|
||||
)
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
# Fields where spaces/dashes should be ignored during matching
|
||||
# (e.g., org number "55 65 74-6624" should match "5565746624")
|
||||
ignore_spaces_fields = (
|
||||
'supplier_organisation_number', 'Bankgiro', 'Plusgiro', 'supplier_accounts'
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
# Normalize different dash types to hyphen-minus for matching
|
||||
token_text_normalized = normalize_dashes(token_text)
|
||||
|
||||
# For certain fields, also try matching with spaces/dashes removed
|
||||
if field_name in ignore_spaces_fields:
|
||||
token_text_compact = token_text_normalized.replace(' ', '').replace('-', '')
|
||||
value_compact = value.replace(' ', '').replace('-', '')
|
||||
else:
|
||||
token_text_compact = None
|
||||
value_compact = None
|
||||
|
||||
# Skip if token is the same length as value (would be exact match)
|
||||
if len(token_text_normalized) <= len(value):
|
||||
continue
|
||||
|
||||
# Check if value appears as substring (using normalized text)
|
||||
# Try case-sensitive first, then case-insensitive
|
||||
idx = None
|
||||
case_sensitive_match = True
|
||||
used_compact = False
|
||||
|
||||
if value in token_text_normalized:
|
||||
idx = token_text_normalized.find(value)
|
||||
elif value.lower() in token_text_normalized.lower():
|
||||
idx = token_text_normalized.lower().find(value.lower())
|
||||
case_sensitive_match = False
|
||||
elif token_text_compact and value_compact in token_text_compact:
|
||||
# Try compact matching (spaces/dashes removed)
|
||||
idx = token_text_compact.find(value_compact)
|
||||
used_compact = True
|
||||
elif token_text_compact and value_compact.lower() in token_text_compact.lower():
|
||||
idx = token_text_compact.lower().find(value_compact.lower())
|
||||
case_sensitive_match = False
|
||||
used_compact = True
|
||||
|
||||
if idx is None:
|
||||
continue
|
||||
|
||||
# For compact matching, boundary check is simpler (just check it's 10 consecutive digits)
|
||||
if used_compact:
|
||||
# Verify proper boundary in compact text
|
||||
if idx > 0 and token_text_compact[idx - 1].isdigit():
|
||||
continue
|
||||
end_idx = idx + len(value_compact)
|
||||
if end_idx < len(token_text_compact) and token_text_compact[end_idx].isdigit():
|
||||
continue
|
||||
else:
|
||||
# Verify it's a proper boundary match (not part of a larger number)
|
||||
# Check character before (if exists)
|
||||
if idx > 0:
|
||||
char_before = token_text_normalized[idx - 1]
|
||||
# Must be non-digit (allow : space - etc)
|
||||
if char_before.isdigit():
|
||||
continue
|
||||
|
||||
# Check character after (if exists)
|
||||
end_idx = idx + len(value)
|
||||
if end_idx < len(token_text_normalized):
|
||||
char_after = token_text_normalized[end_idx]
|
||||
# Must be non-digit
|
||||
if char_after.isdigit():
|
||||
continue
|
||||
|
||||
# Found valid substring match
|
||||
context_keywords, context_boost = find_context_keywords(
|
||||
tokens, token, field_name, self.context_radius, token_index
|
||||
)
|
||||
|
||||
# Check if context keyword is in the same token (like "Fakturadatum:")
|
||||
token_lower = token_text.lower()
|
||||
inline_context = []
|
||||
for keyword in CONTEXT_KEYWORDS.get(field_name, []):
|
||||
if keyword in token_lower:
|
||||
inline_context.append(keyword)
|
||||
|
||||
# Boost score if keyword is inline
|
||||
inline_boost = 0.1 if inline_context else 0
|
||||
|
||||
# Lower score for case-insensitive match
|
||||
base_score = 0.75 if case_sensitive_match else 0.70
|
||||
|
||||
matches.append(Match(
|
||||
field=field_name,
|
||||
value=value,
|
||||
bbox=token.bbox, # Use full token bbox
|
||||
page_no=token.page_no,
|
||||
score=min(1.0, base_score + context_boost + inline_boost),
|
||||
matched_text=token_text,
|
||||
context_keywords=context_keywords + inline_context
|
||||
))
|
||||
|
||||
return matches
|
||||
@@ -1,896 +0,0 @@
|
||||
"""
|
||||
Tests for the Field Matching Module.
|
||||
|
||||
Tests cover all matcher functions in src/matcher/field_matcher.py
|
||||
|
||||
Usage:
|
||||
pytest src/matcher/test_field_matcher.py -v -o 'addopts='
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from src.matcher.field_matcher import (
|
||||
FieldMatcher,
|
||||
Match,
|
||||
TokenIndex,
|
||||
CONTEXT_KEYWORDS,
|
||||
_normalize_dashes,
|
||||
find_field_matches,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToken:
|
||||
"""Mock token for testing."""
|
||||
text: str
|
||||
bbox: tuple[float, float, float, float]
|
||||
page_no: int = 0
|
||||
|
||||
|
||||
class TestNormalizeDashes:
|
||||
"""Tests for _normalize_dashes function."""
|
||||
|
||||
def test_normalize_en_dash(self):
|
||||
"""Should normalize en-dash to hyphen."""
|
||||
assert _normalize_dashes("123\u2013456") == "123-456"
|
||||
|
||||
def test_normalize_em_dash(self):
|
||||
"""Should normalize em-dash to hyphen."""
|
||||
assert _normalize_dashes("123\u2014456") == "123-456"
|
||||
|
||||
def test_normalize_minus_sign(self):
|
||||
"""Should normalize minus sign to hyphen."""
|
||||
assert _normalize_dashes("123\u2212456") == "123-456"
|
||||
|
||||
def test_normalize_middle_dot(self):
|
||||
"""Should normalize middle dot to hyphen."""
|
||||
assert _normalize_dashes("123\u00b7456") == "123-456"
|
||||
|
||||
def test_normal_hyphen_unchanged(self):
|
||||
"""Should keep normal hyphen unchanged."""
|
||||
assert _normalize_dashes("123-456") == "123-456"
|
||||
|
||||
|
||||
class TestTokenIndex:
|
||||
"""Tests for TokenIndex class."""
|
||||
|
||||
def test_build_index(self):
|
||||
"""Should build spatial index from tokens."""
|
||||
tokens = [
|
||||
MockToken("hello", (0, 0, 50, 20)),
|
||||
MockToken("world", (60, 0, 110, 20)),
|
||||
]
|
||||
index = TokenIndex(tokens)
|
||||
assert len(index.tokens) == 2
|
||||
|
||||
def test_get_center(self):
|
||||
"""Should return correct center coordinates."""
|
||||
token = MockToken("test", (0, 0, 100, 50))
|
||||
tokens = [token]
|
||||
index = TokenIndex(tokens)
|
||||
center = index.get_center(token)
|
||||
assert center == (50.0, 25.0)
|
||||
|
||||
def test_get_text_lower(self):
|
||||
"""Should return lowercase text."""
|
||||
token = MockToken("HELLO World", (0, 0, 100, 20))
|
||||
tokens = [token]
|
||||
index = TokenIndex(tokens)
|
||||
assert index.get_text_lower(token) == "hello world"
|
||||
|
||||
def test_find_nearby_within_radius(self):
|
||||
"""Should find tokens within radius."""
|
||||
token1 = MockToken("hello", (0, 0, 50, 20))
|
||||
token2 = MockToken("world", (60, 0, 110, 20)) # 60px away
|
||||
token3 = MockToken("far", (500, 0, 550, 20)) # 500px away
|
||||
tokens = [token1, token2, token3]
|
||||
index = TokenIndex(tokens)
|
||||
|
||||
nearby = index.find_nearby(token1, radius=100)
|
||||
assert len(nearby) == 1
|
||||
assert nearby[0].text == "world"
|
||||
|
||||
def test_find_nearby_excludes_self(self):
|
||||
"""Should not include the target token itself."""
|
||||
token1 = MockToken("hello", (0, 0, 50, 20))
|
||||
token2 = MockToken("world", (60, 0, 110, 20))
|
||||
tokens = [token1, token2]
|
||||
index = TokenIndex(tokens)
|
||||
|
||||
nearby = index.find_nearby(token1, radius=100)
|
||||
assert token1 not in nearby
|
||||
|
||||
def test_find_nearby_empty_when_none_in_range(self):
|
||||
"""Should return empty list when no tokens in range."""
|
||||
token1 = MockToken("hello", (0, 0, 50, 20))
|
||||
token2 = MockToken("far", (500, 0, 550, 20))
|
||||
tokens = [token1, token2]
|
||||
index = TokenIndex(tokens)
|
||||
|
||||
nearby = index.find_nearby(token1, radius=50)
|
||||
assert len(nearby) == 0
|
||||
|
||||
|
||||
class TestMatch:
|
||||
"""Tests for Match dataclass."""
|
||||
|
||||
def test_match_creation(self):
|
||||
"""Should create Match with all fields."""
|
||||
match = Match(
|
||||
field="InvoiceNumber",
|
||||
value="12345",
|
||||
bbox=(0, 0, 100, 20),
|
||||
page_no=0,
|
||||
score=0.95,
|
||||
matched_text="12345",
|
||||
context_keywords=["fakturanr"]
|
||||
)
|
||||
assert match.field == "InvoiceNumber"
|
||||
assert match.value == "12345"
|
||||
assert match.score == 0.95
|
||||
|
||||
def test_to_yolo_format(self):
|
||||
"""Should convert to YOLO annotation format."""
|
||||
match = Match(
|
||||
field="Amount",
|
||||
value="100",
|
||||
bbox=(100, 200, 200, 250), # x0, y0, x1, y1
|
||||
page_no=0,
|
||||
score=1.0,
|
||||
matched_text="100",
|
||||
context_keywords=[]
|
||||
)
|
||||
# Image: 1000x1000
|
||||
yolo = match.to_yolo_format(1000, 1000, class_id=5)
|
||||
|
||||
# Expected: center_x=150, center_y=225, width=100, height=50
|
||||
# Normalized: x_center=0.15, y_center=0.225, w=0.1, h=0.05
|
||||
assert yolo.startswith("5 ")
|
||||
parts = yolo.split()
|
||||
assert len(parts) == 5
|
||||
assert float(parts[1]) == pytest.approx(0.15, rel=1e-4)
|
||||
assert float(parts[2]) == pytest.approx(0.225, rel=1e-4)
|
||||
assert float(parts[3]) == pytest.approx(0.1, rel=1e-4)
|
||||
assert float(parts[4]) == pytest.approx(0.05, rel=1e-4)
|
||||
|
||||
|
||||
class TestFieldMatcher:
|
||||
"""Tests for FieldMatcher class."""
|
||||
|
||||
def test_init_defaults(self):
|
||||
"""Should initialize with default values."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher.context_radius == 200.0
|
||||
assert matcher.min_score_threshold == 0.5
|
||||
|
||||
def test_init_custom_params(self):
|
||||
"""Should initialize with custom parameters."""
|
||||
matcher = FieldMatcher(context_radius=300.0, min_score_threshold=0.7)
|
||||
assert matcher.context_radius == 300.0
|
||||
assert matcher.min_score_threshold == 0.7
|
||||
|
||||
|
||||
class TestFieldMatcherExactMatch:
|
||||
"""Tests for exact matching."""
|
||||
|
||||
def test_exact_match_full_score(self):
|
||||
"""Should find exact match with full score."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("12345", (0, 0, 50, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score == 1.0
|
||||
assert matches[0].matched_text == "12345"
|
||||
|
||||
def test_case_insensitive_match(self):
|
||||
"""Should find case-insensitive match with lower score."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("HELLO", (0, 0, 50, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["hello"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score == 0.95
|
||||
|
||||
def test_digits_only_match(self):
|
||||
"""Should match by digits only for numeric fields."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("INV-12345", (0, 0, 80, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score == 0.9
|
||||
|
||||
def test_no_match_when_different(self):
|
||||
"""Should return empty when no match found."""
|
||||
matcher = FieldMatcher(min_score_threshold=0.8)
|
||||
tokens = [MockToken("99999", (0, 0, 50, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
class TestFieldMatcherContextKeywords:
|
||||
"""Tests for context keyword boosting."""
|
||||
|
||||
def test_context_boost_with_nearby_keyword(self):
|
||||
"""Should boost score when context keyword is nearby."""
|
||||
matcher = FieldMatcher(context_radius=200)
|
||||
tokens = [
|
||||
MockToken("fakturanr", (0, 0, 80, 20)), # Context keyword
|
||||
MockToken("12345", (100, 0, 150, 20)), # Value
|
||||
]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
# Score should be boosted above 1.0 (capped at 1.0)
|
||||
assert matches[0].score == 1.0
|
||||
assert "fakturanr" in matches[0].context_keywords
|
||||
|
||||
def test_no_boost_when_keyword_far_away(self):
|
||||
"""Should not boost when keyword is too far."""
|
||||
matcher = FieldMatcher(context_radius=50)
|
||||
tokens = [
|
||||
MockToken("fakturanr", (0, 0, 80, 20)), # Context keyword
|
||||
MockToken("12345", (500, 0, 550, 20)), # Value - far away
|
||||
]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert "fakturanr" not in matches[0].context_keywords
|
||||
|
||||
|
||||
class TestFieldMatcherConcatenatedMatch:
|
||||
"""Tests for concatenated token matching."""
|
||||
|
||||
def test_concatenate_adjacent_tokens(self):
|
||||
"""Should match value split across adjacent tokens."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("123", (0, 0, 30, 20)),
|
||||
MockToken("456", (35, 0, 65, 20)), # Adjacent, same line
|
||||
]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["123456"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert "123456" in matches[0].matched_text or matches[0].value == "123456"
|
||||
|
||||
def test_no_concatenate_when_gap_too_large(self):
|
||||
"""Should not concatenate when gap is too large."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("123", (0, 0, 30, 20)),
|
||||
MockToken("456", (100, 0, 130, 20)), # Gap > 50px
|
||||
]
|
||||
|
||||
# This might still match if exact matches work differently
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["123456"])
|
||||
# No concatenated match expected (only from exact/substring)
|
||||
concat_matches = [m for m in matches if "123456" in m.matched_text]
|
||||
# May or may not find depending on strategy
|
||||
|
||||
|
||||
class TestFieldMatcherSubstringMatch:
|
||||
"""Tests for substring matching."""
|
||||
|
||||
def test_substring_match_in_longer_text(self):
|
||||
"""Should find value as substring in longer token."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("Fakturanummer: 12345", (0, 0, 150, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
# Substring match should have lower score
|
||||
substring_match = [m for m in matches if "12345" in m.matched_text]
|
||||
assert len(substring_match) >= 1
|
||||
|
||||
def test_no_substring_match_when_part_of_larger_number(self):
|
||||
"""Should not match when value is part of a larger number."""
|
||||
matcher = FieldMatcher(min_score_threshold=0.6)
|
||||
tokens = [MockToken("123456789", (0, 0, 100, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["456"])
|
||||
|
||||
# Should not match because 456 is embedded in larger number
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
class TestFieldMatcherFuzzyMatch:
|
||||
"""Tests for fuzzy amount matching."""
|
||||
|
||||
def test_fuzzy_amount_match(self):
|
||||
"""Should match amounts that are numerically equal."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("1234,56", (0, 0, 70, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "Amount", ["1234.56"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
|
||||
def test_fuzzy_amount_with_different_formats(self):
|
||||
"""Should match amounts in different formats."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("1 234,56", (0, 0, 80, 20))]
|
||||
|
||||
matches = matcher.find_matches(tokens, "Amount", ["1234,56"])
|
||||
|
||||
assert len(matches) >= 1
|
||||
|
||||
|
||||
class TestFieldMatcherParseAmount:
|
||||
"""Tests for _parse_amount method."""
|
||||
|
||||
def test_parse_simple_integer(self):
|
||||
"""Should parse simple integer."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("100") == 100.0
|
||||
|
||||
def test_parse_decimal_with_dot(self):
|
||||
"""Should parse decimal with dot."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("100.50") == 100.50
|
||||
|
||||
def test_parse_decimal_with_comma(self):
|
||||
"""Should parse decimal with comma (European format)."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("100,50") == 100.50
|
||||
|
||||
def test_parse_with_thousand_separator(self):
|
||||
"""Should parse with thousand separator."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("1 234,56") == 1234.56
|
||||
|
||||
def test_parse_with_currency_suffix(self):
|
||||
"""Should parse and remove currency suffix."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("100 SEK") == 100.0
|
||||
assert matcher._parse_amount("100 kr") == 100.0
|
||||
|
||||
def test_parse_swedish_ore_format(self):
|
||||
"""Should parse Swedish öre format (kronor space öre)."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("239 00") == 239.00
|
||||
assert matcher._parse_amount("1234 50") == 1234.50
|
||||
|
||||
def test_parse_invalid_returns_none(self):
|
||||
"""Should return None for invalid input."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount("abc") is None
|
||||
assert matcher._parse_amount("") is None
|
||||
|
||||
|
||||
class TestFieldMatcherTokensOnSameLine:
|
||||
"""Tests for _tokens_on_same_line method."""
|
||||
|
||||
def test_same_line_tokens(self):
|
||||
"""Should detect tokens on same line."""
|
||||
matcher = FieldMatcher()
|
||||
token1 = MockToken("hello", (0, 10, 50, 30))
|
||||
token2 = MockToken("world", (60, 12, 110, 28)) # Slight y variation
|
||||
|
||||
assert matcher._tokens_on_same_line(token1, token2) is True
|
||||
|
||||
def test_different_line_tokens(self):
|
||||
"""Should detect tokens on different lines."""
|
||||
matcher = FieldMatcher()
|
||||
token1 = MockToken("hello", (0, 10, 50, 30))
|
||||
token2 = MockToken("world", (0, 50, 50, 70)) # Different y
|
||||
|
||||
assert matcher._tokens_on_same_line(token1, token2) is False
|
||||
|
||||
|
||||
class TestFieldMatcherBboxOverlap:
|
||||
"""Tests for _bbox_overlap method."""
|
||||
|
||||
def test_full_overlap(self):
|
||||
"""Should return 1.0 for identical bboxes."""
|
||||
matcher = FieldMatcher()
|
||||
bbox = (0, 0, 100, 50)
|
||||
assert matcher._bbox_overlap(bbox, bbox) == 1.0
|
||||
|
||||
def test_partial_overlap(self):
|
||||
"""Should calculate partial overlap correctly."""
|
||||
matcher = FieldMatcher()
|
||||
bbox1 = (0, 0, 100, 100)
|
||||
bbox2 = (50, 50, 150, 150) # 50% overlap on each axis
|
||||
|
||||
overlap = matcher._bbox_overlap(bbox1, bbox2)
|
||||
# Intersection: 50x50=2500, Union: 10000+10000-2500=17500
|
||||
# IoU = 2500/17500 ≈ 0.143
|
||||
assert 0.1 < overlap < 0.2
|
||||
|
||||
def test_no_overlap(self):
|
||||
"""Should return 0.0 for non-overlapping bboxes."""
|
||||
matcher = FieldMatcher()
|
||||
bbox1 = (0, 0, 50, 50)
|
||||
bbox2 = (100, 100, 150, 150)
|
||||
|
||||
assert matcher._bbox_overlap(bbox1, bbox2) == 0.0
|
||||
|
||||
|
||||
class TestFieldMatcherDeduplication:
|
||||
"""Tests for match deduplication."""
|
||||
|
||||
def test_deduplicate_overlapping_matches(self):
|
||||
"""Should keep only highest scoring match for overlapping bboxes."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("12345", (0, 0, 50, 20)),
|
||||
]
|
||||
|
||||
# Find matches with multiple values that could match same token
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345", "12345"])
|
||||
|
||||
# Should deduplicate to single match
|
||||
assert len(matches) == 1
|
||||
|
||||
|
||||
class TestFieldMatcherFlexibleDateMatch:
|
||||
"""Tests for flexible date matching."""
|
||||
|
||||
def test_flexible_date_same_month(self):
|
||||
"""Should match dates in same year-month when exact match fails."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("2025-01-15", (0, 0, 80, 20)), # Slightly different day
|
||||
]
|
||||
|
||||
# Search for different day in same month
|
||||
matches = matcher.find_matches(
|
||||
tokens, "InvoiceDate", ["2025-01-10"]
|
||||
)
|
||||
|
||||
# Should find flexible match (lower score)
|
||||
# Note: This depends on exact match failing first
|
||||
# If exact match works, flexible won't be tried
|
||||
|
||||
|
||||
class TestFieldMatcherPageFiltering:
|
||||
"""Tests for page number filtering."""
|
||||
|
||||
def test_filters_by_page_number(self):
|
||||
"""Should only match tokens on specified page."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("12345", (0, 0, 50, 20), page_no=0),
|
||||
MockToken("12345", (0, 0, 50, 20), page_no=1),
|
||||
]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"], page_no=0)
|
||||
|
||||
assert all(m.page_no == 0 for m in matches)
|
||||
|
||||
def test_excludes_hidden_tokens(self):
|
||||
"""Should exclude tokens with negative y coordinates (metadata)."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [
|
||||
MockToken("12345", (0, -100, 50, -80), page_no=0), # Hidden metadata
|
||||
MockToken("67890", (0, 0, 50, 20), page_no=0), # Visible
|
||||
]
|
||||
|
||||
matches = matcher.find_matches(tokens, "InvoiceNumber", ["12345"], page_no=0)
|
||||
|
||||
# Should not match the hidden token
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
class TestContextKeywordsMapping:
|
||||
"""Tests for CONTEXT_KEYWORDS constant."""
|
||||
|
||||
def test_all_fields_have_keywords(self):
|
||||
"""Should have keywords for all expected fields."""
|
||||
expected_fields = [
|
||||
"InvoiceNumber",
|
||||
"InvoiceDate",
|
||||
"InvoiceDueDate",
|
||||
"OCR",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"Amount",
|
||||
"supplier_organisation_number",
|
||||
"supplier_accounts",
|
||||
]
|
||||
for field in expected_fields:
|
||||
assert field in CONTEXT_KEYWORDS
|
||||
assert len(CONTEXT_KEYWORDS[field]) > 0
|
||||
|
||||
def test_keywords_are_lowercase(self):
|
||||
"""All keywords should be lowercase."""
|
||||
for field, keywords in CONTEXT_KEYWORDS.items():
|
||||
for kw in keywords:
|
||||
assert kw == kw.lower(), f"Keyword '{kw}' in {field} should be lowercase"
|
||||
|
||||
|
||||
class TestFindFieldMatches:
|
||||
"""Tests for find_field_matches convenience function."""
|
||||
|
||||
def test_finds_multiple_fields(self):
|
||||
"""Should find matches for multiple fields."""
|
||||
tokens = [
|
||||
MockToken("12345", (0, 0, 50, 20)),
|
||||
MockToken("100,00", (0, 30, 60, 50)),
|
||||
]
|
||||
field_values = {
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": "100",
|
||||
}
|
||||
|
||||
results = find_field_matches(tokens, field_values)
|
||||
|
||||
assert "InvoiceNumber" in results
|
||||
assert "Amount" in results
|
||||
assert len(results["InvoiceNumber"]) >= 1
|
||||
assert len(results["Amount"]) >= 1
|
||||
|
||||
def test_skips_empty_values(self):
|
||||
"""Should skip fields with None or empty values."""
|
||||
tokens = [MockToken("12345", (0, 0, 50, 20))]
|
||||
field_values = {
|
||||
"InvoiceNumber": "12345",
|
||||
"Amount": None,
|
||||
"OCR": "",
|
||||
}
|
||||
|
||||
results = find_field_matches(tokens, field_values)
|
||||
|
||||
assert "InvoiceNumber" in results
|
||||
assert "Amount" not in results
|
||||
assert "OCR" not in results
|
||||
|
||||
|
||||
class TestSubstringMatchEdgeCases:
|
||||
"""Additional edge case tests for substring matching."""
|
||||
|
||||
def test_unsupported_field_returns_empty(self):
|
||||
"""Should return empty for unsupported field types."""
|
||||
# Line 380: field_name not in supported_fields
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("Faktura: 12345", (0, 0, 100, 20))]
|
||||
|
||||
# Message is not a supported field for substring matching
|
||||
matches = matcher._find_substring_matches(tokens, "12345", "Message")
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_case_insensitive_substring_match(self):
|
||||
"""Should find case-insensitive substring match."""
|
||||
# Line 397-398: case-insensitive substring matching
|
||||
matcher = FieldMatcher()
|
||||
# Use token without inline keyword to isolate case-insensitive behavior
|
||||
tokens = [MockToken("REF: ABC123", (0, 0, 100, 20))]
|
||||
|
||||
matches = matcher._find_substring_matches(tokens, "abc123", "InvoiceNumber")
|
||||
|
||||
assert len(matches) >= 1
|
||||
# Case-insensitive base score is 0.70 (vs 0.75 for case-sensitive)
|
||||
# Score may have context boost but base should be lower
|
||||
assert matches[0].score <= 0.80 # 0.70 base + possible small boost
|
||||
|
||||
def test_substring_with_digit_before(self):
|
||||
"""Should not match when digit appears before value."""
|
||||
# Line 407-408: char_before.isdigit() continue
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("9912345", (0, 0, 60, 20))]
|
||||
|
||||
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber")
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_substring_with_digit_after(self):
|
||||
"""Should not match when digit appears after value."""
|
||||
# Line 413-416: char_after.isdigit() continue
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("12345678", (0, 0, 70, 20))]
|
||||
|
||||
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber")
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_substring_with_inline_keyword(self):
|
||||
"""Should boost score when keyword is in same token."""
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("Fakturanr: 12345", (0, 0, 100, 20))]
|
||||
|
||||
matches = matcher._find_substring_matches(tokens, "12345", "InvoiceNumber")
|
||||
|
||||
assert len(matches) >= 1
|
||||
# Should have inline keyword boost
|
||||
assert "fakturanr" in matches[0].context_keywords
|
||||
|
||||
|
||||
class TestFlexibleDateMatchEdgeCases:
|
||||
"""Additional edge case tests for flexible date matching."""
|
||||
|
||||
def test_no_valid_date_in_normalized_values(self):
|
||||
"""Should return empty when no valid date in normalized values."""
|
||||
# Line 520-521, 524: target_date parsing failures
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("2025-01-15", (0, 0, 80, 20))]
|
||||
|
||||
# Pass non-date values
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["not-a-date", "also-not-date"], "InvoiceDate"
|
||||
)
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_no_date_tokens_found(self):
|
||||
"""Should return empty when no date tokens in document."""
|
||||
# Line 571-572: no date_candidates
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("Hello World", (0, 0, 80, 20))]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_flexible_date_within_7_days(self):
|
||||
"""Should score higher for dates within 7 days."""
|
||||
# Line 582-583: days_diff <= 7
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-01-18", (0, 0, 80, 20)), # 3 days from target
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score >= 0.75
|
||||
|
||||
def test_flexible_date_within_3_days(self):
|
||||
"""Should score highest for dates within 3 days."""
|
||||
# Line 584-585: days_diff <= 3
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-01-17", (0, 0, 80, 20)), # 2 days from target
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score >= 0.8
|
||||
|
||||
def test_flexible_date_within_14_days_different_month(self):
|
||||
"""Should match dates within 14 days even in different month."""
|
||||
# Line 587-588: days_diff <= 14, different year-month
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-02-05", (0, 0, 80, 20)), # 10 days from Jan 26
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-26"], "InvoiceDate"
|
||||
)
|
||||
|
||||
assert len(matches) >= 1
|
||||
|
||||
def test_flexible_date_within_30_days(self):
|
||||
"""Should match dates within 30 days with lower score."""
|
||||
# Line 589-590: days_diff <= 30
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-02-10", (0, 0, 80, 20)), # 25 days from target
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-16"], "InvoiceDate"
|
||||
)
|
||||
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].score >= 0.55
|
||||
|
||||
def test_flexible_date_far_apart_without_context(self):
|
||||
"""Should skip dates too far apart without context keywords."""
|
||||
# Line 591-595: > 30 days, no context
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-06-15", (0, 0, 80, 20)), # Many months from target
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
# Should be empty - too far apart and no context
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_flexible_date_far_with_context(self):
|
||||
"""Should match distant dates if context keywords present."""
|
||||
# Line 592-595: > 30 days but has context
|
||||
matcher = FieldMatcher(min_score_threshold=0.5, context_radius=200)
|
||||
tokens = [
|
||||
MockToken("fakturadatum", (0, 0, 80, 20)), # Context keyword
|
||||
MockToken("2025-06-15", (90, 0, 170, 20)), # Distant date
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
# May match due to context keyword
|
||||
# (depends on how context is detected in flexible match)
|
||||
|
||||
def test_flexible_date_boost_with_context(self):
|
||||
"""Should boost flexible date score with context keywords."""
|
||||
# Line 598, 602-603: context_boost applied
|
||||
matcher = FieldMatcher(min_score_threshold=0.5, context_radius=200)
|
||||
tokens = [
|
||||
MockToken("fakturadatum", (0, 0, 80, 20)),
|
||||
MockToken("2025-01-18", (90, 0, 170, 20)), # 3 days from target
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
if len(matches) > 0:
|
||||
assert len(matches[0].context_keywords) >= 0
|
||||
|
||||
|
||||
class TestContextKeywordFallback:
|
||||
"""Tests for context keyword lookup fallback (no spatial index)."""
|
||||
|
||||
def test_fallback_context_lookup_without_index(self):
|
||||
"""Should find context using O(n) scan when no index available."""
|
||||
# Line 650-673: fallback context lookup
|
||||
matcher = FieldMatcher(context_radius=200)
|
||||
# Don't use find_matches which builds index, call internal method directly
|
||||
|
||||
tokens = [
|
||||
MockToken("fakturanr", (0, 0, 80, 20)),
|
||||
MockToken("12345", (100, 0, 150, 20)),
|
||||
]
|
||||
|
||||
# _token_index is None, so fallback is used
|
||||
keywords, boost = matcher._find_context_keywords(tokens, tokens[1], "InvoiceNumber")
|
||||
|
||||
assert "fakturanr" in keywords
|
||||
assert boost > 0
|
||||
|
||||
def test_context_lookup_skips_self(self):
|
||||
"""Should skip the target token itself in fallback search."""
|
||||
# Line 656-657: token is target_token continue
|
||||
matcher = FieldMatcher(context_radius=200)
|
||||
matcher._token_index = None # Force fallback
|
||||
|
||||
token = MockToken("fakturanr 12345", (0, 0, 150, 20))
|
||||
tokens = [token]
|
||||
|
||||
keywords, boost = matcher._find_context_keywords(tokens, token, "InvoiceNumber")
|
||||
|
||||
# Token contains keyword but is the target - should still find if keyword in token
|
||||
# Actually this tests that it doesn't error when target is in list
|
||||
|
||||
|
||||
class TestFieldWithoutContextKeywords:
|
||||
"""Tests for fields without defined context keywords."""
|
||||
|
||||
def test_field_without_keywords_returns_empty(self):
|
||||
"""Should return empty keywords for fields not in CONTEXT_KEYWORDS."""
|
||||
# Line 633-635: keywords empty, return early
|
||||
matcher = FieldMatcher()
|
||||
matcher._token_index = None
|
||||
|
||||
tokens = [MockToken("hello", (0, 0, 50, 20))]
|
||||
|
||||
# customer_number is not in CONTEXT_KEYWORDS
|
||||
keywords, boost = matcher._find_context_keywords(tokens, tokens[0], "UnknownField")
|
||||
|
||||
assert keywords == []
|
||||
assert boost == 0.0
|
||||
|
||||
|
||||
class TestParseAmountEdgeCases:
|
||||
"""Additional edge case tests for _parse_amount."""
|
||||
|
||||
def test_parse_amount_with_parentheses(self):
|
||||
"""Should remove parenthesized text like (inkl. moms)."""
|
||||
matcher = FieldMatcher()
|
||||
result = matcher._parse_amount("100 (inkl. moms)")
|
||||
assert result == 100.0
|
||||
|
||||
def test_parse_amount_with_kronor_suffix(self):
|
||||
"""Should handle 'kronor' suffix."""
|
||||
matcher = FieldMatcher()
|
||||
result = matcher._parse_amount("100 kronor")
|
||||
assert result == 100.0
|
||||
|
||||
def test_parse_amount_numeric_input(self):
|
||||
"""Should handle numeric input (int/float)."""
|
||||
matcher = FieldMatcher()
|
||||
assert matcher._parse_amount(100) == 100.0
|
||||
assert matcher._parse_amount(100.5) == 100.5
|
||||
|
||||
|
||||
class TestFuzzyMatchExceptionHandling:
|
||||
"""Tests for exception handling in fuzzy matching."""
|
||||
|
||||
def test_fuzzy_match_with_unparseable_token(self):
|
||||
"""Should handle tokens that can't be parsed as amounts."""
|
||||
# Line 481-482: except clause in fuzzy matching
|
||||
matcher = FieldMatcher()
|
||||
# Create a token that will cause parse issues
|
||||
tokens = [MockToken("abc xyz", (0, 0, 50, 20))]
|
||||
|
||||
# This should not raise, just return empty matches
|
||||
matches = matcher._find_fuzzy_matches(tokens, "100", "Amount")
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_fuzzy_match_exception_in_context_lookup(self):
|
||||
"""Should catch exceptions during fuzzy match processing."""
|
||||
# Line 481-482: general exception handler
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("100", (0, 0, 50, 20))]
|
||||
|
||||
# Mock _find_context_keywords to raise an exception
|
||||
with patch.object(matcher, '_find_context_keywords', side_effect=RuntimeError("Test error")):
|
||||
# Should not raise, exception should be caught
|
||||
matches = matcher._find_fuzzy_matches(tokens, "100", "Amount")
|
||||
# Should return empty due to exception
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
class TestFlexibleDateInvalidDateParsing:
|
||||
"""Tests for invalid date parsing in flexible date matching."""
|
||||
|
||||
def test_invalid_date_in_normalized_values(self):
|
||||
"""Should handle invalid dates in normalized values gracefully."""
|
||||
# Line 520-521: ValueError continue in target date parsing
|
||||
matcher = FieldMatcher()
|
||||
tokens = [MockToken("2025-01-15", (0, 0, 80, 20))]
|
||||
|
||||
# Pass an invalid date that matches the pattern but is not a valid date
|
||||
# e.g., 2025-13-45 matches pattern but month 13 is invalid
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-13-45"], "InvoiceDate"
|
||||
)
|
||||
# Should return empty as no valid target date could be parsed
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_invalid_date_token_in_document(self):
|
||||
"""Should skip invalid date-like tokens in document."""
|
||||
# Line 568-569: ValueError continue in date token parsing
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("2025-99-99", (0, 0, 80, 20)), # Invalid date in doc
|
||||
MockToken("2025-01-18", (0, 50, 80, 70)), # Valid date
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
# Should only match the valid date
|
||||
assert len(matches) >= 1
|
||||
assert matches[0].value == "2025-01-18"
|
||||
|
||||
def test_flexible_date_with_inline_keyword(self):
|
||||
"""Should detect inline keywords in date tokens."""
|
||||
# Line 555: inline_keywords append
|
||||
matcher = FieldMatcher(min_score_threshold=0.5)
|
||||
tokens = [
|
||||
MockToken("Fakturadatum: 2025-01-18", (0, 0, 150, 20)),
|
||||
]
|
||||
|
||||
matches = matcher._find_flexible_date_matches(
|
||||
tokens, ["2025-01-15"], "InvoiceDate"
|
||||
)
|
||||
|
||||
# Should find match with inline keyword
|
||||
assert len(matches) >= 1
|
||||
assert "fakturadatum" in matches[0].context_keywords
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
92
src/matcher/token_index.py
Normal file
92
src/matcher/token_index.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""
|
||||
Spatial index for fast token lookup.
|
||||
"""
|
||||
|
||||
from .models import TokenLike
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
91
src/matcher/utils.py
Normal file
91
src/matcher/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Utility functions for field matching.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
DASH_PATTERN = re.compile(r'[\u2013\u2014\u2212\u00b7]') # en-dash, em-dash, minus sign, middle dot
|
||||
|
||||
|
||||
def normalize_dashes(text: str) -> str:
|
||||
"""Normalize different dash types and middle dots to standard hyphen-minus (ASCII 45)."""
|
||||
return DASH_PATTERN.sub('-', text)
|
||||
|
||||
|
||||
def parse_amount(text: str | int | float) -> float | None:
|
||||
"""Try to parse text as a monetary amount."""
|
||||
# Convert to string first
|
||||
text = str(text)
|
||||
|
||||
# First, handle Swedish öre format: "239 00" means 239.00 (239 kr 00 öre)
|
||||
# Pattern: digits + space + exactly 2 digits at end
|
||||
ore_match = re.match(r'^(\d+)\s+(\d{2})$', text.strip())
|
||||
if ore_match:
|
||||
kronor = ore_match.group(1)
|
||||
ore = ore_match.group(2)
|
||||
try:
|
||||
return float(f"{kronor}.{ore}")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Remove everything after and including parentheses (e.g., "(inkl. moms)")
|
||||
text = re.sub(r'\s*\(.*\)', '', text)
|
||||
|
||||
# Remove currency symbols and common suffixes (including trailing dots from "kr.")
|
||||
text = re.sub(r'\b(SEK|kr|kronor|öre)\b\.?', '', text, flags=re.IGNORECASE)
|
||||
text = re.sub(r'[:-]', '', text)
|
||||
|
||||
# Remove spaces (thousand separators) but be careful with öre format
|
||||
text = text.replace(' ', '').replace('\xa0', '')
|
||||
|
||||
# Handle comma as decimal separator
|
||||
# Swedish format: "500,00" means 500.00
|
||||
# Need to handle cases like "500,00." (after removing "kr.")
|
||||
if ',' in text:
|
||||
# Remove any trailing dots first (from "kr." removal)
|
||||
text = text.rstrip('.')
|
||||
# Now replace comma with dot
|
||||
if '.' not in text:
|
||||
text = text.replace(',', '.')
|
||||
|
||||
# Remove any remaining non-numeric characters except dot
|
||||
text = re.sub(r'[^\d.]', '', text)
|
||||
|
||||
try:
|
||||
return float(text)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def tokens_on_same_line(token1, token2) -> bool:
|
||||
"""Check if two tokens are on the same line."""
|
||||
# Check vertical overlap
|
||||
y_overlap = min(token1.bbox[3], token2.bbox[3]) - max(token1.bbox[1], token2.bbox[1])
|
||||
min_height = min(token1.bbox[3] - token1.bbox[1], token2.bbox[3] - token2.bbox[1])
|
||||
return y_overlap > min_height * 0.5
|
||||
|
||||
|
||||
def bbox_overlap(
|
||||
bbox1: tuple[float, float, float, float],
|
||||
bbox2: tuple[float, float, float, float]
|
||||
) -> float:
|
||||
"""Calculate IoU (Intersection over Union) of two bounding boxes."""
|
||||
x1 = max(bbox1[0], bbox2[0])
|
||||
y1 = max(bbox1[1], bbox2[1])
|
||||
x2 = min(bbox1[2], bbox2[2])
|
||||
y2 = min(bbox1[3], bbox2[3])
|
||||
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
Reference in New Issue
Block a user