This commit is contained in:
Yaojia Wang
2026-01-13 00:10:27 +01:00
parent 1b7c61cdd8
commit b26fd61852
43 changed files with 7751 additions and 578 deletions

View File

@@ -4,9 +4,16 @@ Field Matching Module
Matches normalized field values to tokens extracted from documents.
"""
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Protocol
import re
from functools import cached_property
# Pre-compiled regex patterns (module-level for efficiency)
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
_WHITESPACE_PATTERN = re.compile(r'\s+')
_NON_DIGIT_PATTERN = re.compile(r'\D')
class TokenLike(Protocol):
@@ -16,6 +23,93 @@ class TokenLike(Protocol):
page_no: int
class TokenIndex:
"""
Spatial index for tokens to enable fast nearby token lookup.
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
"""
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
"""
Build spatial index from tokens.
Args:
tokens: List of tokens to index
grid_size: Size of grid cells in pixels
"""
self.tokens = tokens
self.grid_size = grid_size
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
self._token_centers: dict[int, tuple[float, float]] = {}
self._token_text_lower: dict[int, str] = {}
# Build index
for i, token in enumerate(tokens):
# Cache center coordinates
center_x = (token.bbox[0] + token.bbox[2]) / 2
center_y = (token.bbox[1] + token.bbox[3]) / 2
self._token_centers[id(token)] = (center_x, center_y)
# Cache lowercased text
self._token_text_lower[id(token)] = token.text.lower()
# Add to grid cell
grid_x = int(center_x / grid_size)
grid_y = int(center_y / grid_size)
key = (grid_x, grid_y)
if key not in self._grid:
self._grid[key] = []
self._grid[key].append(token)
def get_center(self, token: TokenLike) -> tuple[float, float]:
"""Get cached center coordinates for token."""
return self._token_centers.get(id(token), (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
))
def get_text_lower(self, token: TokenLike) -> str:
"""Get cached lowercased text for token."""
return self._token_text_lower.get(id(token), token.text.lower())
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
"""
Find all tokens within radius of the given token.
Uses grid-based lookup for O(1) average case instead of O(n).
"""
center = self.get_center(token)
center_x, center_y = center
# Determine which grid cells to search
cells_to_check = int(radius / self.grid_size) + 1
grid_x = int(center_x / self.grid_size)
grid_y = int(center_y / self.grid_size)
nearby = []
radius_sq = radius * radius
# Check all nearby grid cells
for dx in range(-cells_to_check, cells_to_check + 1):
for dy in range(-cells_to_check, cells_to_check + 1):
key = (grid_x + dx, grid_y + dy)
if key not in self._grid:
continue
for other in self._grid[key]:
if other is token:
continue
other_center = self.get_center(other)
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
if dist_sq <= radius_sq:
nearby.append(other)
return nearby
@dataclass
class Match:
"""Represents a matched field in the document."""
@@ -57,18 +151,20 @@ class FieldMatcher:
def __init__(
self,
context_radius: float = 100.0, # pixels
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
min_score_threshold: float = 0.5
):
"""
Initialize the matcher.
Args:
context_radius: Distance to search for context keywords
context_radius: Distance to search for context keywords (default 200px to handle
typical label-value spacing in scanned invoices at 150 DPI)
min_score_threshold: Minimum score to consider a match valid
"""
self.context_radius = context_radius
self.min_score_threshold = min_score_threshold
self._token_index: TokenIndex | None = None
def find_matches(
self,
@@ -92,6 +188,9 @@ class FieldMatcher:
matches = []
page_tokens = [t for t in tokens if t.page_no == page_no]
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
for value in normalized_values:
# Strategy 1: Exact token match
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
@@ -108,7 +207,7 @@ class FieldMatcher:
# Strategy 4: Substring match (for values embedded in longer text)
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
matches.extend(substring_matches)
@@ -124,6 +223,9 @@ class FieldMatcher:
matches = self._deduplicate_matches(matches)
matches.sort(key=lambda m: m.score, reverse=True)
# Clear token index to free memory
self._token_index = None
return [m for m in matches if m.score >= self.min_score_threshold]
def _find_exact_matches(
@@ -134,6 +236,8 @@ class FieldMatcher:
) -> list[Match]:
"""Find tokens that exactly match the value."""
matches = []
value_lower = value.lower()
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
for token in tokens:
token_text = token.text.strip()
@@ -141,13 +245,12 @@ class FieldMatcher:
# Exact match
if token_text == value:
score = 1.0
# Case-insensitive match
elif token_text.lower() == value.lower():
# Case-insensitive match (use cached lowercase from index)
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
score = 0.95
# Digits-only match for numeric fields
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
token_digits = re.sub(r'\D', '', token_text)
value_digits = re.sub(r'\D', '', value)
elif value_digits is not None:
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
if token_digits and token_digits == value_digits:
score = 0.9
else:
@@ -181,7 +284,7 @@ class FieldMatcher:
) -> list[Match]:
"""Find value by concatenating adjacent tokens."""
matches = []
value_clean = re.sub(r'\s+', '', value)
value_clean = _WHITESPACE_PATTERN.sub('', value)
# Sort tokens by position (top-to-bottom, left-to-right)
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
@@ -213,7 +316,7 @@ class FieldMatcher:
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
# Check for match
concat_clean = re.sub(r'\s+', '', concat_text)
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
if concat_clean == value_clean:
context_keywords, context_boost = self._find_context_keywords(
tokens, start_token, field_name
@@ -252,7 +355,7 @@ class FieldMatcher:
matches = []
# Supported fields for substring matching
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro')
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
if field_name not in supported_fields:
return matches
@@ -390,13 +493,12 @@ class FieldMatcher:
# Find all date-like tokens in the document
date_candidates = []
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
for token in tokens:
token_text = token.text.strip()
# Search for date pattern in token
for match in date_pattern.finditer(token_text):
# Search for date pattern in token (use pre-compiled pattern)
for match in _DATE_PATTERN.finditer(token_text):
try:
found_date = datetime(
int(match.group(1)),
@@ -491,38 +593,56 @@ class FieldMatcher:
target_token: TokenLike,
field_name: str
) -> tuple[list[str], float]:
"""Find context keywords near the target token."""
"""
Find context keywords near the target token.
Uses spatial index for O(1) average lookup instead of O(n) scan.
"""
keywords = CONTEXT_KEYWORDS.get(field_name, [])
if not keywords:
return [], 0.0
found_keywords = []
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
# Calculate distance
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
# Use spatial index for efficient nearby token lookup
if self._token_index:
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
for token in nearby_tokens:
# Use cached lowercase text
token_lower = self._token_index.get_text_lower(token)
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
else:
# Fallback to O(n) scan if no index available
target_center = (
(target_token.bbox[0] + target_token.bbox[2]) / 2,
(target_token.bbox[1] + target_token.bbox[3]) / 2
)
for token in tokens:
if token is target_token:
continue
token_center = (
(token.bbox[0] + token.bbox[2]) / 2,
(token.bbox[1] + token.bbox[3]) / 2
)
distance = (
(target_center[0] - token_center[0]) ** 2 +
(target_center[1] - token_center[1]) ** 2
) ** 0.5
if distance <= self.context_radius:
token_lower = token.text.lower()
for keyword in keywords:
if keyword in token_lower:
found_keywords.append(keyword)
# Calculate boost based on keywords found
boost = min(0.15, len(found_keywords) * 0.05)
# Increased boost to better differentiate matches with/without context
boost = min(0.25, len(found_keywords) * 0.10)
return found_keywords, boost
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
@@ -548,23 +668,62 @@ class FieldMatcher:
return None
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
"""Remove duplicate matches based on bbox overlap."""
"""
Remove duplicate matches based on bbox overlap.
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
"""
if not matches:
return []
# Sort by score descending
matches.sort(key=lambda m: m.score, reverse=True)
# Sort by: 1) score descending, 2) prefer matches with context keywords,
# 3) prefer upper positions (smaller y) for same-score matches
# This helps select the "main" occurrence in invoice body rather than footer
matches.sort(key=lambda m: (
-m.score,
-len(m.context_keywords), # More keywords = better
m.bbox[1] # Smaller y (upper position) = better
))
# Use spatial grid for efficient overlap checking
# Grid cell size based on typical bbox size
grid_size = 50.0 # pixels
grid: dict[tuple[int, int], list[Match]] = {}
unique = []
for match in matches:
bbox = match.bbox
# Calculate grid cells this bbox touches
min_gx = int(bbox[0] / grid_size)
min_gy = int(bbox[1] / grid_size)
max_gx = int(bbox[2] / grid_size)
max_gy = int(bbox[3] / grid_size)
# Check for overlap only with matches in nearby grid cells
is_duplicate = False
for existing in unique:
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7:
is_duplicate = True
break
cells_to_check = set()
for gx in range(min_gx - 1, max_gx + 2):
for gy in range(min_gy - 1, max_gy + 2):
cells_to_check.add((gx, gy))
for cell in cells_to_check:
if cell in grid:
for existing in grid[cell]:
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
is_duplicate = True
break
if is_duplicate:
break
if not is_duplicate:
unique.append(match)
# Add to all grid cells this bbox touches
for gx in range(min_gx, max_gx + 1):
for gy in range(min_gy, max_gy + 1):
key = (gx, gy)
if key not in grid:
grid[key] = []
grid[key].append(match)
return unique
@@ -582,9 +741,9 @@ class FieldMatcher:
if x2 <= x1 or y2 <= y1:
return 0.0
intersection = (x2 - x1) * (y2 - y1)
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
intersection = float(x2 - x1) * float(y2 - y1)
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
union = area1 + area2 - intersection
return intersection / union if union > 0 else 0.0