236 lines
9.7 KiB
Python
236 lines
9.7 KiB
Python
"""
|
||
Tests for shared utility modules.
|
||
"""
|
||
|
||
import pytest
|
||
from src.utils.text_cleaner import TextCleaner
|
||
from src.utils.format_variants import FormatVariants
|
||
from src.utils.validators import FieldValidators
|
||
|
||
|
||
class TestTextCleaner:
|
||
"""Tests for TextCleaner class."""
|
||
|
||
def test_clean_unicode_dashes(self):
|
||
"""Test normalization of various dash types."""
|
||
# en-dash
|
||
assert TextCleaner.clean_unicode("556123–4567") == "556123-4567"
|
||
# em-dash
|
||
assert TextCleaner.clean_unicode("556123—4567") == "556123-4567"
|
||
# minus sign
|
||
assert TextCleaner.clean_unicode("556123−4567") == "556123-4567"
|
||
|
||
def test_clean_unicode_spaces(self):
|
||
"""Test normalization of various space types."""
|
||
# non-breaking space
|
||
assert TextCleaner.clean_unicode("1\xa0234") == "1 234"
|
||
# zero-width space removed
|
||
assert TextCleaner.clean_unicode("123\u200b456") == "123456"
|
||
|
||
def test_ocr_digit_corrections(self):
|
||
"""Test OCR error corrections for digit fields."""
|
||
# O -> 0
|
||
assert TextCleaner.apply_ocr_digit_corrections("556O23") == "556023"
|
||
# l -> 1
|
||
assert TextCleaner.apply_ocr_digit_corrections("556l23") == "556123"
|
||
# S -> 5
|
||
assert TextCleaner.apply_ocr_digit_corrections("5S6123") == "556123"
|
||
# Mixed
|
||
assert TextCleaner.apply_ocr_digit_corrections("S56l23-4S67") == "556123-4567"
|
||
|
||
def test_extract_digits(self):
|
||
"""Test digit extraction with OCR correction."""
|
||
assert TextCleaner.extract_digits("556123-4567") == "5561234567"
|
||
assert TextCleaner.extract_digits("556O23-4567", apply_ocr_correction=True) == "5560234567"
|
||
# Without OCR correction, only extracts actual digits
|
||
assert TextCleaner.extract_digits("ABC 123 DEF", apply_ocr_correction=False) == "123"
|
||
# With OCR correction, standalone letters are not converted
|
||
# (they need to be adjacent to digits to be corrected)
|
||
assert TextCleaner.extract_digits("A 123 B", apply_ocr_correction=True) == "123"
|
||
|
||
def test_normalize_amount_text(self):
|
||
"""Test amount text normalization."""
|
||
assert TextCleaner.normalize_amount_text("1 234,56 kr") == "1234,56"
|
||
assert TextCleaner.normalize_amount_text("SEK 1234.56") == "1234.56"
|
||
assert TextCleaner.normalize_amount_text("1 234 567,89 kronor") == "1234567,89"
|
||
|
||
|
||
class TestFormatVariants:
|
||
"""Tests for FormatVariants class."""
|
||
|
||
def test_organisation_number_variants(self):
|
||
"""Test organisation number variant generation."""
|
||
variants = FormatVariants.organisation_number_variants("5561234567")
|
||
|
||
assert "5561234567" in variants # 纯数字
|
||
assert "556123-4567" in variants # 带横线
|
||
assert "SE556123456701" in variants # VAT格式
|
||
|
||
def test_organisation_number_from_vat(self):
|
||
"""Test extracting org number from VAT format."""
|
||
variants = FormatVariants.organisation_number_variants("SE556123456701")
|
||
|
||
assert "5561234567" in variants
|
||
assert "556123-4567" in variants
|
||
|
||
def test_bankgiro_variants(self):
|
||
"""Test Bankgiro variant generation."""
|
||
# 8 digits
|
||
variants = FormatVariants.bankgiro_variants("53939484")
|
||
assert "53939484" in variants
|
||
assert "5393-9484" in variants
|
||
|
||
# 7 digits
|
||
variants = FormatVariants.bankgiro_variants("1234567")
|
||
assert "1234567" in variants
|
||
assert "123-4567" in variants
|
||
|
||
def test_plusgiro_variants(self):
|
||
"""Test Plusgiro variant generation."""
|
||
variants = FormatVariants.plusgiro_variants("12345678")
|
||
assert "12345678" in variants
|
||
assert "1234567-8" in variants
|
||
|
||
def test_amount_variants(self):
|
||
"""Test amount variant generation."""
|
||
variants = FormatVariants.amount_variants("1234.56")
|
||
|
||
assert "1234.56" in variants
|
||
assert "1234,56" in variants
|
||
assert "1 234,56" in variants or "1234,56" in variants # Swedish format
|
||
|
||
def test_date_variants(self):
|
||
"""Test date variant generation."""
|
||
variants = FormatVariants.date_variants("2024-12-29")
|
||
|
||
assert "2024-12-29" in variants # ISO
|
||
assert "29.12.2024" in variants # European
|
||
assert "29/12/2024" in variants # European slash
|
||
assert "20241229" in variants # Compact
|
||
assert "29 december 2024" in variants # Swedish text
|
||
|
||
def test_invoice_number_variants(self):
|
||
"""Test invoice number variant generation."""
|
||
variants = FormatVariants.invoice_number_variants("INV-2024-001")
|
||
|
||
assert "INV-2024-001" in variants
|
||
assert "INV2024001" in variants # No separators
|
||
assert "inv-2024-001" in variants # Lowercase
|
||
|
||
def test_get_variants_dispatch(self):
|
||
"""Test get_variants dispatches to correct method."""
|
||
# Organisation number
|
||
org_variants = FormatVariants.get_variants("supplier_organisation_number", "5561234567")
|
||
assert "556123-4567" in org_variants
|
||
|
||
# Bankgiro
|
||
bg_variants = FormatVariants.get_variants("Bankgiro", "53939484")
|
||
assert "5393-9484" in bg_variants
|
||
|
||
# Amount
|
||
amount_variants = FormatVariants.get_variants("Amount", "1234.56")
|
||
assert "1234,56" in amount_variants
|
||
|
||
|
||
class TestFieldValidators:
|
||
"""Tests for FieldValidators class."""
|
||
|
||
def test_luhn_checksum_valid(self):
|
||
"""Test Luhn validation with valid numbers."""
|
||
# Valid Bankgiro numbers (with correct check digit)
|
||
assert FieldValidators.luhn_checksum("53939484") is True
|
||
# Valid OCR numbers
|
||
assert FieldValidators.luhn_checksum("1234567897") is True # check digit 7
|
||
|
||
def test_luhn_checksum_invalid(self):
|
||
"""Test Luhn validation with invalid numbers."""
|
||
assert FieldValidators.luhn_checksum("53939485") is False # wrong check digit
|
||
assert FieldValidators.luhn_checksum("1234567890") is False
|
||
|
||
def test_calculate_luhn_check_digit(self):
|
||
"""Test Luhn check digit calculation."""
|
||
# For "123456789", the check digit should make it valid
|
||
check = FieldValidators.calculate_luhn_check_digit("123456789")
|
||
full_number = "123456789" + str(check)
|
||
assert FieldValidators.luhn_checksum(full_number) is True
|
||
|
||
def test_is_valid_organisation_number(self):
|
||
"""Test organisation number validation."""
|
||
# Valid (with correct Luhn checksum)
|
||
# Note: Need actual valid org numbers for this test
|
||
# Using a well-known one: 5565006245 (placeholder)
|
||
pass # Skip without real test data
|
||
|
||
def test_is_valid_bankgiro(self):
|
||
"""Test Bankgiro validation."""
|
||
# Valid 8-digit Bankgiro with Luhn
|
||
assert FieldValidators.is_valid_bankgiro("53939484") is True
|
||
# Invalid (wrong length)
|
||
assert FieldValidators.is_valid_bankgiro("123") is False
|
||
assert FieldValidators.is_valid_bankgiro("123456789") is False # 9 digits
|
||
|
||
def test_format_bankgiro(self):
|
||
"""Test Bankgiro formatting."""
|
||
assert FieldValidators.format_bankgiro("53939484") == "5393-9484"
|
||
assert FieldValidators.format_bankgiro("1234567") == "123-4567"
|
||
assert FieldValidators.format_bankgiro("123") is None
|
||
|
||
def test_is_valid_plusgiro(self):
|
||
"""Test Plusgiro validation."""
|
||
# Valid Plusgiro (2-8 digits with Luhn)
|
||
assert FieldValidators.is_valid_plusgiro("18") is True # minimal
|
||
# Invalid (wrong length)
|
||
assert FieldValidators.is_valid_plusgiro("1") is False
|
||
|
||
def test_format_plusgiro(self):
|
||
"""Test Plusgiro formatting."""
|
||
assert FieldValidators.format_plusgiro("12345678") == "1234567-8"
|
||
assert FieldValidators.format_plusgiro("123456") == "12345-6"
|
||
|
||
def test_is_valid_amount(self):
|
||
"""Test amount validation."""
|
||
assert FieldValidators.is_valid_amount("1234.56") is True
|
||
assert FieldValidators.is_valid_amount("1 234,56") is True
|
||
assert FieldValidators.is_valid_amount("abc") is False
|
||
assert FieldValidators.is_valid_amount("-100") is False # below min
|
||
assert FieldValidators.is_valid_amount("100000000") is False # above max
|
||
|
||
def test_parse_amount(self):
|
||
"""Test amount parsing."""
|
||
assert FieldValidators.parse_amount("1234.56") == 1234.56
|
||
assert FieldValidators.parse_amount("1 234,56") == 1234.56
|
||
assert FieldValidators.parse_amount("1.234,56") == 1234.56 # German
|
||
assert FieldValidators.parse_amount("1,234.56") == 1234.56 # US
|
||
|
||
def test_is_valid_date(self):
|
||
"""Test date validation."""
|
||
assert FieldValidators.is_valid_date("2024-12-29") is True
|
||
assert FieldValidators.is_valid_date("29.12.2024") is True
|
||
assert FieldValidators.is_valid_date("29/12/2024") is True
|
||
assert FieldValidators.is_valid_date("not a date") is False
|
||
assert FieldValidators.is_valid_date("1900-01-01") is False # out of range
|
||
|
||
def test_format_date_iso(self):
|
||
"""Test date ISO formatting."""
|
||
assert FieldValidators.format_date_iso("29.12.2024") == "2024-12-29"
|
||
assert FieldValidators.format_date_iso("29/12/2024") == "2024-12-29"
|
||
assert FieldValidators.format_date_iso("2024-12-29") == "2024-12-29"
|
||
|
||
def test_validate_field_dispatch(self):
|
||
"""Test validate_field dispatches correctly."""
|
||
# Organisation number
|
||
is_valid, error = FieldValidators.validate_field("supplier_organisation_number", "")
|
||
assert is_valid is False
|
||
|
||
# Amount
|
||
is_valid, error = FieldValidators.validate_field("Amount", "1234.56")
|
||
assert is_valid is True
|
||
|
||
# Date
|
||
is_valid, error = FieldValidators.validate_field("InvoiceDate", "2024-12-29")
|
||
assert is_valid is True
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|