refactor: split line_items_extractor into smaller modules with comprehensive tests

- Extract models.py (LineItem, LineItemsResult dataclasses)
- Extract html_table_parser.py (ColumnMapper, HtmlTableParser)
- Extract merged_cell_handler.py (MergedCellHandler for PP-StructureV3 merged cells)
- Reduce line_items_extractor.py from 971 to 396 lines
- Add constants for magic numbers (MIN_AMOUNT_THRESHOLD, ROW_GROUPING_THRESHOLD, etc.)
- Fix row grouping algorithm in text_line_items_extractor.py
- Demote INFO logs to DEBUG level in structure_detector.py
- Add 209 tests achieving 85%+ coverage on main modules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Yaojia Wang
2026-02-03 23:02:00 +01:00
parent c2c8f2dd04
commit 8723ef4653
11 changed files with 2230 additions and 841 deletions

View File

@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
extractor = LineItemsExtractor()
# Create mock table detection result
# Create mock table detection result with proper thead/tbody structure
mock_table = MagicMock(spec=TableDetectionResult)
mock_table.html = """
<table>
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
</table>
"""
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
assert len(result.items) >= 1
class TestPdfPathValidation:
"""Tests for PDF path validation."""
def test_detect_tables_with_nonexistent_path(self):
"""Test that non-existent PDF path returns empty results."""
extractor = LineItemsExtractor()
# Create detector and call _detect_tables_with_parsing with non-existent path
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, "nonexistent.pdf"
)
assert tables == []
assert parsing_res == []
def test_detect_tables_with_directory_path(self, tmp_path):
"""Test that directory path (not file) returns empty results."""
extractor = LineItemsExtractor()
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
# tmp_path is a directory, not a file
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, str(tmp_path)
)
assert tables == []
assert parsing_res == []
def test_detect_tables_validates_file_exists(self, tmp_path):
"""Test path validation for file existence.
This test verifies that the method correctly validates the path exists
and is a file before attempting to process it.
"""
from unittest.mock import patch
extractor = LineItemsExtractor()
# Create a real file path that exists
fake_pdf = tmp_path / "test.pdf"
fake_pdf.write_bytes(b"not a real pdf")
# Mock render_pdf_to_images to avoid actual PDF processing
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
# Return empty iterator - simulates file exists but no pages
mock_render.return_value = iter([])
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
mock_detector._ensure_initialized = MagicMock()
mock_detector._pipeline = MagicMock()
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, str(fake_pdf)
)
# render_pdf_to_images was called (path validation passed)
mock_render.assert_called_once()
assert tables == []
assert parsing_res == []
class TestLineItemsResult:
"""Tests for LineItemsResult dataclass."""
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
assert result.items[0].is_deduction is False
assert result.items[1].amount == "-2000"
assert result.items[1].is_deduction is True
class TestTextFallbackExtraction:
"""Tests for text-based fallback extraction."""
def test_text_fallback_disabled_by_default(self):
"""Test text fallback can be disabled."""
extractor = LineItemsExtractor(enable_text_fallback=False)
assert extractor.enable_text_fallback is False
def test_text_fallback_enabled_by_default(self):
"""Test text fallback is enabled by default."""
extractor = LineItemsExtractor()
assert extractor.enable_text_fallback is True
def test_try_text_fallback_with_valid_parsing_res(self):
"""Test text fallback with valid parsing results."""
from unittest.mock import patch, MagicMock
from backend.table.text_line_items_extractor import (
TextLineItemsExtractor,
TextLineItem,
TextLineItemsResult,
)
extractor = LineItemsExtractor()
# Mock parsing_res_list with text elements
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
]
# Create mock text extraction result
mock_text_result = TextLineItemsResult(
items=[
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
],
header_row=[],
)
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
result = extractor._try_text_fallback(parsing_res)
assert result is not None
assert len(result.items) == 2
assert result.items[0].description == "Product A"
assert result.items[1].description == "Product B"
def test_try_text_fallback_returns_none_on_failure(self):
"""Test text fallback returns None when extraction fails."""
from unittest.mock import patch
extractor = LineItemsExtractor()
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
result = extractor._try_text_fallback([])
assert result is None
def test_extract_from_pdf_uses_text_fallback(self):
"""Test extract_from_pdf uses text fallback when no tables found."""
from unittest.mock import patch, MagicMock
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
extractor = LineItemsExtractor(enable_text_fallback=True)
# Mock _detect_tables_with_parsing to return no tables but parsing_res
mock_text_result = TextLineItemsResult(
items=[
TextLineItem(row_index=0, description="Product", amount="100,00"),
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
],
header_row=[],
)
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
result = extractor.extract_from_pdf("fake.pdf")
# Text fallback should be called
mock_fallback.assert_called_once()
def test_extract_from_pdf_skips_fallback_when_disabled(self):
"""Test extract_from_pdf skips text fallback when disabled."""
from unittest.mock import patch
extractor = LineItemsExtractor(enable_text_fallback=False)
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
result = extractor.extract_from_pdf("fake.pdf")
# Should return None, not use text fallback
assert result is None
class TestVerticallyMergedCellExtraction:
"""Tests for vertically merged cell extraction."""
def test_detects_vertically_merged_cells(self):
"""Test detection of vertically merged cells in rows."""
extractor = LineItemsExtractor()
# Rows with multiple product numbers in single cell
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
assert extractor._has_vertically_merged_cells(rows) is True
def test_splits_vertically_merged_rows(self):
"""Test splitting vertically merged rows."""
extractor = LineItemsExtractor()
rows = [
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
]
header, data = extractor._split_merged_rows(rows)
# Should split into header + data rows
assert isinstance(header, list)
assert isinstance(data, list)
class TestDeductionDetection:
"""Tests for deduction/discount detection."""
def test_detects_deduction_by_keyword_avdrag(self):
"""Test detection of deduction by 'avdrag' keyword."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_detects_deduction_by_keyword_rabatt(self):
"""Test detection of deduction by 'rabatt' keyword."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_detects_deduction_by_negative_amount(self):
"""Test detection of deduction by negative amount."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Some credit</td><td>-250,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_normal_item_not_deduction(self):
"""Test normal item is not marked as deduction."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Normal product</td><td>500,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is False
class TestHeaderDetection:
"""Tests for header row detection."""
def test_detect_header_at_bottom(self):
"""Test detecting header at bottom of table (reversed)."""
extractor = LineItemsExtractor()
rows = [
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
["Belopp", "Beskrivning", "Antal"], # Header at bottom
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == 2
assert is_at_end is True
assert "Belopp" in header
def test_detect_header_at_top(self):
"""Test detecting header at top of table."""
extractor = LineItemsExtractor()
rows = [
["Belopp", "Beskrivning", "Antal"], # Header at top
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == 0
assert is_at_end is False
assert "Belopp" in header
def test_no_header_detected(self):
"""Test when no header is detected."""
extractor = LineItemsExtractor()
rows = [
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == -1
assert header == []
assert is_at_end is False

View File

@@ -0,0 +1,448 @@
"""
Tests for Merged Cell Handler
Tests the detection and extraction of data from tables with merged cells,
a common issue with PP-StructureV3 OCR output.
"""
import pytest
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
from backend.table.html_table_parser import ColumnMapper
@pytest.fixture
def handler():
"""Create a MergedCellHandler with default ColumnMapper."""
return MergedCellHandler(ColumnMapper())
class TestHasVerticallyMergedCells:
"""Tests for has_vertically_merged_cells detection."""
def test_empty_rows_returns_false(self, handler):
"""Test empty rows returns False."""
assert handler.has_vertically_merged_cells([]) is False
def test_short_cells_ignored(self, handler):
"""Test cells shorter than 20 chars are ignored."""
rows = [["Short cell", "Also short"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_product_numbers(self, handler):
"""Test detection of multiple 7-digit product numbers in cell."""
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_single_product_number_not_merged(self, handler):
"""Test single product number doesn't trigger detection."""
rows = [["Produktnr 1457280 and more text here for length"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_prices(self, handler):
"""Test detection of 3+ prices in cell (Swedish format)."""
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_two_prices_not_merged(self, handler):
"""Test two prices doesn't trigger detection (needs 3+)."""
rows = [["Pris 127,20 234,56 total amount here"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_quantities(self, handler):
"""Test detection of multiple quantity patterns."""
rows = [["Antal 6ST 6ST 1ST more text here"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_single_quantity_not_merged(self, handler):
"""Test single quantity doesn't trigger detection."""
rows = [["Antal 6ST and more text here for length"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_empty_cell_skipped(self, handler):
"""Test empty cells are skipped."""
rows = [["", None, "Valid but short"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_multiple_rows_checked(self, handler):
"""Test all rows are checked for merged content."""
rows = [
["Normal row with nothing special"],
["Produktnr 1457280 1457281 1060381 merged content"],
]
assert handler.has_vertically_merged_cells(rows) is True
class TestSplitMergedRows:
"""Tests for split_merged_rows method."""
def test_empty_rows_returns_empty(self, handler):
"""Test empty rows returns empty result."""
header, data = handler.split_merged_rows([])
assert header == []
assert data == []
def test_all_empty_rows_returns_original(self, handler):
"""Test all empty rows returns original rows."""
rows = [["", ""], ["", ""]]
header, data = handler.split_merged_rows(rows)
assert header == []
assert data == rows
def test_splits_by_product_numbers(self, handler):
"""Test splitting rows by product numbers."""
rows = [
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
]
header, data = handler.split_merged_rows(rows)
assert len(header) == 3
assert header[0] == "Produktnr"
assert len(data) == 2
def test_splits_by_quantities(self, handler):
"""Test splitting rows by quantity patterns."""
rows = [
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
]
header, data = handler.split_merged_rows(rows)
# Should detect 2 quantities and split accordingly
assert len(data) >= 1
def test_single_row_not_split(self, handler):
"""Test single item row is not split."""
rows = [
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
]
header, data = handler.split_merged_rows(rows)
# Only 1 product number, so expected_rows <= 1
assert header == []
assert data == rows
def test_handles_missing_columns(self, handler):
"""Test handles rows with different column counts."""
rows = [
["Produktnr 1234567 1234568", ""],
["Antal 2ST 3ST"],
]
header, data = handler.split_merged_rows(rows)
# Should handle gracefully
assert isinstance(header, list)
assert isinstance(data, list)
class TestCountExpectedRows:
"""Tests for _count_expected_rows helper."""
def test_counts_product_numbers(self, handler):
"""Test counting product numbers."""
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
count = handler._count_expected_rows(columns)
assert count == 3
def test_counts_quantities(self, handler):
"""Test counting quantity patterns."""
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
count = handler._count_expected_rows(columns)
assert count == 4
def test_returns_max_count(self, handler):
"""Test returns maximum count across columns."""
columns = [
"Produktnr 1234567 1234568", # 2 products
"Antal 5ST 10ST 15ST", # 3 quantities
]
count = handler._count_expected_rows(columns)
assert count == 3
def test_empty_columns_return_zero(self, handler):
"""Test empty columns return 0."""
columns = ["", None, "Short"]
count = handler._count_expected_rows(columns)
assert count == 0
class TestSplitCellContentForRows:
"""Tests for _split_cell_content_for_rows helper."""
def test_splits_by_product_numbers(self, handler):
"""Test splitting by product numbers with expected count."""
cell = "Produktnr 1234567 1234568"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3 # header + 2 values
assert result[0] == "Produktnr"
assert "1234567" in result[1]
assert "1234568" in result[2]
def test_splits_by_quantities(self, handler):
"""Test splitting by quantity patterns."""
cell = "Antal 5ST 10ST"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3 # header + 2 values
assert result[0] == "Antal"
def test_splits_discount_totalsumma(self, handler):
"""Test splitting discount+totalsumma columns."""
cell = "Rabatt i% Totalsumma 686,88 123,45"
result = handler._split_cell_content_for_rows(cell, 2)
assert result[0] == "Totalsumma"
assert "686,88" in result[1]
assert "123,45" in result[2]
def test_splits_by_prices(self, handler):
"""Test splitting by price patterns."""
cell = "Pris 127,20 234,56"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) >= 2
def test_fallback_returns_original(self, handler):
"""Test fallback returns original cell."""
cell = "No patterns here"
result = handler._split_cell_content_for_rows(cell, 2)
assert result == ["No patterns here"]
def test_product_number_with_description(self, handler):
"""Test product numbers include trailing description text."""
cell = "Art 1234567 Widget A 1234568 Widget B"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3
class TestSplitCellContent:
"""Tests for split_cell_content method."""
def test_splits_by_product_numbers(self, handler):
"""Test splitting by multiple product numbers."""
cell = "Produktnr 1234567 1234568 1234569"
result = handler.split_cell_content(cell)
assert result[0] == "Produktnr"
assert "1234567" in result
assert "1234568" in result
assert "1234569" in result
def test_splits_by_quantities(self, handler):
"""Test splitting by multiple quantities."""
cell = "Antal 6ST 6ST 1ST"
result = handler.split_cell_content(cell)
assert result[0] == "Antal"
assert len(result) >= 3
def test_splits_discount_amount_interleaved(self, handler):
"""Test splitting interleaved discount+amount patterns."""
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
result = handler.split_cell_content(cell)
# Should extract amounts (3+ digit numbers with decimals)
assert result[0] == "Totalsumma"
assert "686,88" in result
assert "123,45" in result
def test_splits_by_prices(self, handler):
"""Test splitting by prices."""
cell = "Pris 127,20 127,20 159,20"
result = handler.split_cell_content(cell)
assert result[0] == "Pris"
def test_single_value_not_split(self, handler):
"""Test single value is not split."""
cell = "Single value"
result = handler.split_cell_content(cell)
assert result == ["Single value"]
def test_single_product_not_split(self, handler):
"""Test single product number is not split."""
cell = "Produktnr 1234567"
result = handler.split_cell_content(cell)
assert result == ["Produktnr 1234567"]
class TestHasMergedHeader:
"""Tests for has_merged_header method."""
def test_none_header_returns_false(self, handler):
"""Test None header returns False."""
assert handler.has_merged_header(None) is False
def test_empty_header_returns_false(self, handler):
"""Test empty header returns False."""
assert handler.has_merged_header([]) is False
def test_multiple_non_empty_cells_returns_false(self, handler):
"""Test multiple non-empty cells returns False."""
header = ["Beskrivning", "Antal", "Belopp"]
assert handler.has_merged_header(header) is False
def test_single_cell_with_keywords_returns_true(self, handler):
"""Test single cell with multiple keywords returns True."""
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
assert handler.has_merged_header(header) is True
def test_single_cell_one_keyword_returns_false(self, handler):
"""Test single cell with only one keyword returns False."""
header = ["Beskrivning only"]
assert handler.has_merged_header(header) is False
def test_ignores_empty_trailing_cells(self, handler):
"""Test ignores empty trailing cells."""
header = ["Specifikation Hyra Avdrag", "", "", ""]
assert handler.has_merged_header(header) is True
class TestExtractFromMergedCells:
"""Tests for extract_from_merged_cells method."""
def test_extracts_single_amount(self, handler):
"""Test extracting a single amount."""
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "8159"
assert items[0].is_deduction is False
assert items[0].article_number == "0218103-1201"
assert items[0].description == "2 rum och kök"
def test_extracts_deduction(self, handler):
"""Test extracting a deduction (negative amount)."""
header = ["Specifikation"]
rows = [["", "", "", "-2000"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "-2000"
assert items[0].is_deduction is True
# First item (row_index=0) gets description from header, not "Avdrag"
# "Avdrag" is only set for subsequent deduction items
assert items[0].description is None
def test_extracts_multiple_amounts_same_row(self, handler):
"""Test extracting multiple amounts from same row."""
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
rows = [["", "", "", "8159 -2000"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 2
assert items[0].amount == "8159"
assert items[1].amount == "-2000"
def test_extracts_amounts_from_multiple_rows(self, handler):
"""Test extracting amounts from multiple rows."""
header = ["Specifikation"]
rows = [
["", "", "", "8159"],
["", "", "", "-2000"],
]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 2
def test_skips_small_amounts(self, handler):
"""Test skipping small amounts below threshold."""
header = ["Specifikation"]
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0
def test_skips_empty_rows(self, handler):
"""Test skipping empty rows."""
header = ["Specifikation"]
rows = [["", "", "", ""]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0
def test_handles_swedish_format_with_spaces(self, handler):
"""Test handling Swedish number format with spaces."""
header = ["Specifikation"]
rows = [["", "", "", "8 159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "8159"
def test_confidence_is_lower_for_merged(self, handler):
"""Test confidence is 0.7 for merged cell extraction."""
header = ["Specifikation"]
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert items[0].confidence == 0.7
def test_empty_header_still_extracts(self, handler):
"""Test extraction works with empty header."""
header = []
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].description is None
assert items[0].article_number is None
def test_row_index_increments(self, handler):
"""Test row_index increments for each item."""
header = ["Specifikation"]
# Use separate rows to avoid regex grouping issues
rows = [
["", "", "", "8159"],
["", "", "", "5000"],
["", "", "", "-2000"],
]
items = handler.extract_from_merged_cells(header, rows)
# Should have 3 items from 3 rows
assert len(items) == 3
assert items[0].row_index == 0
assert items[1].row_index == 1
assert items[2].row_index == 2
class TestMinAmountThreshold:
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
def test_threshold_value(self):
"""Test the threshold constant value."""
assert MIN_AMOUNT_THRESHOLD == 100
def test_amounts_at_threshold_included(self, handler):
"""Test amounts exactly at threshold are included."""
header = ["Specifikation"]
rows = [["", "", "", "100"]] # Exactly at threshold
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "100"
def test_amounts_below_threshold_excluded(self, handler):
"""Test amounts below threshold are excluded."""
header = ["Specifikation"]
rows = [["", "", "", "99"]] # Below threshold
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0

157
tests/table/test_models.py Normal file
View File

@@ -0,0 +1,157 @@
"""
Tests for Line Items Data Models
Tests for LineItem and LineItemsResult dataclasses.
"""
import pytest
from backend.table.models import LineItem, LineItemsResult
class TestLineItem:
"""Tests for LineItem dataclass."""
def test_default_values(self):
"""Test default values for optional fields."""
item = LineItem(row_index=0)
assert item.row_index == 0
assert item.description is None
assert item.quantity is None
assert item.unit is None
assert item.unit_price is None
assert item.amount is None
assert item.article_number is None
assert item.vat_rate is None
assert item.is_deduction is False
assert item.confidence == 0.9
def test_custom_confidence(self):
"""Test setting custom confidence."""
item = LineItem(row_index=0, confidence=0.7)
assert item.confidence == 0.7
def test_is_deduction_true(self):
"""Test is_deduction flag."""
item = LineItem(row_index=0, is_deduction=True)
assert item.is_deduction is True
class TestLineItemsResult:
"""Tests for LineItemsResult dataclass."""
def test_total_amount_empty_items(self):
"""Test total_amount returns None for empty items."""
result = LineItemsResult(items=[], header_row=[], raw_html="")
assert result.total_amount is None
def test_total_amount_single_item(self):
"""Test total_amount with single item."""
items = [LineItem(row_index=0, amount="100,00")]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "100,00"
def test_total_amount_multiple_items(self):
"""Test total_amount with multiple items."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount="200,50"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "300,50"
def test_total_amount_with_deduction(self):
"""Test total_amount includes negative amounts (deductions)."""
items = [
LineItem(row_index=0, amount="1000,00"),
LineItem(row_index=1, amount="-200,00", is_deduction=True),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "800,00"
def test_total_amount_swedish_format_with_spaces(self):
"""Test total_amount handles Swedish format with spaces."""
items = [
LineItem(row_index=0, amount="1 234,56"),
LineItem(row_index=1, amount="2 000,00"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "3 234,56"
def test_total_amount_invalid_amount_skipped(self):
"""Test total_amount skips invalid amounts."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount="invalid"),
LineItem(row_index=2, amount="200,00"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
# Invalid amount is skipped
assert result.total_amount == "300,00"
def test_total_amount_none_amount_skipped(self):
"""Test total_amount skips None amounts."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount=None),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "100,00"
def test_total_amount_all_invalid_returns_none(self):
"""Test total_amount returns None when all amounts are invalid."""
items = [
LineItem(row_index=0, amount="invalid"),
LineItem(row_index=1, amount="also invalid"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount is None
def test_total_amount_large_numbers(self):
"""Test total_amount handles large numbers."""
items = [
LineItem(row_index=0, amount="123 456,78"),
LineItem(row_index=1, amount="876 543,22"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "1 000 000,00"
def test_total_amount_decimal_precision(self):
"""Test total_amount maintains decimal precision."""
items = [
LineItem(row_index=0, amount="0,01"),
LineItem(row_index=1, amount="0,02"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "0,03"
def test_is_reversed_default_false(self):
"""Test is_reversed defaults to False."""
result = LineItemsResult(items=[], header_row=[], raw_html="")
assert result.is_reversed is False
def test_is_reversed_can_be_set(self):
"""Test is_reversed can be set to True."""
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
assert result.is_reversed is True
def test_header_row_preserved(self):
"""Test header_row is preserved."""
header = ["Beskrivning", "Antal", "Belopp"]
result = LineItemsResult(items=[], header_row=header, raw_html="")
assert result.header_row == header
def test_raw_html_preserved(self):
"""Test raw_html is preserved."""
html = "<table><tr><td>Test</td></tr></table>"
result = LineItemsResult(items=[], header_row=[], raw_html=html)
assert result.raw_html == html

View File

@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
assert len(results) == 1
assert results[0].cells == [] # Empty cells list
assert results[0].html == "<table></table>"
def test_parse_paddlex_result_with_dict_ocr_data(self):
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
mock_pipeline = MagicMock()
mock_result = {
"table_res_list": [
{
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
"table_ocr_pred": {
"rec_texts": ["A", "B"],
"rec_scores": [0.99, 0.98],
},
}
],
"parsing_res_list": [
{"label": "table", "bbox": [10, 20, 200, 300]},
],
}
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert len(results[0].cells) == 2
assert results[0].cells[0]["text"] == "A"
assert results[0].cells[1]["text"] == "B"
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
mock_pipeline = MagicMock()
mock_result = {
"table_res_list": [
{
"cell_box_list": [[0, 0, 50, 20]],
"pred_html": "<table><tr><td>A</td></tr></table>",
"table_ocr_pred": ["A"],
}
],
"parsing_res_list": [
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
],
}
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
# Should use default bbox [0,0,0,0] when not found
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
class TestIteratorResults:
"""Tests for iterator/generator result handling."""
def test_handles_iterator_results(self):
"""Test handling of iterator results from pipeline."""
mock_pipeline = MagicMock()
# Return a generator instead of list
def result_generator():
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.html = "<table></table>"
element.score = 0.9
element.cells = []
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
yield mock_result
mock_pipeline.predict.return_value = result_generator()
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
def test_handles_failed_iterator_conversion(self):
"""Test handling when iterator conversion fails."""
mock_pipeline = MagicMock()
# Create an object that has __iter__ but fails when converted to list
class FailingIterator:
def __iter__(self):
raise RuntimeError("Iterator failed")
mock_pipeline.predict.return_value = FailingIterator()
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
# Should return empty list, not raise
assert results == []
class TestPathConversion:
"""Tests for path handling."""
def test_converts_path_object_to_string(self):
"""Test that Path objects are converted to strings."""
from pathlib import Path
mock_pipeline = MagicMock()
mock_pipeline.predict.return_value = []
detector = TableDetector(pipeline=mock_pipeline)
path = Path("/some/path/to/image.png")
detector.detect(path)
# Should be called with string, not Path
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
class TestHtmlExtraction:
"""Tests for HTML extraction from different element formats."""
def test_extracts_html_from_res_dict(self):
"""Test extracting HTML from element.res dictionary."""
mock_pipeline = MagicMock()
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
element.score = 0.9
element.cells = []
# Remove direct html attribute
del element.html
del element.table_html
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
def test_returns_empty_html_when_not_found(self):
"""Test empty HTML when no html attribute found."""
mock_pipeline = MagicMock()
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.score = 0.9
element.cells = []
# Remove all html attributes
del element.html
del element.table_html
del element.res
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert results[0].html == ""
class TestTableTypeDetection:
"""Tests for table type detection."""
def test_detects_borderless_table(self):
"""Test detection of borderless table type via _get_table_type."""
detector = TableDetector()
# Create mock element with borderless label
element = MagicMock()
element.label = "borderless_table"
result = detector._get_table_type(element)
assert result == "wireless"
def test_detects_wireless_table_label(self):
"""Test detection of wireless table type."""
detector = TableDetector()
element = MagicMock()
element.label = "wireless_table"
result = detector._get_table_type(element)
assert result == "wireless"
def test_defaults_to_wired_table(self):
"""Test default table type is wired."""
detector = TableDetector()
element = MagicMock()
element.label = "table"
result = detector._get_table_type(element)
assert result == "wired"
def test_type_attribute_instead_of_label(self):
"""Test table type detection using type attribute."""
detector = TableDetector()
element = MagicMock()
element.type = "wireless"
del element.label # Remove label
result = detector._get_table_type(element)
assert result == "wireless"
class TestPipelineRuntimeError:
"""Tests for pipeline runtime errors."""
def test_raises_runtime_error_when_pipeline_none(self):
"""Test RuntimeError when pipeline is None during detect."""
detector = TableDetector()
detector._initialized = True # Bypass lazy init
detector._pipeline = None
image = np.zeros((100, 100, 3), dtype=np.uint8)
with pytest.raises(RuntimeError) as exc_info:
detector.detect(image)
assert "not initialized" in str(exc_info.value).lower()

View File

@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
rows = extractor._group_by_row(elements)
assert len(rows) == 2
def test_group_by_row_varying_heights_uses_average(self, extractor):
"""Test grouping handles varying element heights using dynamic average.
When elements have varying heights, the row center should be recalculated
as new elements are added, preventing tall elements from being incorrectly
grouped with the next row.
"""
# First element: small height, center_y = 105
# Second element: tall, center_y = 115 (but should still be same row)
# Third element: next row, center_y = 160
elements = [
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
]
rows = extractor._group_by_row(elements)
# With dynamic average, both first and second element should be same row
assert len(rows) == 2
assert len(rows[0]) == 2 # Short and Tall item
assert len(rows[1]) == 1 # Next row
def test_group_by_row_empty_input(self, extractor):
"""Test grouping with empty input returns empty list."""
rows = extractor._group_by_row([])
assert rows == []
def test_looks_like_line_item_with_amount(self, extractor):
"""Test line item detection with amount."""
row = [
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
assert len(elements) == 4
class TestExceptionHandling:
"""Tests for exception handling in text element extraction."""
def test_extract_text_elements_handles_missing_bbox(self):
"""Test that missing bbox is handled gracefully."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "text": "No bbox"}, # Missing bbox
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
# Should only have 1 valid element
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_invalid_bbox(self):
"""Test that invalid bbox (less than 4 values) is handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_none_text(self):
"""Test that None text is handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_empty_string(self):
"""Test that empty string text is skipped."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_malformed_element(self):
"""Test that completely malformed elements are handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
"not a dict", # String instead of dict
123, # Number instead of dict
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
class TestConvertTextLineItem:
"""Tests for convert_text_line_item function."""