This commit is contained in:
Yaojia Wang
2026-01-22 22:03:24 +01:00
parent 4ea4bc96d4
commit 8fd61ea928
19 changed files with 4069 additions and 226 deletions

393
src/utils/validators.py Normal file
View File

@@ -0,0 +1,393 @@
"""
Field Validators Module
Provides validation functions for Swedish invoice fields.
Used by both inference (to validate extracted values) and matching (to filter candidates).
"""
import re
from datetime import datetime
from typing import Optional
from .text_cleaner import TextCleaner
class FieldValidators:
"""
Validators for Swedish invoice field values.
Includes:
- Luhn (Mod10) checksum validation
- Format validation for specific field types
- Range validation for dates and amounts
"""
# =========================================================================
# Luhn (Mod10) Checksum
# =========================================================================
@classmethod
def luhn_checksum(cls, digits: str) -> bool:
"""
Validate using Luhn (Mod10) algorithm.
Used for:
- Bankgiro numbers
- Plusgiro numbers
- OCR reference numbers
- Swedish organization numbers
The checksum is valid if the total modulo 10 equals 0.
"""
# 只保留数字
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
if not digits or not digits.isdigit():
return False
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 1: # 从右往左,每隔一位加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
return total % 10 == 0
@classmethod
def calculate_luhn_check_digit(cls, digits: str) -> int:
"""
Calculate the Luhn check digit for a number.
Given a number without check digit, returns the digit that would make it valid.
"""
digits = TextCleaner.extract_digits(digits, apply_ocr_correction=False)
# 计算现有数字的 Luhn 和
total = 0
for i, char in enumerate(reversed(digits)):
digit = int(char)
if i % 2 == 0: # 注意:因为还要加一位,所以偶数位置加倍
digit *= 2
if digit > 9:
digit -= 9
total += digit
# 计算需要的校验位
check_digit = (10 - (total % 10)) % 10
return check_digit
# =========================================================================
# Organisation Number Validation
# =========================================================================
@classmethod
def is_valid_organisation_number(cls, value: str) -> bool:
"""
Validate Swedish organisation number.
Format: NNNNNN-NNNN (10 digits)
- First digit: 1-9
- Third digit: >= 2 (distinguishes from personal numbers)
- Last digit: Luhn check digit
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
# 处理 VAT 格式
if len(digits) == 12 and digits.endswith('01'):
digits = digits[:10]
elif len(digits) == 14 and digits.startswith('46') and digits.endswith('01'):
digits = digits[2:12]
if len(digits) != 10:
return False
# 第一位 1-9
if digits[0] == '0':
return False
# 第三位 >= 2 (区分组织号和个人号)
# 注意:有些特殊组织可能不符合此规则,所以这里放宽
# if int(digits[2]) < 2:
# return False
# Luhn 校验
return cls.luhn_checksum(digits)
# =========================================================================
# Bankgiro Validation
# =========================================================================
@classmethod
def is_valid_bankgiro(cls, value: str) -> bool:
"""
Validate Swedish Bankgiro number.
Format: 7 or 8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 7 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_bankgiro(cls, value: str) -> Optional[str]:
"""
Format Bankgiro number to standard format.
Returns: XXX-XXXX (7 digits) or XXXX-XXXX (8 digits), or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) == 7:
return f"{digits[:3]}-{digits[3:]}"
elif len(digits) == 8:
return f"{digits[:4]}-{digits[4:]}"
else:
return None
# =========================================================================
# Plusgiro Validation
# =========================================================================
@classmethod
def is_valid_plusgiro(cls, value: str) -> bool:
"""
Validate Swedish Plusgiro number.
Format: 2-8 digits with Luhn checksum
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return False
return cls.luhn_checksum(digits)
@classmethod
def format_plusgiro(cls, value: str) -> Optional[str]:
"""
Format Plusgiro number to standard format.
Returns: XXXXXXX-X format, or None if invalid
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 2 or len(digits) > 8:
return None
return f"{digits[:-1]}-{digits[-1]}"
# =========================================================================
# OCR Number Validation
# =========================================================================
@classmethod
def is_valid_ocr_number(cls, value: str, validate_checksum: bool = True) -> bool:
"""
Validate Swedish OCR reference number.
- Typically 10-25 digits
- Usually has Luhn checksum (but not always enforced)
"""
digits = TextCleaner.extract_digits(value, apply_ocr_correction=True)
if len(digits) < 5 or len(digits) > 25:
return False
if validate_checksum:
return cls.luhn_checksum(digits)
return True
# =========================================================================
# Amount Validation
# =========================================================================
@classmethod
def is_valid_amount(cls, value: str, min_amount: float = 0.0, max_amount: float = 10_000_000.0) -> bool:
"""
Validate monetary amount.
- Must be positive (or at least >= min_amount)
- Should be within reasonable range
"""
try:
# 尝试解析
text = TextCleaner.normalize_amount_text(value)
# 统一为点作为小数分隔符
text = text.replace(' ', '').replace(',', '.')
# 如果有多个点,保留最后一个
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
amount = float(text)
return min_amount <= amount <= max_amount
except (ValueError, TypeError):
return False
@classmethod
def parse_amount(cls, value: str) -> Optional[float]:
"""
Parse amount from string, handling various formats.
Returns float or None if parsing fails.
"""
try:
text = TextCleaner.normalize_amount_text(value)
text = text.replace(' ', '')
# 检测格式并解析
# 瑞典/德国格式: 逗号是小数点
if re.match(r'^[\d.]+,\d{1,2}$', text):
text = text.replace('.', '').replace(',', '.')
# 美国格式: 点是小数点
elif re.match(r'^[\d,]+\.\d{1,2}$', text):
text = text.replace(',', '')
else:
# 简单格式
text = text.replace(',', '.')
if text.count('.') > 1:
parts = text.rsplit('.', 1)
text = parts[0].replace('.', '') + '.' + parts[1]
return float(text)
except (ValueError, TypeError):
return None
# =========================================================================
# Date Validation
# =========================================================================
@classmethod
def is_valid_date(cls, value: str, min_year: int = 2000, max_year: int = 2100) -> bool:
"""
Validate date string.
- Year should be within reasonable range
- Month 1-12
- Day 1-31 (basic check)
"""
parsed = cls.parse_date(value)
if parsed is None:
return False
year, month, day = parsed
if not (min_year <= year <= max_year):
return False
if not (1 <= month <= 12):
return False
if not (1 <= day <= 31):
return False
# 更精确的日期验证
try:
datetime(year, month, day)
return True
except ValueError:
return False
@classmethod
def parse_date(cls, value: str) -> Optional[tuple[int, int, int]]:
"""
Parse date from string.
Returns (year, month, day) tuple or None.
"""
from .format_variants import FormatVariants
return FormatVariants._parse_date(value)
@classmethod
def format_date_iso(cls, value: str) -> Optional[str]:
"""
Format date to ISO format (YYYY-MM-DD).
Returns formatted string or None if parsing fails.
"""
parsed = cls.parse_date(value)
if parsed is None:
return None
year, month, day = parsed
return f"{year}-{month:02d}-{day:02d}"
# =========================================================================
# Invoice Number Validation
# =========================================================================
@classmethod
def is_valid_invoice_number(cls, value: str, min_length: int = 1, max_length: int = 30) -> bool:
"""
Validate invoice number.
Basic validation - just length check since invoice numbers are highly variable.
"""
clean = TextCleaner.clean_text(value)
if not clean:
return False
# 提取有意义的字符(字母和数字)
meaningful = re.sub(r'[^a-zA-Z0-9]', '', clean)
return min_length <= len(meaningful) <= max_length
# =========================================================================
# Generic Validation
# =========================================================================
@classmethod
def validate_field(cls, field_name: str, value: str) -> tuple[bool, Optional[str]]:
"""
Validate a field by name.
Returns (is_valid, error_message).
"""
if not value:
return False, "Empty value"
field_lower = field_name.lower()
if 'organisation' in field_lower or 'org' in field_lower:
if cls.is_valid_organisation_number(value):
return True, None
return False, "Invalid organisation number format or checksum"
elif 'bankgiro' in field_lower:
if cls.is_valid_bankgiro(value):
return True, None
return False, "Invalid Bankgiro format or checksum"
elif 'plusgiro' in field_lower:
if cls.is_valid_plusgiro(value):
return True, None
return False, "Invalid Plusgiro format or checksum"
elif 'ocr' in field_lower:
if cls.is_valid_ocr_number(value, validate_checksum=False):
return True, None
return False, "Invalid OCR number length"
elif 'amount' in field_lower:
if cls.is_valid_amount(value):
return True, None
return False, "Invalid amount format"
elif 'date' in field_lower:
if cls.is_valid_date(value):
return True, None
return False, "Invalid date format"
elif 'invoice' in field_lower and 'number' in field_lower:
if cls.is_valid_invoice_number(value):
return True, None
return False, "Invalid invoice number"
else:
# 默认:只检查非空
if TextCleaner.clean_text(value):
return True, None
return False, "Empty value after cleaning"