WOP
This commit is contained in:
@@ -4,9 +4,16 @@ Field Matching Module
|
||||
Matches normalized field values to tokens extracted from documents.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Protocol
|
||||
import re
|
||||
from functools import cached_property
|
||||
|
||||
|
||||
# Pre-compiled regex patterns (module-level for efficiency)
|
||||
_DATE_PATTERN = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
_WHITESPACE_PATTERN = re.compile(r'\s+')
|
||||
_NON_DIGIT_PATTERN = re.compile(r'\D')
|
||||
|
||||
|
||||
class TokenLike(Protocol):
|
||||
@@ -16,6 +23,93 @@ class TokenLike(Protocol):
|
||||
page_no: int
|
||||
|
||||
|
||||
class TokenIndex:
|
||||
"""
|
||||
Spatial index for tokens to enable fast nearby token lookup.
|
||||
|
||||
Uses grid-based spatial hashing for O(1) average lookup instead of O(n).
|
||||
"""
|
||||
|
||||
def __init__(self, tokens: list[TokenLike], grid_size: float = 100.0):
|
||||
"""
|
||||
Build spatial index from tokens.
|
||||
|
||||
Args:
|
||||
tokens: List of tokens to index
|
||||
grid_size: Size of grid cells in pixels
|
||||
"""
|
||||
self.tokens = tokens
|
||||
self.grid_size = grid_size
|
||||
self._grid: dict[tuple[int, int], list[TokenLike]] = {}
|
||||
self._token_centers: dict[int, tuple[float, float]] = {}
|
||||
self._token_text_lower: dict[int, str] = {}
|
||||
|
||||
# Build index
|
||||
for i, token in enumerate(tokens):
|
||||
# Cache center coordinates
|
||||
center_x = (token.bbox[0] + token.bbox[2]) / 2
|
||||
center_y = (token.bbox[1] + token.bbox[3]) / 2
|
||||
self._token_centers[id(token)] = (center_x, center_y)
|
||||
|
||||
# Cache lowercased text
|
||||
self._token_text_lower[id(token)] = token.text.lower()
|
||||
|
||||
# Add to grid cell
|
||||
grid_x = int(center_x / grid_size)
|
||||
grid_y = int(center_y / grid_size)
|
||||
key = (grid_x, grid_y)
|
||||
if key not in self._grid:
|
||||
self._grid[key] = []
|
||||
self._grid[key].append(token)
|
||||
|
||||
def get_center(self, token: TokenLike) -> tuple[float, float]:
|
||||
"""Get cached center coordinates for token."""
|
||||
return self._token_centers.get(id(token), (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
))
|
||||
|
||||
def get_text_lower(self, token: TokenLike) -> str:
|
||||
"""Get cached lowercased text for token."""
|
||||
return self._token_text_lower.get(id(token), token.text.lower())
|
||||
|
||||
def find_nearby(self, token: TokenLike, radius: float) -> list[TokenLike]:
|
||||
"""
|
||||
Find all tokens within radius of the given token.
|
||||
|
||||
Uses grid-based lookup for O(1) average case instead of O(n).
|
||||
"""
|
||||
center = self.get_center(token)
|
||||
center_x, center_y = center
|
||||
|
||||
# Determine which grid cells to search
|
||||
cells_to_check = int(radius / self.grid_size) + 1
|
||||
grid_x = int(center_x / self.grid_size)
|
||||
grid_y = int(center_y / self.grid_size)
|
||||
|
||||
nearby = []
|
||||
radius_sq = radius * radius
|
||||
|
||||
# Check all nearby grid cells
|
||||
for dx in range(-cells_to_check, cells_to_check + 1):
|
||||
for dy in range(-cells_to_check, cells_to_check + 1):
|
||||
key = (grid_x + dx, grid_y + dy)
|
||||
if key not in self._grid:
|
||||
continue
|
||||
|
||||
for other in self._grid[key]:
|
||||
if other is token:
|
||||
continue
|
||||
|
||||
other_center = self.get_center(other)
|
||||
dist_sq = (center_x - other_center[0]) ** 2 + (center_y - other_center[1]) ** 2
|
||||
|
||||
if dist_sq <= radius_sq:
|
||||
nearby.append(other)
|
||||
|
||||
return nearby
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
"""Represents a matched field in the document."""
|
||||
@@ -57,18 +151,20 @@ class FieldMatcher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_radius: float = 100.0, # pixels
|
||||
context_radius: float = 200.0, # pixels - increased to handle label-value spacing in scanned PDFs
|
||||
min_score_threshold: float = 0.5
|
||||
):
|
||||
"""
|
||||
Initialize the matcher.
|
||||
|
||||
Args:
|
||||
context_radius: Distance to search for context keywords
|
||||
context_radius: Distance to search for context keywords (default 200px to handle
|
||||
typical label-value spacing in scanned invoices at 150 DPI)
|
||||
min_score_threshold: Minimum score to consider a match valid
|
||||
"""
|
||||
self.context_radius = context_radius
|
||||
self.min_score_threshold = min_score_threshold
|
||||
self._token_index: TokenIndex | None = None
|
||||
|
||||
def find_matches(
|
||||
self,
|
||||
@@ -92,6 +188,9 @@ class FieldMatcher:
|
||||
matches = []
|
||||
page_tokens = [t for t in tokens if t.page_no == page_no]
|
||||
|
||||
# Build spatial index for efficient nearby token lookup (O(n) -> O(1))
|
||||
self._token_index = TokenIndex(page_tokens, grid_size=self.context_radius)
|
||||
|
||||
for value in normalized_values:
|
||||
# Strategy 1: Exact token match
|
||||
exact_matches = self._find_exact_matches(page_tokens, value, field_name)
|
||||
@@ -108,7 +207,7 @@ class FieldMatcher:
|
||||
|
||||
# Strategy 4: Substring match (for values embedded in longer text)
|
||||
# e.g., "Fakturanummer: 2465027205" should match OCR value "2465027205"
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
|
||||
if field_name in ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount'):
|
||||
substring_matches = self._find_substring_matches(page_tokens, value, field_name)
|
||||
matches.extend(substring_matches)
|
||||
|
||||
@@ -124,6 +223,9 @@ class FieldMatcher:
|
||||
matches = self._deduplicate_matches(matches)
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
|
||||
# Clear token index to free memory
|
||||
self._token_index = None
|
||||
|
||||
return [m for m in matches if m.score >= self.min_score_threshold]
|
||||
|
||||
def _find_exact_matches(
|
||||
@@ -134,6 +236,8 @@ class FieldMatcher:
|
||||
) -> list[Match]:
|
||||
"""Find tokens that exactly match the value."""
|
||||
matches = []
|
||||
value_lower = value.lower()
|
||||
value_digits = _NON_DIGIT_PATTERN.sub('', value) if field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro') else None
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
@@ -141,13 +245,12 @@ class FieldMatcher:
|
||||
# Exact match
|
||||
if token_text == value:
|
||||
score = 1.0
|
||||
# Case-insensitive match
|
||||
elif token_text.lower() == value.lower():
|
||||
# Case-insensitive match (use cached lowercase from index)
|
||||
elif self._token_index and self._token_index.get_text_lower(token).strip() == value_lower:
|
||||
score = 0.95
|
||||
# Digits-only match for numeric fields
|
||||
elif field_name in ('InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro'):
|
||||
token_digits = re.sub(r'\D', '', token_text)
|
||||
value_digits = re.sub(r'\D', '', value)
|
||||
elif value_digits is not None:
|
||||
token_digits = _NON_DIGIT_PATTERN.sub('', token_text)
|
||||
if token_digits and token_digits == value_digits:
|
||||
score = 0.9
|
||||
else:
|
||||
@@ -181,7 +284,7 @@ class FieldMatcher:
|
||||
) -> list[Match]:
|
||||
"""Find value by concatenating adjacent tokens."""
|
||||
matches = []
|
||||
value_clean = re.sub(r'\s+', '', value)
|
||||
value_clean = _WHITESPACE_PATTERN.sub('', value)
|
||||
|
||||
# Sort tokens by position (top-to-bottom, left-to-right)
|
||||
sorted_tokens = sorted(tokens, key=lambda t: (t.bbox[1], t.bbox[0]))
|
||||
@@ -213,7 +316,7 @@ class FieldMatcher:
|
||||
concat_bbox[3] = max(concat_bbox[3], next_token.bbox[3])
|
||||
|
||||
# Check for match
|
||||
concat_clean = re.sub(r'\s+', '', concat_text)
|
||||
concat_clean = _WHITESPACE_PATTERN.sub('', concat_text)
|
||||
if concat_clean == value_clean:
|
||||
context_keywords, context_boost = self._find_context_keywords(
|
||||
tokens, start_token, field_name
|
||||
@@ -252,7 +355,7 @@ class FieldMatcher:
|
||||
matches = []
|
||||
|
||||
# Supported fields for substring matching
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro')
|
||||
supported_fields = ('InvoiceDate', 'InvoiceDueDate', 'InvoiceNumber', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount')
|
||||
if field_name not in supported_fields:
|
||||
return matches
|
||||
|
||||
@@ -390,13 +493,12 @@ class FieldMatcher:
|
||||
|
||||
# Find all date-like tokens in the document
|
||||
date_candidates = []
|
||||
date_pattern = re.compile(r'(\d{4})-(\d{2})-(\d{2})')
|
||||
|
||||
for token in tokens:
|
||||
token_text = token.text.strip()
|
||||
|
||||
# Search for date pattern in token
|
||||
for match in date_pattern.finditer(token_text):
|
||||
# Search for date pattern in token (use pre-compiled pattern)
|
||||
for match in _DATE_PATTERN.finditer(token_text):
|
||||
try:
|
||||
found_date = datetime(
|
||||
int(match.group(1)),
|
||||
@@ -491,38 +593,56 @@ class FieldMatcher:
|
||||
target_token: TokenLike,
|
||||
field_name: str
|
||||
) -> tuple[list[str], float]:
|
||||
"""Find context keywords near the target token."""
|
||||
"""
|
||||
Find context keywords near the target token.
|
||||
|
||||
Uses spatial index for O(1) average lookup instead of O(n) scan.
|
||||
"""
|
||||
keywords = CONTEXT_KEYWORDS.get(field_name, [])
|
||||
if not keywords:
|
||||
return [], 0.0
|
||||
|
||||
found_keywords = []
|
||||
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
# Calculate distance
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= self.context_radius:
|
||||
token_lower = token.text.lower()
|
||||
# Use spatial index for efficient nearby token lookup
|
||||
if self._token_index:
|
||||
nearby_tokens = self._token_index.find_nearby(target_token, self.context_radius)
|
||||
for token in nearby_tokens:
|
||||
# Use cached lowercase text
|
||||
token_lower = self._token_index.get_text_lower(token)
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
else:
|
||||
# Fallback to O(n) scan if no index available
|
||||
target_center = (
|
||||
(target_token.bbox[0] + target_token.bbox[2]) / 2,
|
||||
(target_token.bbox[1] + target_token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
for token in tokens:
|
||||
if token is target_token:
|
||||
continue
|
||||
|
||||
token_center = (
|
||||
(token.bbox[0] + token.bbox[2]) / 2,
|
||||
(token.bbox[1] + token.bbox[3]) / 2
|
||||
)
|
||||
|
||||
distance = (
|
||||
(target_center[0] - token_center[0]) ** 2 +
|
||||
(target_center[1] - token_center[1]) ** 2
|
||||
) ** 0.5
|
||||
|
||||
if distance <= self.context_radius:
|
||||
token_lower = token.text.lower()
|
||||
for keyword in keywords:
|
||||
if keyword in token_lower:
|
||||
found_keywords.append(keyword)
|
||||
|
||||
# Calculate boost based on keywords found
|
||||
boost = min(0.15, len(found_keywords) * 0.05)
|
||||
# Increased boost to better differentiate matches with/without context
|
||||
boost = min(0.25, len(found_keywords) * 0.10)
|
||||
return found_keywords, boost
|
||||
|
||||
def _tokens_on_same_line(self, token1: TokenLike, token2: TokenLike) -> bool:
|
||||
@@ -548,23 +668,62 @@ class FieldMatcher:
|
||||
return None
|
||||
|
||||
def _deduplicate_matches(self, matches: list[Match]) -> list[Match]:
|
||||
"""Remove duplicate matches based on bbox overlap."""
|
||||
"""
|
||||
Remove duplicate matches based on bbox overlap.
|
||||
|
||||
Uses grid-based spatial hashing to reduce O(n²) to O(n) average case.
|
||||
"""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Sort by score descending
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
# Sort by: 1) score descending, 2) prefer matches with context keywords,
|
||||
# 3) prefer upper positions (smaller y) for same-score matches
|
||||
# This helps select the "main" occurrence in invoice body rather than footer
|
||||
matches.sort(key=lambda m: (
|
||||
-m.score,
|
||||
-len(m.context_keywords), # More keywords = better
|
||||
m.bbox[1] # Smaller y (upper position) = better
|
||||
))
|
||||
|
||||
# Use spatial grid for efficient overlap checking
|
||||
# Grid cell size based on typical bbox size
|
||||
grid_size = 50.0 # pixels
|
||||
grid: dict[tuple[int, int], list[Match]] = {}
|
||||
unique = []
|
||||
|
||||
for match in matches:
|
||||
bbox = match.bbox
|
||||
# Calculate grid cells this bbox touches
|
||||
min_gx = int(bbox[0] / grid_size)
|
||||
min_gy = int(bbox[1] / grid_size)
|
||||
max_gx = int(bbox[2] / grid_size)
|
||||
max_gy = int(bbox[3] / grid_size)
|
||||
|
||||
# Check for overlap only with matches in nearby grid cells
|
||||
is_duplicate = False
|
||||
for existing in unique:
|
||||
if self._bbox_overlap(match.bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
cells_to_check = set()
|
||||
for gx in range(min_gx - 1, max_gx + 2):
|
||||
for gy in range(min_gy - 1, max_gy + 2):
|
||||
cells_to_check.add((gx, gy))
|
||||
|
||||
for cell in cells_to_check:
|
||||
if cell in grid:
|
||||
for existing in grid[cell]:
|
||||
if self._bbox_overlap(bbox, existing.bbox) > 0.7:
|
||||
is_duplicate = True
|
||||
break
|
||||
if is_duplicate:
|
||||
break
|
||||
|
||||
if not is_duplicate:
|
||||
unique.append(match)
|
||||
# Add to all grid cells this bbox touches
|
||||
for gx in range(min_gx, max_gx + 1):
|
||||
for gy in range(min_gy, max_gy + 1):
|
||||
key = (gx, gy)
|
||||
if key not in grid:
|
||||
grid[key] = []
|
||||
grid[key].append(match)
|
||||
|
||||
return unique
|
||||
|
||||
@@ -582,9 +741,9 @@ class FieldMatcher:
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
return 0.0
|
||||
|
||||
intersection = (x2 - x1) * (y2 - y1)
|
||||
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
|
||||
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])
|
||||
intersection = float(x2 - x1) * float(y2 - y1)
|
||||
area1 = float(bbox1[2] - bbox1[0]) * float(bbox1[3] - bbox1[1])
|
||||
area2 = float(bbox2[2] - bbox2[0]) * float(bbox2[3] - bbox2[1])
|
||||
union = area1 + area2 - intersection
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
Reference in New Issue
Block a user