Re-structure the project.

This commit is contained in:
Yaojia Wang
2026-01-25 15:21:11 +01:00
parent 8fd61ea928
commit e599424a92
80 changed files with 10672 additions and 1584 deletions

View File

@@ -0,0 +1,534 @@
"""
Tests for the CSV Data Loader Module.
Tests cover all loader functions in src/data/csv_loader.py
Usage:
pytest src/data/test_csv_loader.py -v -o 'addopts='
"""
import pytest
import tempfile
from pathlib import Path
from datetime import date
from decimal import Decimal
from src.data.csv_loader import (
InvoiceRow,
CSVLoader,
load_invoice_csv,
)
class TestInvoiceRow:
"""Tests for InvoiceRow dataclass."""
def test_creation_minimal(self):
"""Should create InvoiceRow with only required field."""
row = InvoiceRow(DocumentId="DOC001")
assert row.DocumentId == "DOC001"
assert row.InvoiceDate is None
assert row.Amount is None
def test_creation_full(self):
"""Should create InvoiceRow with all fields."""
row = InvoiceRow(
DocumentId="DOC001",
InvoiceDate=date(2025, 1, 15),
InvoiceNumber="INV-001",
InvoiceDueDate=date(2025, 2, 15),
OCR="1234567890",
Message="Test message",
Bankgiro="5393-9484",
Plusgiro="123456-7",
Amount=Decimal("1234.56"),
split="train",
customer_number="CUST001",
supplier_name="Test Supplier",
supplier_organisation_number="556123-4567",
supplier_accounts="BG:5393-9484",
)
assert row.DocumentId == "DOC001"
assert row.InvoiceDate == date(2025, 1, 15)
assert row.Amount == Decimal("1234.56")
def test_to_dict(self):
"""Should convert to dictionary correctly."""
row = InvoiceRow(
DocumentId="DOC001",
InvoiceDate=date(2025, 1, 15),
Amount=Decimal("100.50"),
)
d = row.to_dict()
assert d["DocumentId"] == "DOC001"
assert d["InvoiceDate"] == "2025-01-15"
assert d["Amount"] == "100.50"
def test_to_dict_none_values(self):
"""Should handle None values in to_dict."""
row = InvoiceRow(DocumentId="DOC001")
d = row.to_dict()
assert d["DocumentId"] == "DOC001"
assert d["InvoiceDate"] is None
assert d["Amount"] is None
def test_get_field_value_date(self):
"""Should get date field as ISO string."""
row = InvoiceRow(
DocumentId="DOC001",
InvoiceDate=date(2025, 1, 15),
)
assert row.get_field_value("InvoiceDate") == "2025-01-15"
def test_get_field_value_decimal(self):
"""Should get Decimal field as string."""
row = InvoiceRow(
DocumentId="DOC001",
Amount=Decimal("1234.56"),
)
assert row.get_field_value("Amount") == "1234.56"
def test_get_field_value_string(self):
"""Should get string field as-is."""
row = InvoiceRow(
DocumentId="DOC001",
InvoiceNumber="INV-001",
)
assert row.get_field_value("InvoiceNumber") == "INV-001"
def test_get_field_value_none(self):
"""Should return None for missing field."""
row = InvoiceRow(DocumentId="DOC001")
assert row.get_field_value("InvoiceNumber") is None
def test_get_field_value_unknown_field(self):
"""Should return None for unknown field."""
row = InvoiceRow(DocumentId="DOC001")
assert row.get_field_value("UnknownField") is None
class TestCSVLoaderParseDate:
"""Tests for CSVLoader._parse_date method."""
def test_parse_iso_format(self):
"""Should parse ISO date format."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("2025-01-15") == date(2025, 1, 15)
def test_parse_iso_with_time(self):
"""Should parse ISO format with time."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("2025-01-15 12:30:45") == date(2025, 1, 15)
def test_parse_iso_with_microseconds(self):
"""Should parse ISO format with microseconds."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("2025-01-15 12:30:45.123456") == date(2025, 1, 15)
def test_parse_european_slash(self):
"""Should parse DD/MM/YYYY format."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("15/01/2025") == date(2025, 1, 15)
def test_parse_european_dot(self):
"""Should parse DD.MM.YYYY format."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("15.01.2025") == date(2025, 1, 15)
def test_parse_european_dash(self):
"""Should parse DD-MM-YYYY format."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("15-01-2025") == date(2025, 1, 15)
def test_parse_compact(self):
"""Should parse YYYYMMDD format."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("20250115") == date(2025, 1, 15)
def test_parse_empty(self):
"""Should return None for empty string."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("") is None
assert loader._parse_date(" ") is None
def test_parse_none(self):
"""Should return None for None input."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date(None) is None
def test_parse_invalid(self):
"""Should return None for invalid date."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_date("not-a-date") is None
class TestCSVLoaderParseAmount:
"""Tests for CSVLoader._parse_amount method."""
def test_parse_simple_integer(self):
"""Should parse simple integer."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100") == Decimal("100")
def test_parse_decimal_dot(self):
"""Should parse decimal with dot."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100.50") == Decimal("100.50")
def test_parse_decimal_comma(self):
"""Should parse European format with comma."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100,50") == Decimal("100.50")
def test_parse_with_thousand_separator_space(self):
"""Should handle space as thousand separator."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("1 234,56") == Decimal("1234.56")
def test_parse_with_thousand_separator_comma(self):
"""Should handle comma as thousand separator when dot is decimal."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("1,234.56") == Decimal("1234.56")
def test_parse_with_currency_sek(self):
"""Should remove SEK suffix."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100 SEK") == Decimal("100")
def test_parse_with_currency_kr(self):
"""Should remove kr suffix."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100 kr") == Decimal("100")
def test_parse_with_colon_dash(self):
"""Should remove :- suffix."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("100:-") == Decimal("100")
def test_parse_empty(self):
"""Should return None for empty string."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("") is None
assert loader._parse_amount(" ") is None
def test_parse_none(self):
"""Should return None for None input."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount(None) is None
def test_parse_invalid(self):
"""Should return None for invalid amount."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_amount("not-an-amount") is None
class TestCSVLoaderParseString:
"""Tests for CSVLoader._parse_string method."""
def test_parse_normal_string(self):
"""Should return stripped string."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_string(" hello ") == "hello"
def test_parse_empty_string(self):
"""Should return None for empty string."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_string("") is None
assert loader._parse_string(" ") is None
def test_parse_none(self):
"""Should return None for None input."""
loader = CSVLoader.__new__(CSVLoader)
assert loader._parse_string(None) is None
class TestCSVLoaderWithFile:
"""Tests for CSVLoader with actual CSV files."""
@pytest.fixture
def sample_csv(self, tmp_path):
"""Create a sample CSV file for testing."""
csv_content = """DocumentId,InvoiceDate,InvoiceNumber,Amount,Bankgiro
DOC001,2025-01-15,INV-001,100.50,5393-9484
DOC002,2025-01-16,INV-002,200.00,1234-5678
DOC003,2025-01-17,INV-003,300.75,
"""
csv_file = tmp_path / "test.csv"
csv_file.write_text(csv_content, encoding="utf-8")
return csv_file
@pytest.fixture
def sample_csv_with_bom(self, tmp_path):
"""Create a CSV file with BOM."""
csv_content = """DocumentId,InvoiceDate,Amount
DOC001,2025-01-15,100.50
"""
csv_file = tmp_path / "test_bom.csv"
csv_file.write_text(csv_content, encoding="utf-8-sig")
return csv_file
def test_load_all(self, sample_csv):
"""Should load all rows from CSV."""
loader = CSVLoader(sample_csv)
rows = loader.load_all()
assert len(rows) == 3
assert rows[0].DocumentId == "DOC001"
assert rows[1].DocumentId == "DOC002"
assert rows[2].DocumentId == "DOC003"
def test_iter_rows(self, sample_csv):
"""Should iterate over rows."""
loader = CSVLoader(sample_csv)
rows = list(loader.iter_rows())
assert len(rows) == 3
def test_parse_fields_correctly(self, sample_csv):
"""Should parse all fields correctly."""
loader = CSVLoader(sample_csv)
rows = loader.load_all()
row = rows[0]
assert row.InvoiceDate == date(2025, 1, 15)
assert row.InvoiceNumber == "INV-001"
assert row.Amount == Decimal("100.50")
assert row.Bankgiro == "5393-9484"
def test_handles_empty_fields(self, sample_csv):
"""Should handle empty fields as None."""
loader = CSVLoader(sample_csv)
rows = loader.load_all()
row = rows[2] # Last row has empty Bankgiro
assert row.Bankgiro is None
def test_handles_bom(self, sample_csv_with_bom):
"""Should handle files with BOM correctly."""
loader = CSVLoader(sample_csv_with_bom)
rows = loader.load_all()
assert len(rows) == 1
assert rows[0].DocumentId == "DOC001"
def test_get_row_by_id(self, sample_csv):
"""Should get specific row by DocumentId."""
loader = CSVLoader(sample_csv)
row = loader.get_row_by_id("DOC002")
assert row is not None
assert row.InvoiceNumber == "INV-002"
def test_get_row_by_id_not_found(self, sample_csv):
"""Should return None for non-existent DocumentId."""
loader = CSVLoader(sample_csv)
row = loader.get_row_by_id("NONEXISTENT")
assert row is None
class TestCSVLoaderMultipleFiles:
"""Tests for CSVLoader with multiple CSV files."""
@pytest.fixture
def multiple_csvs(self, tmp_path):
"""Create multiple CSV files for testing."""
csv1 = tmp_path / "file1.csv"
csv1.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001
DOC002,INV-002
""", encoding="utf-8")
csv2 = tmp_path / "file2.csv"
csv2.write_text("""DocumentId,InvoiceNumber
DOC003,INV-003
DOC004,INV-004
""", encoding="utf-8")
return [csv1, csv2]
def test_load_from_list(self, multiple_csvs):
"""Should load from list of CSV paths."""
loader = CSVLoader(multiple_csvs)
rows = loader.load_all()
assert len(rows) == 4
doc_ids = [r.DocumentId for r in rows]
assert "DOC001" in doc_ids
assert "DOC004" in doc_ids
def test_load_from_glob(self, multiple_csvs, tmp_path):
"""Should load from glob pattern."""
loader = CSVLoader(tmp_path / "*.csv")
rows = loader.load_all()
assert len(rows) == 4
def test_deduplicates_by_doc_id(self, tmp_path):
"""Should deduplicate rows by DocumentId across files."""
csv1 = tmp_path / "file1.csv"
csv1.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001
""", encoding="utf-8")
csv2 = tmp_path / "file2.csv"
csv2.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001-DUPLICATE
""", encoding="utf-8")
loader = CSVLoader([csv1, csv2])
rows = loader.load_all()
assert len(rows) == 1
assert rows[0].InvoiceNumber == "INV-001" # First one wins
class TestCSVLoaderPDFPath:
"""Tests for CSVLoader.get_pdf_path method."""
@pytest.fixture
def setup_pdf_dir(self, tmp_path):
"""Create PDF directory with some files."""
pdf_dir = tmp_path / "pdfs"
pdf_dir.mkdir()
# Create some dummy PDF files
(pdf_dir / "DOC001.pdf").touch()
(pdf_dir / "doc002.pdf").touch()
(pdf_dir / "INVOICE_DOC003.pdf").touch()
csv_file = tmp_path / "test.csv"
csv_file.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001
DOC002,INV-002
DOC003,INV-003
DOC004,INV-004
""", encoding="utf-8")
return csv_file, pdf_dir
def test_find_exact_match(self, setup_pdf_dir):
"""Should find PDF with exact name match."""
csv_file, pdf_dir = setup_pdf_dir
loader = CSVLoader(csv_file, pdf_dir)
rows = loader.load_all()
pdf_path = loader.get_pdf_path(rows[0]) # DOC001
assert pdf_path is not None
assert pdf_path.name == "DOC001.pdf"
def test_find_lowercase_match(self, setup_pdf_dir):
"""Should find PDF with lowercase name."""
csv_file, pdf_dir = setup_pdf_dir
loader = CSVLoader(csv_file, pdf_dir)
rows = loader.load_all()
pdf_path = loader.get_pdf_path(rows[1]) # DOC002 -> doc002.pdf
assert pdf_path is not None
assert pdf_path.name == "doc002.pdf"
def test_find_glob_match(self, setup_pdf_dir):
"""Should find PDF using glob pattern."""
csv_file, pdf_dir = setup_pdf_dir
loader = CSVLoader(csv_file, pdf_dir)
rows = loader.load_all()
pdf_path = loader.get_pdf_path(rows[2]) # DOC003 -> INVOICE_DOC003.pdf
assert pdf_path is not None
assert "DOC003" in pdf_path.name
def test_not_found(self, setup_pdf_dir):
"""Should return None when PDF not found."""
csv_file, pdf_dir = setup_pdf_dir
loader = CSVLoader(csv_file, pdf_dir)
rows = loader.load_all()
pdf_path = loader.get_pdf_path(rows[3]) # DOC004 - no PDF
assert pdf_path is None
class TestCSVLoaderValidate:
"""Tests for CSVLoader.validate method."""
def test_validate_missing_pdf(self, tmp_path):
"""Should report missing PDF files."""
csv_file = tmp_path / "test.csv"
csv_file.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001
""", encoding="utf-8")
loader = CSVLoader(csv_file, tmp_path)
issues = loader.validate()
assert len(issues) >= 1
pdf_issues = [i for i in issues if i.get("field") == "PDF"]
assert len(pdf_issues) == 1
def test_validate_no_matchable_fields(self, tmp_path):
"""Should report rows with no matchable fields."""
csv_file = tmp_path / "test.csv"
csv_file.write_text("""DocumentId,Message
DOC001,Just a message
""", encoding="utf-8")
# Create a PDF so we only get the matchable fields issue
pdf_dir = tmp_path / "pdfs"
pdf_dir.mkdir()
(pdf_dir / "DOC001.pdf").touch()
loader = CSVLoader(csv_file, pdf_dir)
issues = loader.validate()
field_issues = [i for i in issues if i.get("field") == "All"]
assert len(field_issues) == 1
class TestCSVLoaderAlternateFieldNames:
"""Tests for alternate field name support."""
def test_lowercase_field_names(self, tmp_path):
"""Should accept lowercase field names."""
csv_file = tmp_path / "test.csv"
csv_file.write_text("""document_id,invoice_date,invoice_number,amount
DOC001,2025-01-15,INV-001,100.50
""", encoding="utf-8")
loader = CSVLoader(csv_file)
rows = loader.load_all()
assert len(rows) == 1
assert rows[0].DocumentId == "DOC001"
assert rows[0].InvoiceDate == date(2025, 1, 15)
def test_alternate_amount_field(self, tmp_path):
"""Should accept invoice_data_amount as Amount field."""
csv_file = tmp_path / "test.csv"
csv_file.write_text("""DocumentId,invoice_data_amount
DOC001,100.50
""", encoding="utf-8")
loader = CSVLoader(csv_file)
rows = loader.load_all()
assert rows[0].Amount == Decimal("100.50")
class TestLoadInvoiceCSV:
"""Tests for load_invoice_csv convenience function."""
def test_load_single_file(self, tmp_path):
"""Should load from single CSV file."""
csv_file = tmp_path / "test.csv"
csv_file.write_text("""DocumentId,InvoiceNumber
DOC001,INV-001
""", encoding="utf-8")
rows = load_invoice_csv(csv_file)
assert len(rows) == 1
assert rows[0].DocumentId == "DOC001"
if __name__ == "__main__":
pytest.main([__file__, "-v"])