""" Tests for the CSV Data Loader Module. Tests cover all loader functions in src/data/csv_loader.py Usage: pytest src/data/test_csv_loader.py -v -o 'addopts=' """ import pytest import tempfile from pathlib import Path from datetime import date from decimal import Decimal from src.data.csv_loader import ( InvoiceRow, CSVLoader, load_invoice_csv, ) class TestInvoiceRow: """Tests for InvoiceRow dataclass.""" def test_creation_minimal(self): """Should create InvoiceRow with only required field.""" row = InvoiceRow(DocumentId="DOC001") assert row.DocumentId == "DOC001" assert row.InvoiceDate is None assert row.Amount is None def test_creation_full(self): """Should create InvoiceRow with all fields.""" row = InvoiceRow( DocumentId="DOC001", InvoiceDate=date(2025, 1, 15), InvoiceNumber="INV-001", InvoiceDueDate=date(2025, 2, 15), OCR="1234567890", Message="Test message", Bankgiro="5393-9484", Plusgiro="123456-7", Amount=Decimal("1234.56"), split="train", customer_number="CUST001", supplier_name="Test Supplier", supplier_organisation_number="556123-4567", supplier_accounts="BG:5393-9484", ) assert row.DocumentId == "DOC001" assert row.InvoiceDate == date(2025, 1, 15) assert row.Amount == Decimal("1234.56") def test_to_dict(self): """Should convert to dictionary correctly.""" row = InvoiceRow( DocumentId="DOC001", InvoiceDate=date(2025, 1, 15), Amount=Decimal("100.50"), ) d = row.to_dict() assert d["DocumentId"] == "DOC001" assert d["InvoiceDate"] == "2025-01-15" assert d["Amount"] == "100.50" def test_to_dict_none_values(self): """Should handle None values in to_dict.""" row = InvoiceRow(DocumentId="DOC001") d = row.to_dict() assert d["DocumentId"] == "DOC001" assert d["InvoiceDate"] is None assert d["Amount"] is None def test_get_field_value_date(self): """Should get date field as ISO string.""" row = InvoiceRow( DocumentId="DOC001", InvoiceDate=date(2025, 1, 15), ) assert row.get_field_value("InvoiceDate") == "2025-01-15" def test_get_field_value_decimal(self): """Should get Decimal field as string.""" row = InvoiceRow( DocumentId="DOC001", Amount=Decimal("1234.56"), ) assert row.get_field_value("Amount") == "1234.56" def test_get_field_value_string(self): """Should get string field as-is.""" row = InvoiceRow( DocumentId="DOC001", InvoiceNumber="INV-001", ) assert row.get_field_value("InvoiceNumber") == "INV-001" def test_get_field_value_none(self): """Should return None for missing field.""" row = InvoiceRow(DocumentId="DOC001") assert row.get_field_value("InvoiceNumber") is None def test_get_field_value_unknown_field(self): """Should return None for unknown field.""" row = InvoiceRow(DocumentId="DOC001") assert row.get_field_value("UnknownField") is None class TestCSVLoaderParseDate: """Tests for CSVLoader._parse_date method.""" def test_parse_iso_format(self): """Should parse ISO date format.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("2025-01-15") == date(2025, 1, 15) def test_parse_iso_with_time(self): """Should parse ISO format with time.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("2025-01-15 12:30:45") == date(2025, 1, 15) def test_parse_iso_with_microseconds(self): """Should parse ISO format with microseconds.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("2025-01-15 12:30:45.123456") == date(2025, 1, 15) def test_parse_european_slash(self): """Should parse DD/MM/YYYY format.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("15/01/2025") == date(2025, 1, 15) def test_parse_european_dot(self): """Should parse DD.MM.YYYY format.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("15.01.2025") == date(2025, 1, 15) def test_parse_european_dash(self): """Should parse DD-MM-YYYY format.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("15-01-2025") == date(2025, 1, 15) def test_parse_compact(self): """Should parse YYYYMMDD format.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("20250115") == date(2025, 1, 15) def test_parse_empty(self): """Should return None for empty string.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("") is None assert loader._parse_date(" ") is None def test_parse_none(self): """Should return None for None input.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date(None) is None def test_parse_invalid(self): """Should return None for invalid date.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_date("not-a-date") is None class TestCSVLoaderParseAmount: """Tests for CSVLoader._parse_amount method.""" def test_parse_simple_integer(self): """Should parse simple integer.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100") == Decimal("100") def test_parse_decimal_dot(self): """Should parse decimal with dot.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100.50") == Decimal("100.50") def test_parse_decimal_comma(self): """Should parse European format with comma.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100,50") == Decimal("100.50") def test_parse_with_thousand_separator_space(self): """Should handle space as thousand separator.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("1 234,56") == Decimal("1234.56") def test_parse_with_thousand_separator_comma(self): """Should handle comma as thousand separator when dot is decimal.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("1,234.56") == Decimal("1234.56") def test_parse_with_currency_sek(self): """Should remove SEK suffix.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100 SEK") == Decimal("100") def test_parse_with_currency_kr(self): """Should remove kr suffix.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100 kr") == Decimal("100") def test_parse_with_colon_dash(self): """Should remove :- suffix.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("100:-") == Decimal("100") def test_parse_empty(self): """Should return None for empty string.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("") is None assert loader._parse_amount(" ") is None def test_parse_none(self): """Should return None for None input.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount(None) is None def test_parse_invalid(self): """Should return None for invalid amount.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_amount("not-an-amount") is None class TestCSVLoaderParseString: """Tests for CSVLoader._parse_string method.""" def test_parse_normal_string(self): """Should return stripped string.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_string(" hello ") == "hello" def test_parse_empty_string(self): """Should return None for empty string.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_string("") is None assert loader._parse_string(" ") is None def test_parse_none(self): """Should return None for None input.""" loader = CSVLoader.__new__(CSVLoader) assert loader._parse_string(None) is None class TestCSVLoaderWithFile: """Tests for CSVLoader with actual CSV files.""" @pytest.fixture def sample_csv(self, tmp_path): """Create a sample CSV file for testing.""" csv_content = """DocumentId,InvoiceDate,InvoiceNumber,Amount,Bankgiro DOC001,2025-01-15,INV-001,100.50,5393-9484 DOC002,2025-01-16,INV-002,200.00,1234-5678 DOC003,2025-01-17,INV-003,300.75, """ csv_file = tmp_path / "test.csv" csv_file.write_text(csv_content, encoding="utf-8") return csv_file @pytest.fixture def sample_csv_with_bom(self, tmp_path): """Create a CSV file with BOM.""" csv_content = """DocumentId,InvoiceDate,Amount DOC001,2025-01-15,100.50 """ csv_file = tmp_path / "test_bom.csv" csv_file.write_text(csv_content, encoding="utf-8-sig") return csv_file def test_load_all(self, sample_csv): """Should load all rows from CSV.""" loader = CSVLoader(sample_csv) rows = loader.load_all() assert len(rows) == 3 assert rows[0].DocumentId == "DOC001" assert rows[1].DocumentId == "DOC002" assert rows[2].DocumentId == "DOC003" def test_iter_rows(self, sample_csv): """Should iterate over rows.""" loader = CSVLoader(sample_csv) rows = list(loader.iter_rows()) assert len(rows) == 3 def test_parse_fields_correctly(self, sample_csv): """Should parse all fields correctly.""" loader = CSVLoader(sample_csv) rows = loader.load_all() row = rows[0] assert row.InvoiceDate == date(2025, 1, 15) assert row.InvoiceNumber == "INV-001" assert row.Amount == Decimal("100.50") assert row.Bankgiro == "5393-9484" def test_handles_empty_fields(self, sample_csv): """Should handle empty fields as None.""" loader = CSVLoader(sample_csv) rows = loader.load_all() row = rows[2] # Last row has empty Bankgiro assert row.Bankgiro is None def test_handles_bom(self, sample_csv_with_bom): """Should handle files with BOM correctly.""" loader = CSVLoader(sample_csv_with_bom) rows = loader.load_all() assert len(rows) == 1 assert rows[0].DocumentId == "DOC001" def test_get_row_by_id(self, sample_csv): """Should get specific row by DocumentId.""" loader = CSVLoader(sample_csv) row = loader.get_row_by_id("DOC002") assert row is not None assert row.InvoiceNumber == "INV-002" def test_get_row_by_id_not_found(self, sample_csv): """Should return None for non-existent DocumentId.""" loader = CSVLoader(sample_csv) row = loader.get_row_by_id("NONEXISTENT") assert row is None class TestCSVLoaderMultipleFiles: """Tests for CSVLoader with multiple CSV files.""" @pytest.fixture def multiple_csvs(self, tmp_path): """Create multiple CSV files for testing.""" csv1 = tmp_path / "file1.csv" csv1.write_text("""DocumentId,InvoiceNumber DOC001,INV-001 DOC002,INV-002 """, encoding="utf-8") csv2 = tmp_path / "file2.csv" csv2.write_text("""DocumentId,InvoiceNumber DOC003,INV-003 DOC004,INV-004 """, encoding="utf-8") return [csv1, csv2] def test_load_from_list(self, multiple_csvs): """Should load from list of CSV paths.""" loader = CSVLoader(multiple_csvs) rows = loader.load_all() assert len(rows) == 4 doc_ids = [r.DocumentId for r in rows] assert "DOC001" in doc_ids assert "DOC004" in doc_ids def test_load_from_glob(self, multiple_csvs, tmp_path): """Should load from glob pattern.""" loader = CSVLoader(tmp_path / "*.csv") rows = loader.load_all() assert len(rows) == 4 def test_deduplicates_by_doc_id(self, tmp_path): """Should deduplicate rows by DocumentId across files.""" csv1 = tmp_path / "file1.csv" csv1.write_text("""DocumentId,InvoiceNumber DOC001,INV-001 """, encoding="utf-8") csv2 = tmp_path / "file2.csv" csv2.write_text("""DocumentId,InvoiceNumber DOC001,INV-001-DUPLICATE """, encoding="utf-8") loader = CSVLoader([csv1, csv2]) rows = loader.load_all() assert len(rows) == 1 assert rows[0].InvoiceNumber == "INV-001" # First one wins class TestCSVLoaderPDFPath: """Tests for CSVLoader.get_pdf_path method.""" @pytest.fixture def setup_pdf_dir(self, tmp_path): """Create PDF directory with some files.""" pdf_dir = tmp_path / "pdfs" pdf_dir.mkdir() # Create some dummy PDF files (pdf_dir / "DOC001.pdf").touch() (pdf_dir / "doc002.pdf").touch() (pdf_dir / "INVOICE_DOC003.pdf").touch() csv_file = tmp_path / "test.csv" csv_file.write_text("""DocumentId,InvoiceNumber DOC001,INV-001 DOC002,INV-002 DOC003,INV-003 DOC004,INV-004 """, encoding="utf-8") return csv_file, pdf_dir def test_find_exact_match(self, setup_pdf_dir): """Should find PDF with exact name match.""" csv_file, pdf_dir = setup_pdf_dir loader = CSVLoader(csv_file, pdf_dir) rows = loader.load_all() pdf_path = loader.get_pdf_path(rows[0]) # DOC001 assert pdf_path is not None assert pdf_path.name == "DOC001.pdf" def test_find_lowercase_match(self, setup_pdf_dir): """Should find PDF with lowercase name.""" csv_file, pdf_dir = setup_pdf_dir loader = CSVLoader(csv_file, pdf_dir) rows = loader.load_all() pdf_path = loader.get_pdf_path(rows[1]) # DOC002 -> doc002.pdf assert pdf_path is not None assert pdf_path.name == "doc002.pdf" def test_find_glob_match(self, setup_pdf_dir): """Should find PDF using glob pattern.""" csv_file, pdf_dir = setup_pdf_dir loader = CSVLoader(csv_file, pdf_dir) rows = loader.load_all() pdf_path = loader.get_pdf_path(rows[2]) # DOC003 -> INVOICE_DOC003.pdf assert pdf_path is not None assert "DOC003" in pdf_path.name def test_not_found(self, setup_pdf_dir): """Should return None when PDF not found.""" csv_file, pdf_dir = setup_pdf_dir loader = CSVLoader(csv_file, pdf_dir) rows = loader.load_all() pdf_path = loader.get_pdf_path(rows[3]) # DOC004 - no PDF assert pdf_path is None class TestCSVLoaderValidate: """Tests for CSVLoader.validate method.""" def test_validate_missing_pdf(self, tmp_path): """Should report missing PDF files.""" csv_file = tmp_path / "test.csv" csv_file.write_text("""DocumentId,InvoiceNumber DOC001,INV-001 """, encoding="utf-8") loader = CSVLoader(csv_file, tmp_path) issues = loader.validate() assert len(issues) >= 1 pdf_issues = [i for i in issues if i.get("field") == "PDF"] assert len(pdf_issues) == 1 def test_validate_no_matchable_fields(self, tmp_path): """Should report rows with no matchable fields.""" csv_file = tmp_path / "test.csv" csv_file.write_text("""DocumentId,Message DOC001,Just a message """, encoding="utf-8") # Create a PDF so we only get the matchable fields issue pdf_dir = tmp_path / "pdfs" pdf_dir.mkdir() (pdf_dir / "DOC001.pdf").touch() loader = CSVLoader(csv_file, pdf_dir) issues = loader.validate() field_issues = [i for i in issues if i.get("field") == "All"] assert len(field_issues) == 1 class TestCSVLoaderAlternateFieldNames: """Tests for alternate field name support.""" def test_lowercase_field_names(self, tmp_path): """Should accept lowercase field names.""" csv_file = tmp_path / "test.csv" csv_file.write_text("""document_id,invoice_date,invoice_number,amount DOC001,2025-01-15,INV-001,100.50 """, encoding="utf-8") loader = CSVLoader(csv_file) rows = loader.load_all() assert len(rows) == 1 assert rows[0].DocumentId == "DOC001" assert rows[0].InvoiceDate == date(2025, 1, 15) def test_alternate_amount_field(self, tmp_path): """Should accept invoice_data_amount as Amount field.""" csv_file = tmp_path / "test.csv" csv_file.write_text("""DocumentId,invoice_data_amount DOC001,100.50 """, encoding="utf-8") loader = CSVLoader(csv_file) rows = loader.load_all() assert rows[0].Amount == Decimal("100.50") class TestLoadInvoiceCSV: """Tests for load_invoice_csv convenience function.""" def test_load_single_file(self, tmp_path): """Should load from single CSV file.""" csv_file = tmp_path / "test.csv" csv_file.write_text("""DocumentId,InvoiceNumber DOC001,INV-001 """, encoding="utf-8") rows = load_invoice_csv(csv_file) assert len(rows) == 1 assert rows[0].DocumentId == "DOC001" if __name__ == "__main__": pytest.main([__file__, "-v"])