refactor: split line_items_extractor into smaller modules with comprehensive tests
- Extract models.py (LineItem, LineItemsResult dataclasses) - Extract html_table_parser.py (ColumnMapper, HtmlTableParser) - Extract merged_cell_handler.py (MergedCellHandler for PP-StructureV3 merged cells) - Reduce line_items_extractor.py from 971 to 396 lines - Add constants for magic numbers (MIN_AMOUNT_THRESHOLD, ROW_GROUPING_THRESHOLD, etc.) - Fix row grouping algorithm in text_line_items_extractor.py - Demote INFO logs to DEBUG level in structure_detector.py - Add 209 tests achieving 85%+ coverage on main modules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create mock table detection result
|
||||
# Create mock table detection result with proper thead/tbody structure
|
||||
mock_table = MagicMock(spec=TableDetectionResult)
|
||||
mock_table.html = """
|
||||
<table>
|
||||
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
|
||||
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
|
||||
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
|
||||
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
|
||||
assert len(result.items) >= 1
|
||||
|
||||
|
||||
class TestPdfPathValidation:
|
||||
"""Tests for PDF path validation."""
|
||||
|
||||
def test_detect_tables_with_nonexistent_path(self):
|
||||
"""Test that non-existent PDF path returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create detector and call _detect_tables_with_parsing with non-existent path
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, "nonexistent.pdf"
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_with_directory_path(self, tmp_path):
|
||||
"""Test that directory path (not file) returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
|
||||
# tmp_path is a directory, not a file
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(tmp_path)
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_validates_file_exists(self, tmp_path):
|
||||
"""Test path validation for file existence.
|
||||
|
||||
This test verifies that the method correctly validates the path exists
|
||||
and is a file before attempting to process it.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create a real file path that exists
|
||||
fake_pdf = tmp_path / "test.pdf"
|
||||
fake_pdf.write_bytes(b"not a real pdf")
|
||||
|
||||
# Mock render_pdf_to_images to avoid actual PDF processing
|
||||
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
|
||||
# Return empty iterator - simulates file exists but no pages
|
||||
mock_render.return_value = iter([])
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
mock_detector._ensure_initialized = MagicMock()
|
||||
mock_detector._pipeline = MagicMock()
|
||||
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(fake_pdf)
|
||||
)
|
||||
|
||||
# render_pdf_to_images was called (path validation passed)
|
||||
mock_render.assert_called_once()
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
|
||||
assert result.items[0].is_deduction is False
|
||||
assert result.items[1].amount == "-2000"
|
||||
assert result.items[1].is_deduction is True
|
||||
|
||||
|
||||
class TestTextFallbackExtraction:
|
||||
"""Tests for text-based fallback extraction."""
|
||||
|
||||
def test_text_fallback_disabled_by_default(self):
|
||||
"""Test text fallback can be disabled."""
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
assert extractor.enable_text_fallback is False
|
||||
|
||||
def test_text_fallback_enabled_by_default(self):
|
||||
"""Test text fallback is enabled by default."""
|
||||
extractor = LineItemsExtractor()
|
||||
assert extractor.enable_text_fallback is True
|
||||
|
||||
def test_try_text_fallback_with_valid_parsing_res(self):
|
||||
"""Test text fallback with valid parsing results."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import (
|
||||
TextLineItemsExtractor,
|
||||
TextLineItem,
|
||||
TextLineItemsResult,
|
||||
)
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Mock parsing_res_list with text elements
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
|
||||
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
|
||||
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
|
||||
]
|
||||
|
||||
# Create mock text extraction result
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
|
||||
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
|
||||
result = extractor._try_text_fallback(parsing_res)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].description == "Product A"
|
||||
assert result.items[1].description == "Product B"
|
||||
|
||||
def test_try_text_fallback_returns_none_on_failure(self):
|
||||
"""Test text fallback returns None when extraction fails."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
|
||||
result = extractor._try_text_fallback([])
|
||||
assert result is None
|
||||
|
||||
def test_extract_from_pdf_uses_text_fallback(self):
|
||||
"""Test extract_from_pdf uses text fallback when no tables found."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=True)
|
||||
|
||||
# Mock _detect_tables_with_parsing to return no tables but parsing_res
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product", amount="100,00"),
|
||||
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Text fallback should be called
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
def test_extract_from_pdf_skips_fallback_when_disabled(self):
|
||||
"""Test extract_from_pdf skips text fallback when disabled."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Should return None, not use text fallback
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVerticallyMergedCellExtraction:
|
||||
"""Tests for vertically merged cell extraction."""
|
||||
|
||||
def test_detects_vertically_merged_cells(self):
|
||||
"""Test detection of vertically merged cells in rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Rows with multiple product numbers in single cell
|
||||
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
|
||||
assert extractor._has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_splits_vertically_merged_rows(self):
|
||||
"""Test splitting vertically merged rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = extractor._split_merged_rows(rows)
|
||||
|
||||
# Should split into header + data rows
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestDeductionDetection:
|
||||
"""Tests for deduction/discount detection."""
|
||||
|
||||
def test_detects_deduction_by_keyword_avdrag(self):
|
||||
"""Test detection of deduction by 'avdrag' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_keyword_rabatt(self):
|
||||
"""Test detection of deduction by 'rabatt' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_negative_amount(self):
|
||||
"""Test detection of deduction by negative amount."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Some credit</td><td>-250,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_normal_item_not_deduction(self):
|
||||
"""Test normal item is not marked as deduction."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Normal product</td><td>500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is False
|
||||
|
||||
|
||||
class TestHeaderDetection:
|
||||
"""Tests for header row detection."""
|
||||
|
||||
def test_detect_header_at_bottom(self):
|
||||
"""Test detecting header at bottom of table (reversed)."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at bottom
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 2
|
||||
assert is_at_end is True
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_detect_header_at_top(self):
|
||||
"""Test detecting header at top of table."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at top
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 0
|
||||
assert is_at_end is False
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_no_header_detected(self):
|
||||
"""Test when no header is detected."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == -1
|
||||
assert header == []
|
||||
assert is_at_end is False
|
||||
|
||||
448
tests/table/test_merged_cell_handler.py
Normal file
448
tests/table/test_merged_cell_handler.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Tests for Merged Cell Handler
|
||||
|
||||
Tests the detection and extraction of data from tables with merged cells,
|
||||
a common issue with PP-StructureV3 OCR output.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
|
||||
from backend.table.html_table_parser import ColumnMapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Create a MergedCellHandler with default ColumnMapper."""
|
||||
return MergedCellHandler(ColumnMapper())
|
||||
|
||||
|
||||
class TestHasVerticallyMergedCells:
|
||||
"""Tests for has_vertically_merged_cells detection."""
|
||||
|
||||
def test_empty_rows_returns_false(self, handler):
|
||||
"""Test empty rows returns False."""
|
||||
assert handler.has_vertically_merged_cells([]) is False
|
||||
|
||||
def test_short_cells_ignored(self, handler):
|
||||
"""Test cells shorter than 20 chars are ignored."""
|
||||
rows = [["Short cell", "Also short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_product_numbers(self, handler):
|
||||
"""Test detection of multiple 7-digit product numbers in cell."""
|
||||
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_product_number_not_merged(self, handler):
|
||||
"""Test single product number doesn't trigger detection."""
|
||||
rows = [["Produktnr 1457280 and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_prices(self, handler):
|
||||
"""Test detection of 3+ prices in cell (Swedish format)."""
|
||||
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_two_prices_not_merged(self, handler):
|
||||
"""Test two prices doesn't trigger detection (needs 3+)."""
|
||||
rows = [["Pris 127,20 234,56 total amount here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_quantities(self, handler):
|
||||
"""Test detection of multiple quantity patterns."""
|
||||
rows = [["Antal 6ST 6ST 1ST more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_quantity_not_merged(self, handler):
|
||||
"""Test single quantity doesn't trigger detection."""
|
||||
rows = [["Antal 6ST and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_empty_cell_skipped(self, handler):
|
||||
"""Test empty cells are skipped."""
|
||||
rows = [["", None, "Valid but short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_multiple_rows_checked(self, handler):
|
||||
"""Test all rows are checked for merged content."""
|
||||
rows = [
|
||||
["Normal row with nothing special"],
|
||||
["Produktnr 1457280 1457281 1060381 merged content"],
|
||||
]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
|
||||
class TestSplitMergedRows:
|
||||
"""Tests for split_merged_rows method."""
|
||||
|
||||
def test_empty_rows_returns_empty(self, handler):
|
||||
"""Test empty rows returns empty result."""
|
||||
header, data = handler.split_merged_rows([])
|
||||
assert header == []
|
||||
assert data == []
|
||||
|
||||
def test_all_empty_rows_returns_original(self, handler):
|
||||
"""Test all empty rows returns original rows."""
|
||||
rows = [["", ""], ["", ""]]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting rows by product numbers."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
assert len(header) == 3
|
||||
assert header[0] == "Produktnr"
|
||||
assert len(data) == 2
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting rows by quantity patterns."""
|
||||
rows = [
|
||||
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should detect 2 quantities and split accordingly
|
||||
assert len(data) >= 1
|
||||
|
||||
def test_single_row_not_split(self, handler):
|
||||
"""Test single item row is not split."""
|
||||
rows = [
|
||||
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Only 1 product number, so expected_rows <= 1
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_handles_missing_columns(self, handler):
|
||||
"""Test handles rows with different column counts."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", ""],
|
||||
["Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestCountExpectedRows:
|
||||
"""Tests for _count_expected_rows helper."""
|
||||
|
||||
def test_counts_product_numbers(self, handler):
|
||||
"""Test counting product numbers."""
|
||||
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_counts_quantities(self, handler):
|
||||
"""Test counting quantity patterns."""
|
||||
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 4
|
||||
|
||||
def test_returns_max_count(self, handler):
|
||||
"""Test returns maximum count across columns."""
|
||||
columns = [
|
||||
"Produktnr 1234567 1234568", # 2 products
|
||||
"Antal 5ST 10ST 15ST", # 3 quantities
|
||||
]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_empty_columns_return_zero(self, handler):
|
||||
"""Test empty columns return 0."""
|
||||
columns = ["", None, "Short"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestSplitCellContentForRows:
|
||||
"""Tests for _split_cell_content_for_rows helper."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by product numbers with expected count."""
|
||||
cell = "Produktnr 1234567 1234568"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result[1]
|
||||
assert "1234568" in result[2]
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by quantity patterns."""
|
||||
cell = "Antal 5ST 10ST"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Antal"
|
||||
|
||||
def test_splits_discount_totalsumma(self, handler):
|
||||
"""Test splitting discount+totalsumma columns."""
|
||||
cell = "Rabatt i% Totalsumma 686,88 123,45"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result[1]
|
||||
assert "123,45" in result[2]
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by price patterns."""
|
||||
cell = "Pris 127,20 234,56"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_fallback_returns_original(self, handler):
|
||||
"""Test fallback returns original cell."""
|
||||
cell = "No patterns here"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result == ["No patterns here"]
|
||||
|
||||
def test_product_number_with_description(self, handler):
|
||||
"""Test product numbers include trailing description text."""
|
||||
cell = "Art 1234567 Widget A 1234568 Widget B"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
class TestSplitCellContent:
|
||||
"""Tests for split_cell_content method."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by multiple product numbers."""
|
||||
cell = "Produktnr 1234567 1234568 1234569"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result
|
||||
assert "1234568" in result
|
||||
assert "1234569" in result
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by multiple quantities."""
|
||||
cell = "Antal 6ST 6ST 1ST"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Antal"
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_splits_discount_amount_interleaved(self, handler):
|
||||
"""Test splitting interleaved discount+amount patterns."""
|
||||
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
# Should extract amounts (3+ digit numbers with decimals)
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result
|
||||
assert "123,45" in result
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by prices."""
|
||||
cell = "Pris 127,20 127,20 159,20"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Pris"
|
||||
|
||||
def test_single_value_not_split(self, handler):
|
||||
"""Test single value is not split."""
|
||||
cell = "Single value"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Single value"]
|
||||
|
||||
def test_single_product_not_split(self, handler):
|
||||
"""Test single product number is not split."""
|
||||
cell = "Produktnr 1234567"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Produktnr 1234567"]
|
||||
|
||||
|
||||
class TestHasMergedHeader:
|
||||
"""Tests for has_merged_header method."""
|
||||
|
||||
def test_none_header_returns_false(self, handler):
|
||||
"""Test None header returns False."""
|
||||
assert handler.has_merged_header(None) is False
|
||||
|
||||
def test_empty_header_returns_false(self, handler):
|
||||
"""Test empty header returns False."""
|
||||
assert handler.has_merged_header([]) is False
|
||||
|
||||
def test_multiple_non_empty_cells_returns_false(self, handler):
|
||||
"""Test multiple non-empty cells returns False."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_single_cell_with_keywords_returns_true(self, handler):
|
||||
"""Test single cell with multiple keywords returns True."""
|
||||
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
def test_single_cell_one_keyword_returns_false(self, handler):
|
||||
"""Test single cell with only one keyword returns False."""
|
||||
header = ["Beskrivning only"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_ignores_empty_trailing_cells(self, handler):
|
||||
"""Test ignores empty trailing cells."""
|
||||
header = ["Specifikation Hyra Avdrag", "", "", ""]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
|
||||
class TestExtractFromMergedCells:
|
||||
"""Tests for extract_from_merged_cells method."""
|
||||
|
||||
def test_extracts_single_amount(self, handler):
|
||||
"""Test extracting a single amount."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
assert items[0].is_deduction is False
|
||||
assert items[0].article_number == "0218103-1201"
|
||||
assert items[0].description == "2 rum och kök"
|
||||
|
||||
def test_extracts_deduction(self, handler):
|
||||
"""Test extracting a deduction (negative amount)."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "-2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "-2000"
|
||||
assert items[0].is_deduction is True
|
||||
# First item (row_index=0) gets description from header, not "Avdrag"
|
||||
# "Avdrag" is only set for subsequent deduction items
|
||||
assert items[0].description is None
|
||||
|
||||
def test_extracts_multiple_amounts_same_row(self, handler):
|
||||
"""Test extracting multiple amounts from same row."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159 -2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0].amount == "8159"
|
||||
assert items[1].amount == "-2000"
|
||||
|
||||
def test_extracts_amounts_from_multiple_rows(self, handler):
|
||||
"""Test extracting amounts from multiple rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
def test_skips_small_amounts(self, handler):
|
||||
"""Test skipping small amounts below threshold."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_skips_empty_rows(self, handler):
|
||||
"""Test skipping empty rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", ""]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_handles_swedish_format_with_spaces(self, handler):
|
||||
"""Test handling Swedish number format with spaces."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8 159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
|
||||
def test_confidence_is_lower_for_merged(self, handler):
|
||||
"""Test confidence is 0.7 for merged cell extraction."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert items[0].confidence == 0.7
|
||||
|
||||
def test_empty_header_still_extracts(self, handler):
|
||||
"""Test extraction works with empty header."""
|
||||
header = []
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].description is None
|
||||
assert items[0].article_number is None
|
||||
|
||||
def test_row_index_increments(self, handler):
|
||||
"""Test row_index increments for each item."""
|
||||
header = ["Specifikation"]
|
||||
# Use separate rows to avoid regex grouping issues
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "5000"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
# Should have 3 items from 3 rows
|
||||
assert len(items) == 3
|
||||
assert items[0].row_index == 0
|
||||
assert items[1].row_index == 1
|
||||
assert items[2].row_index == 2
|
||||
|
||||
|
||||
class TestMinAmountThreshold:
|
||||
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
|
||||
|
||||
def test_threshold_value(self):
|
||||
"""Test the threshold constant value."""
|
||||
assert MIN_AMOUNT_THRESHOLD == 100
|
||||
|
||||
def test_amounts_at_threshold_included(self, handler):
|
||||
"""Test amounts exactly at threshold are included."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "100"]] # Exactly at threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "100"
|
||||
|
||||
def test_amounts_below_threshold_excluded(self, handler):
|
||||
"""Test amounts below threshold are excluded."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "99"]] # Below threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
157
tests/table/test_models.py
Normal file
157
tests/table/test_models.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Tests for Line Items Data Models
|
||||
|
||||
Tests for LineItem and LineItemsResult dataclasses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.models import LineItem, LineItemsResult
|
||||
|
||||
|
||||
class TestLineItem:
|
||||
"""Tests for LineItem dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values for optional fields."""
|
||||
item = LineItem(row_index=0)
|
||||
|
||||
assert item.row_index == 0
|
||||
assert item.description is None
|
||||
assert item.quantity is None
|
||||
assert item.unit is None
|
||||
assert item.unit_price is None
|
||||
assert item.amount is None
|
||||
assert item.article_number is None
|
||||
assert item.vat_rate is None
|
||||
assert item.is_deduction is False
|
||||
assert item.confidence == 0.9
|
||||
|
||||
def test_custom_confidence(self):
|
||||
"""Test setting custom confidence."""
|
||||
item = LineItem(row_index=0, confidence=0.7)
|
||||
assert item.confidence == 0.7
|
||||
|
||||
def test_is_deduction_true(self):
|
||||
"""Test is_deduction flag."""
|
||||
item = LineItem(row_index=0, is_deduction=True)
|
||||
assert item.is_deduction is True
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
def test_total_amount_empty_items(self):
|
||||
"""Test total_amount returns None for empty items."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_single_item(self):
|
||||
"""Test total_amount with single item."""
|
||||
items = [LineItem(row_index=0, amount="100,00")]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_multiple_items(self):
|
||||
"""Test total_amount with multiple items."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="200,50"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "300,50"
|
||||
|
||||
def test_total_amount_with_deduction(self):
|
||||
"""Test total_amount includes negative amounts (deductions)."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1000,00"),
|
||||
LineItem(row_index=1, amount="-200,00", is_deduction=True),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "800,00"
|
||||
|
||||
def test_total_amount_swedish_format_with_spaces(self):
|
||||
"""Test total_amount handles Swedish format with spaces."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1 234,56"),
|
||||
LineItem(row_index=1, amount="2 000,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "3 234,56"
|
||||
|
||||
def test_total_amount_invalid_amount_skipped(self):
|
||||
"""Test total_amount skips invalid amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="invalid"),
|
||||
LineItem(row_index=2, amount="200,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
# Invalid amount is skipped
|
||||
assert result.total_amount == "300,00"
|
||||
|
||||
def test_total_amount_none_amount_skipped(self):
|
||||
"""Test total_amount skips None amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount=None),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_all_invalid_returns_none(self):
|
||||
"""Test total_amount returns None when all amounts are invalid."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="invalid"),
|
||||
LineItem(row_index=1, amount="also invalid"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_large_numbers(self):
|
||||
"""Test total_amount handles large numbers."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="123 456,78"),
|
||||
LineItem(row_index=1, amount="876 543,22"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "1 000 000,00"
|
||||
|
||||
def test_total_amount_decimal_precision(self):
|
||||
"""Test total_amount maintains decimal precision."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="0,01"),
|
||||
LineItem(row_index=1, amount="0,02"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "0,03"
|
||||
|
||||
def test_is_reversed_default_false(self):
|
||||
"""Test is_reversed defaults to False."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.is_reversed is False
|
||||
|
||||
def test_is_reversed_can_be_set(self):
|
||||
"""Test is_reversed can be set to True."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
|
||||
assert result.is_reversed is True
|
||||
|
||||
def test_header_row_preserved(self):
|
||||
"""Test header_row is preserved."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
result = LineItemsResult(items=[], header_row=header, raw_html="")
|
||||
assert result.header_row == header
|
||||
|
||||
def test_raw_html_preserved(self):
|
||||
"""Test raw_html is preserved."""
|
||||
html = "<table><tr><td>Test</td></tr></table>"
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html=html)
|
||||
assert result.raw_html == html
|
||||
@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
|
||||
assert len(results) == 1
|
||||
assert results[0].cells == [] # Empty cells list
|
||||
assert results[0].html == "<table></table>"
|
||||
|
||||
def test_parse_paddlex_result_with_dict_ocr_data(self):
|
||||
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
|
||||
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
|
||||
"table_ocr_pred": {
|
||||
"rec_texts": ["A", "B"],
|
||||
"rec_scores": [0.99, 0.98],
|
||||
},
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "table", "bbox": [10, 20, 200, 300]},
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0].cells) == 2
|
||||
assert results[0].cells[0]["text"] == "A"
|
||||
assert results[0].cells[1]["text"] == "B"
|
||||
|
||||
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
|
||||
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20]],
|
||||
"pred_html": "<table><tr><td>A</td></tr></table>",
|
||||
"table_ocr_pred": ["A"],
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should use default bbox [0,0,0,0] when not found
|
||||
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
|
||||
|
||||
|
||||
class TestIteratorResults:
|
||||
"""Tests for iterator/generator result handling."""
|
||||
|
||||
def test_handles_iterator_results(self):
|
||||
"""Test handling of iterator results from pipeline."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Return a generator instead of list
|
||||
def result_generator():
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.html = "<table></table>"
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
yield mock_result
|
||||
|
||||
mock_pipeline.predict.return_value = result_generator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_handles_failed_iterator_conversion(self):
|
||||
"""Test handling when iterator conversion fails."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Create an object that has __iter__ but fails when converted to list
|
||||
class FailingIterator:
|
||||
def __iter__(self):
|
||||
raise RuntimeError("Iterator failed")
|
||||
|
||||
mock_pipeline.predict.return_value = FailingIterator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
# Should return empty list, not raise
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestPathConversion:
|
||||
"""Tests for path handling."""
|
||||
|
||||
def test_converts_path_object_to_string(self):
|
||||
"""Test that Path objects are converted to strings."""
|
||||
from pathlib import Path
|
||||
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.predict.return_value = []
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
path = Path("/some/path/to/image.png")
|
||||
|
||||
detector.detect(path)
|
||||
|
||||
# Should be called with string, not Path
|
||||
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
|
||||
|
||||
|
||||
class TestHtmlExtraction:
|
||||
"""Tests for HTML extraction from different element formats."""
|
||||
|
||||
def test_extracts_html_from_res_dict(self):
|
||||
"""Test extracting HTML from element.res dictionary."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove direct html attribute
|
||||
del element.html
|
||||
del element.table_html
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
|
||||
|
||||
def test_returns_empty_html_when_not_found(self):
|
||||
"""Test empty HTML when no html attribute found."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove all html attributes
|
||||
del element.html
|
||||
del element.table_html
|
||||
del element.res
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == ""
|
||||
|
||||
|
||||
class TestTableTypeDetection:
|
||||
"""Tests for table type detection."""
|
||||
|
||||
def test_detects_borderless_table(self):
|
||||
"""Test detection of borderless table type via _get_table_type."""
|
||||
detector = TableDetector()
|
||||
|
||||
# Create mock element with borderless label
|
||||
element = MagicMock()
|
||||
element.label = "borderless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_detects_wireless_table_label(self):
|
||||
"""Test detection of wireless table type."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "wireless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_defaults_to_wired_table(self):
|
||||
"""Test default table type is wired."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wired"
|
||||
|
||||
def test_type_attribute_instead_of_label(self):
|
||||
"""Test table type detection using type attribute."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.type = "wireless"
|
||||
del element.label # Remove label
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
|
||||
class TestPipelineRuntimeError:
|
||||
"""Tests for pipeline runtime errors."""
|
||||
|
||||
def test_raises_runtime_error_when_pipeline_none(self):
|
||||
"""Test RuntimeError when pipeline is None during detect."""
|
||||
detector = TableDetector()
|
||||
detector._initialized = True # Bypass lazy init
|
||||
detector._pipeline = None
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
detector.detect(image)
|
||||
|
||||
assert "not initialized" in str(exc_info.value).lower()
|
||||
|
||||
@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
|
||||
rows = extractor._group_by_row(elements)
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_group_by_row_varying_heights_uses_average(self, extractor):
|
||||
"""Test grouping handles varying element heights using dynamic average.
|
||||
|
||||
When elements have varying heights, the row center should be recalculated
|
||||
as new elements are added, preventing tall elements from being incorrectly
|
||||
grouped with the next row.
|
||||
"""
|
||||
# First element: small height, center_y = 105
|
||||
# Second element: tall, center_y = 115 (but should still be same row)
|
||||
# Third element: next row, center_y = 160
|
||||
elements = [
|
||||
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
|
||||
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
|
||||
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
|
||||
]
|
||||
rows = extractor._group_by_row(elements)
|
||||
|
||||
# With dynamic average, both first and second element should be same row
|
||||
assert len(rows) == 2
|
||||
assert len(rows[0]) == 2 # Short and Tall item
|
||||
assert len(rows[1]) == 1 # Next row
|
||||
|
||||
def test_group_by_row_empty_input(self, extractor):
|
||||
"""Test grouping with empty input returns empty list."""
|
||||
rows = extractor._group_by_row([])
|
||||
assert rows == []
|
||||
|
||||
def test_looks_like_line_item_with_amount(self, extractor):
|
||||
"""Test line item detection with amount."""
|
||||
row = [
|
||||
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
|
||||
assert len(elements) == 4
|
||||
|
||||
|
||||
class TestExceptionHandling:
|
||||
"""Tests for exception handling in text element extraction."""
|
||||
|
||||
def test_extract_text_elements_handles_missing_bbox(self):
|
||||
"""Test that missing bbox is handled gracefully."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "text": "No bbox"}, # Missing bbox
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
# Should only have 1 valid element
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_invalid_bbox(self):
|
||||
"""Test that invalid bbox (less than 4 values) is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_none_text(self):
|
||||
"""Test that None text is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_empty_string(self):
|
||||
"""Test that empty string text is skipped."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_malformed_element(self):
|
||||
"""Test that completely malformed elements are handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
"not a dict", # String instead of dict
|
||||
123, # Number instead of dict
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
|
||||
class TestConvertTextLineItem:
|
||||
"""Tests for convert_text_line_item function."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user