refactor: split line_items_extractor into smaller modules with comprehensive tests
- Extract models.py (LineItem, LineItemsResult dataclasses) - Extract html_table_parser.py (ColumnMapper, HtmlTableParser) - Extract merged_cell_handler.py (MergedCellHandler for PP-StructureV3 merged cells) - Reduce line_items_extractor.py from 971 to 396 lines - Add constants for magic numbers (MIN_AMOUNT_THRESHOLD, ROW_GROUPING_THRESHOLD, etc.) - Fix row grouping algorithm in text_line_items_extractor.py - Demote INFO logs to DEBUG level in structure_detector.py - Add 209 tests achieving 85%+ coverage on main modules Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
204
packages/backend/backend/table/html_table_parser.py
Normal file
204
packages/backend/backend/table/html_table_parser.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
"""
|
||||||
|
HTML Table Parser
|
||||||
|
|
||||||
|
Parses HTML tables into structured data and maps columns to field names.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from html.parser import HTMLParser
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration constants
|
||||||
|
# Minimum pattern length to avoid false positives from short substrings
|
||||||
|
MIN_PATTERN_MATCH_LENGTH = 3
|
||||||
|
# Exact match bonus for column mapping priority
|
||||||
|
EXACT_MATCH_BONUS = 100
|
||||||
|
|
||||||
|
# Swedish column name mappings
|
||||||
|
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
|
||||||
|
COLUMN_MAPPINGS = {
|
||||||
|
"article_number": [
|
||||||
|
"art nummer",
|
||||||
|
"artikelnummer",
|
||||||
|
"artikel",
|
||||||
|
"artnr",
|
||||||
|
"art.nr",
|
||||||
|
"art nr",
|
||||||
|
"objektnummer", # Rental: property reference
|
||||||
|
"objekt",
|
||||||
|
],
|
||||||
|
"description": [
|
||||||
|
"beskrivning",
|
||||||
|
"produktbeskrivning",
|
||||||
|
"produkt",
|
||||||
|
"tjänst",
|
||||||
|
"text",
|
||||||
|
"benämning",
|
||||||
|
"vara/tjänst",
|
||||||
|
"vara",
|
||||||
|
# Rental invoice specific
|
||||||
|
"specifikation",
|
||||||
|
"spec",
|
||||||
|
"hyresperiod", # Rental period
|
||||||
|
"period",
|
||||||
|
"typ", # Type of charge
|
||||||
|
# Utility bills
|
||||||
|
"förbrukning", # Consumption
|
||||||
|
"avläsning", # Meter reading
|
||||||
|
],
|
||||||
|
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "m²", "kvm"],
|
||||||
|
"unit": ["enhet", "unit"],
|
||||||
|
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
|
||||||
|
"amount": [
|
||||||
|
"belopp",
|
||||||
|
"summa",
|
||||||
|
"total",
|
||||||
|
"netto",
|
||||||
|
"rad summa",
|
||||||
|
# Rental specific
|
||||||
|
"hyra", # Rent
|
||||||
|
"avgift", # Fee
|
||||||
|
"kostnad", # Cost
|
||||||
|
"debitering", # Charge
|
||||||
|
"totalt", # Total
|
||||||
|
],
|
||||||
|
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
|
||||||
|
# Additional field for rental: deductions/adjustments
|
||||||
|
"deduction": [
|
||||||
|
"avdrag", # Deduction
|
||||||
|
"rabatt", # Discount
|
||||||
|
"kredit", # Credit
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Keywords that indicate NOT a line items table
|
||||||
|
SUMMARY_KEYWORDS = [
|
||||||
|
"frakt",
|
||||||
|
"faktura.avg",
|
||||||
|
"fakturavg",
|
||||||
|
"exkl.moms",
|
||||||
|
"att betala",
|
||||||
|
"öresavr",
|
||||||
|
"bankgiro",
|
||||||
|
"plusgiro",
|
||||||
|
"ocr",
|
||||||
|
"forfallodatum",
|
||||||
|
"förfallodatum",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class _TableHTMLParser(HTMLParser):
|
||||||
|
"""Internal HTML parser for tables."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.rows: list[list[str]] = []
|
||||||
|
self.current_row: list[str] = []
|
||||||
|
self.current_cell: str = ""
|
||||||
|
self.in_td = False
|
||||||
|
self.in_thead = False
|
||||||
|
self.header_row: list[str] = []
|
||||||
|
|
||||||
|
def handle_starttag(self, tag, attrs):
|
||||||
|
if tag == "tr":
|
||||||
|
self.current_row = []
|
||||||
|
elif tag in ("td", "th"):
|
||||||
|
self.in_td = True
|
||||||
|
self.current_cell = ""
|
||||||
|
elif tag == "thead":
|
||||||
|
self.in_thead = True
|
||||||
|
|
||||||
|
def handle_endtag(self, tag):
|
||||||
|
if tag in ("td", "th"):
|
||||||
|
self.in_td = False
|
||||||
|
self.current_row.append(self.current_cell.strip())
|
||||||
|
elif tag == "tr":
|
||||||
|
if self.current_row:
|
||||||
|
if self.in_thead:
|
||||||
|
self.header_row = self.current_row
|
||||||
|
else:
|
||||||
|
self.rows.append(self.current_row)
|
||||||
|
elif tag == "thead":
|
||||||
|
self.in_thead = False
|
||||||
|
|
||||||
|
def handle_data(self, data):
|
||||||
|
if self.in_td:
|
||||||
|
self.current_cell += data
|
||||||
|
|
||||||
|
|
||||||
|
class HTMLTableParser:
|
||||||
|
"""Parse HTML tables into structured data."""
|
||||||
|
|
||||||
|
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
|
||||||
|
"""
|
||||||
|
Parse HTML table and return header and rows.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
html: HTML string containing table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (header_row, data_rows).
|
||||||
|
"""
|
||||||
|
parser = _TableHTMLParser()
|
||||||
|
parser.feed(html)
|
||||||
|
return parser.header_row, parser.rows
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnMapper:
|
||||||
|
"""Map column headers to field names."""
|
||||||
|
|
||||||
|
def __init__(self, mappings: dict[str, list[str]] | None = None):
|
||||||
|
"""
|
||||||
|
Initialize column mapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mappings: Custom column mappings. Uses Swedish defaults if None.
|
||||||
|
"""
|
||||||
|
self.mappings = mappings or COLUMN_MAPPINGS
|
||||||
|
|
||||||
|
def map(self, headers: list[str]) -> dict[int, str]:
|
||||||
|
"""
|
||||||
|
Map column indices to field names.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
headers: List of column header strings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping column index to field name.
|
||||||
|
"""
|
||||||
|
mapping = {}
|
||||||
|
for idx, header in enumerate(headers):
|
||||||
|
normalized = self._normalize(header)
|
||||||
|
|
||||||
|
if not normalized.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
best_match_len = 0
|
||||||
|
|
||||||
|
for field_name, patterns in self.mappings.items():
|
||||||
|
for pattern in patterns:
|
||||||
|
if pattern == normalized:
|
||||||
|
# Exact match gets highest priority
|
||||||
|
best_match = field_name
|
||||||
|
best_match_len = len(pattern) + EXACT_MATCH_BONUS
|
||||||
|
break
|
||||||
|
elif pattern in normalized and len(pattern) > best_match_len:
|
||||||
|
# Partial match requires minimum length to avoid false positives
|
||||||
|
if len(pattern) >= MIN_PATTERN_MATCH_LENGTH:
|
||||||
|
best_match = field_name
|
||||||
|
best_match_len = len(pattern)
|
||||||
|
|
||||||
|
if best_match_len > EXACT_MATCH_BONUS:
|
||||||
|
# Found exact match, no need to check other fields
|
||||||
|
break
|
||||||
|
|
||||||
|
if best_match:
|
||||||
|
mapping[idx] = best_match
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def _normalize(self, header: str) -> str:
|
||||||
|
"""Normalize header text for matching."""
|
||||||
|
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||||
File diff suppressed because it is too large
Load Diff
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
@@ -0,0 +1,423 @@
|
|||||||
|
"""
|
||||||
|
Merged Cell Handler
|
||||||
|
|
||||||
|
Handles detection and extraction of data from tables with merged cells,
|
||||||
|
a common issue with PP-StructureV3 OCR output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from .models import LineItem
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .html_table_parser import ColumnMapper
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Minimum positive amount to consider as line item (filters noise like row indices)
|
||||||
|
MIN_AMOUNT_THRESHOLD = 100
|
||||||
|
|
||||||
|
|
||||||
|
class MergedCellHandler:
|
||||||
|
"""Handles tables with vertically merged cells from PP-StructureV3."""
|
||||||
|
|
||||||
|
def __init__(self, mapper: "ColumnMapper"):
|
||||||
|
"""
|
||||||
|
Initialize handler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mapper: ColumnMapper instance for header keyword detection.
|
||||||
|
"""
|
||||||
|
self.mapper = mapper
|
||||||
|
|
||||||
|
def has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if table rows contain vertically merged data in single cells.
|
||||||
|
|
||||||
|
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
|
||||||
|
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
|
||||||
|
|
||||||
|
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
|
||||||
|
"""
|
||||||
|
if not rows:
|
||||||
|
return False
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
for cell in row:
|
||||||
|
if not cell or len(cell) < 20:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for multiple product numbers (7+ digit patterns)
|
||||||
|
product_nums = re.findall(r"\b\d{7}\b", cell)
|
||||||
|
if len(product_nums) >= 2:
|
||||||
|
logger.debug(f"has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
|
||||||
|
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
|
||||||
|
if len(prices) >= 3:
|
||||||
|
logger.debug(f"has_vertically_merged_cells: found {len(prices)} prices in cell")
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
|
||||||
|
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
|
||||||
|
if len(quantities) >= 2:
|
||||||
|
logger.debug(f"has_vertically_merged_cells: found {len(quantities)} quantities in cell")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def split_merged_rows(
|
||||||
|
self, rows: list[list[str]]
|
||||||
|
) -> tuple[list[str], list[list[str]]]:
|
||||||
|
"""
|
||||||
|
Split vertically merged cells back into separate rows.
|
||||||
|
|
||||||
|
Handles complex cases where PP-StructureV3 merges content across
|
||||||
|
multiple HTML rows. For example, 5 line items might be spread across
|
||||||
|
3 HTML rows with content mixed together.
|
||||||
|
|
||||||
|
Strategy:
|
||||||
|
1. Merge all row content per column
|
||||||
|
2. Detect how many actual data rows exist (by counting product numbers)
|
||||||
|
3. Split each column's content into that many lines
|
||||||
|
|
||||||
|
Returns header and data rows.
|
||||||
|
"""
|
||||||
|
if not rows:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
# Filter out completely empty rows
|
||||||
|
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
|
||||||
|
if not non_empty_rows:
|
||||||
|
return [], rows
|
||||||
|
|
||||||
|
# Determine column count
|
||||||
|
col_count = max(len(r) for r in non_empty_rows)
|
||||||
|
|
||||||
|
# Merge content from all rows for each column
|
||||||
|
merged_columns = []
|
||||||
|
for col_idx in range(col_count):
|
||||||
|
col_content = []
|
||||||
|
for row in non_empty_rows:
|
||||||
|
if col_idx < len(row) and row[col_idx].strip():
|
||||||
|
col_content.append(row[col_idx].strip())
|
||||||
|
merged_columns.append(" ".join(col_content))
|
||||||
|
|
||||||
|
logger.debug(f"split_merged_rows: merged columns = {merged_columns}")
|
||||||
|
|
||||||
|
# Count how many actual data rows we should have
|
||||||
|
# Use the column with most product numbers as reference
|
||||||
|
expected_rows = self._count_expected_rows(merged_columns)
|
||||||
|
logger.debug(f"split_merged_rows: expecting {expected_rows} data rows")
|
||||||
|
|
||||||
|
if expected_rows <= 1:
|
||||||
|
# Not enough data for splitting
|
||||||
|
return [], rows
|
||||||
|
|
||||||
|
# Split each column based on expected row count
|
||||||
|
split_columns = []
|
||||||
|
for col_idx, col_text in enumerate(merged_columns):
|
||||||
|
if not col_text.strip():
|
||||||
|
split_columns.append([""] * (expected_rows + 1)) # +1 for header
|
||||||
|
continue
|
||||||
|
lines = self._split_cell_content_for_rows(col_text, expected_rows)
|
||||||
|
split_columns.append(lines)
|
||||||
|
|
||||||
|
# Ensure all columns have same number of lines (immutable approach)
|
||||||
|
max_lines = max(len(col) for col in split_columns)
|
||||||
|
split_columns = [
|
||||||
|
col + [""] * (max_lines - len(col))
|
||||||
|
for col in split_columns
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(f"split_merged_rows: split into {max_lines} lines total")
|
||||||
|
|
||||||
|
# First line is header, rest are data rows
|
||||||
|
header = [col[0] for col in split_columns]
|
||||||
|
data_rows = []
|
||||||
|
for line_idx in range(1, max_lines):
|
||||||
|
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
|
||||||
|
if any(cell.strip() for cell in row):
|
||||||
|
data_rows.append(row)
|
||||||
|
|
||||||
|
logger.debug(f"split_merged_rows: header={header}, data_rows count={len(data_rows)}")
|
||||||
|
return header, data_rows
|
||||||
|
|
||||||
|
def _count_expected_rows(self, merged_columns: list[str]) -> int:
|
||||||
|
"""
|
||||||
|
Count how many data rows should exist based on content patterns.
|
||||||
|
|
||||||
|
Returns the maximum count found from:
|
||||||
|
- Product numbers (7 digits)
|
||||||
|
- Quantity patterns (number + ST/PCS)
|
||||||
|
- Amount patterns (in columns likely to be totals)
|
||||||
|
"""
|
||||||
|
max_count = 0
|
||||||
|
|
||||||
|
for col_text in merged_columns:
|
||||||
|
if not col_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Count product numbers (most reliable indicator)
|
||||||
|
product_nums = re.findall(r"\b\d{7}\b", col_text)
|
||||||
|
max_count = max(max_count, len(product_nums))
|
||||||
|
|
||||||
|
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
|
||||||
|
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
|
||||||
|
max_count = max(max_count, len(quantities))
|
||||||
|
|
||||||
|
return max_count
|
||||||
|
|
||||||
|
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split cell content knowing how many data rows we expect.
|
||||||
|
|
||||||
|
This is smarter than split_cell_content because it knows the target count.
|
||||||
|
"""
|
||||||
|
cell = cell.strip()
|
||||||
|
|
||||||
|
# Try product number split first
|
||||||
|
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||||
|
products = product_pattern.findall(cell)
|
||||||
|
if len(products) == expected_rows:
|
||||||
|
parts = product_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
# Include description text after each product number
|
||||||
|
values = []
|
||||||
|
for i in range(1, len(parts), 2): # Odd indices are product numbers
|
||||||
|
if i < len(parts):
|
||||||
|
prod_num = parts[i].strip()
|
||||||
|
# Check if there's description text after
|
||||||
|
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
|
||||||
|
# If description looks like text (not another pattern), include it
|
||||||
|
if desc and not re.match(r"^\d{7}$", desc):
|
||||||
|
# Truncate at next product number pattern if any
|
||||||
|
desc_clean = re.split(r"\d{7}", desc)[0].strip()
|
||||||
|
if desc_clean:
|
||||||
|
values.append(f"{prod_num} {desc_clean}")
|
||||||
|
else:
|
||||||
|
values.append(prod_num)
|
||||||
|
else:
|
||||||
|
values.append(prod_num)
|
||||||
|
if len(values) == expected_rows:
|
||||||
|
return [header] + values
|
||||||
|
|
||||||
|
# Try quantity split
|
||||||
|
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||||
|
quantities = qty_pattern.findall(cell)
|
||||||
|
if len(quantities) == expected_rows:
|
||||||
|
parts = qty_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||||
|
if len(values) == expected_rows:
|
||||||
|
return [header] + values
|
||||||
|
|
||||||
|
# Try amount split for discount+totalsumma columns
|
||||||
|
cell_lower = cell.lower()
|
||||||
|
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||||
|
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
|
||||||
|
|
||||||
|
if has_discount and has_total:
|
||||||
|
# Extract only amounts (3+ digit numbers), skip discount percentages
|
||||||
|
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||||
|
amounts = amount_pattern.findall(cell)
|
||||||
|
if len(amounts) >= expected_rows:
|
||||||
|
# Take the last expected_rows amounts (they are likely the totals)
|
||||||
|
return ["Totalsumma"] + amounts[:expected_rows]
|
||||||
|
|
||||||
|
# Try price split
|
||||||
|
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||||
|
prices = price_pattern.findall(cell)
|
||||||
|
if len(prices) >= expected_rows:
|
||||||
|
parts = price_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||||
|
if len(values) >= expected_rows:
|
||||||
|
return [header] + values[:expected_rows]
|
||||||
|
|
||||||
|
# Fall back to original single-value behavior
|
||||||
|
return [cell]
|
||||||
|
|
||||||
|
def split_cell_content(self, cell: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Split a cell containing merged multi-line content.
|
||||||
|
|
||||||
|
Strategies:
|
||||||
|
1. Look for product number patterns (7 digits)
|
||||||
|
2. Look for quantity patterns (number + ST/PCS)
|
||||||
|
3. Look for price patterns (with decimal)
|
||||||
|
4. Handle interleaved discount+amount patterns
|
||||||
|
"""
|
||||||
|
cell = cell.strip()
|
||||||
|
|
||||||
|
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
|
||||||
|
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||||
|
products = product_pattern.findall(cell)
|
||||||
|
if len(products) >= 2:
|
||||||
|
# Extract header (text before first product number) and values
|
||||||
|
parts = product_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
|
||||||
|
return [header] + values
|
||||||
|
|
||||||
|
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
|
||||||
|
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||||
|
quantities = qty_pattern.findall(cell)
|
||||||
|
if len(quantities) >= 2:
|
||||||
|
parts = qty_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||||
|
return [header] + values
|
||||||
|
|
||||||
|
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
|
||||||
|
# Check if header contains two keywords indicating merged columns
|
||||||
|
cell_lower = cell.lower()
|
||||||
|
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||||
|
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
|
||||||
|
|
||||||
|
if has_discount_header and has_amount_header:
|
||||||
|
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
|
||||||
|
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
|
||||||
|
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||||
|
amounts = amount_pattern.findall(cell)
|
||||||
|
|
||||||
|
if len(amounts) >= 2:
|
||||||
|
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
|
||||||
|
# This avoids the "Rabatt" keyword causing is_deduction=True
|
||||||
|
header = "Totalsumma"
|
||||||
|
return [header] + amounts
|
||||||
|
|
||||||
|
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
|
||||||
|
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||||
|
prices = price_pattern.findall(cell)
|
||||||
|
if len(prices) >= 2:
|
||||||
|
parts = price_pattern.split(cell)
|
||||||
|
header = parts[0].strip() if parts else ""
|
||||||
|
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||||
|
return [header] + values
|
||||||
|
|
||||||
|
# No pattern detected, return as single value
|
||||||
|
return [cell]
|
||||||
|
|
||||||
|
def has_merged_header(self, header: list[str] | None) -> bool:
|
||||||
|
"""
|
||||||
|
Check if header appears to be a merged cell containing multiple column names.
|
||||||
|
|
||||||
|
This happens when OCR merges table headers into a single cell, e.g.:
|
||||||
|
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
|
||||||
|
|
||||||
|
Also handles cases where PP-StructureV3 produces headers like:
|
||||||
|
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
|
||||||
|
"""
|
||||||
|
if header is None or not header:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Filter out empty cells to find the actual content
|
||||||
|
non_empty_cells = [h for h in header if h.strip()]
|
||||||
|
|
||||||
|
# Check if we have a single non-empty cell that contains multiple keywords
|
||||||
|
if len(non_empty_cells) == 1:
|
||||||
|
header_text = non_empty_cells[0].lower()
|
||||||
|
# Count how many column keywords are in this single cell
|
||||||
|
keyword_count = 0
|
||||||
|
for patterns in self.mapper.mappings.values():
|
||||||
|
for pattern in patterns:
|
||||||
|
if pattern in header_text:
|
||||||
|
keyword_count += 1
|
||||||
|
break # Only count once per field type
|
||||||
|
|
||||||
|
logger.debug(f"has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
|
||||||
|
return keyword_count >= 2
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def extract_from_merged_cells(
|
||||||
|
self, header: list[str], rows: list[list[str]]
|
||||||
|
) -> list[LineItem]:
|
||||||
|
"""
|
||||||
|
Extract line items from tables with merged cells.
|
||||||
|
|
||||||
|
For poorly OCR'd tables like:
|
||||||
|
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||||
|
Row 1: ["", "", "", "8159"] <- amount row
|
||||||
|
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
|
||||||
|
|
||||||
|
Or:
|
||||||
|
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
|
||||||
|
|
||||||
|
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
|
||||||
|
"""
|
||||||
|
items = []
|
||||||
|
|
||||||
|
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
|
||||||
|
amount_pattern = re.compile(
|
||||||
|
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to parse header cell for description info
|
||||||
|
header_text = " ".join(h for h in header if h.strip()) if header else ""
|
||||||
|
logger.debug(f"extract_from_merged_cells: header_text='{header_text}'")
|
||||||
|
logger.debug(f"extract_from_merged_cells: rows={rows}")
|
||||||
|
|
||||||
|
# Extract description from header
|
||||||
|
description = None
|
||||||
|
article_number = None
|
||||||
|
|
||||||
|
# Look for object number pattern (e.g., "0218103-1201")
|
||||||
|
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
|
||||||
|
if obj_match:
|
||||||
|
article_number = obj_match.group(1)
|
||||||
|
|
||||||
|
# Look for description after object number
|
||||||
|
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
|
||||||
|
if desc_match:
|
||||||
|
description = desc_match.group(1).strip()
|
||||||
|
|
||||||
|
row_index = 0
|
||||||
|
for row in rows:
|
||||||
|
# Combine all non-empty cells in the row
|
||||||
|
row_text = " ".join(cell.strip() for cell in row if cell.strip())
|
||||||
|
logger.debug(f"extract_from_merged_cells: row text='{row_text}'")
|
||||||
|
|
||||||
|
if not row_text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find all amounts in the row
|
||||||
|
amounts = amount_pattern.findall(row_text)
|
||||||
|
logger.debug(f"extract_from_merged_cells: amounts={amounts}")
|
||||||
|
|
||||||
|
for amt_str in amounts:
|
||||||
|
# Clean the amount string
|
||||||
|
cleaned = amt_str.replace(" ", "").strip()
|
||||||
|
if not cleaned or cleaned == "-":
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_deduction = cleaned.startswith("-")
|
||||||
|
|
||||||
|
# Skip small positive numbers that are likely not amounts
|
||||||
|
# (e.g., row indices, small percentages)
|
||||||
|
if not is_deduction:
|
||||||
|
try:
|
||||||
|
val = float(cleaned.replace(",", "."))
|
||||||
|
if val < MIN_AMOUNT_THRESHOLD:
|
||||||
|
continue
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create a line item for each amount
|
||||||
|
item = LineItem(
|
||||||
|
row_index=row_index,
|
||||||
|
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
|
||||||
|
article_number=article_number if row_index == 0 else None,
|
||||||
|
amount=cleaned,
|
||||||
|
is_deduction=is_deduction,
|
||||||
|
confidence=0.7,
|
||||||
|
)
|
||||||
|
items.append(item)
|
||||||
|
row_index += 1
|
||||||
|
logger.debug(f"extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
|
||||||
|
|
||||||
|
return items
|
||||||
61
packages/backend/backend/table/models.py
Normal file
61
packages/backend/backend/table/models.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
"""
|
||||||
|
Line Items Data Models
|
||||||
|
|
||||||
|
Dataclasses for line item extraction results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from decimal import Decimal, InvalidOperation
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LineItem:
|
||||||
|
"""Single line item from invoice."""
|
||||||
|
|
||||||
|
row_index: int
|
||||||
|
description: str | None = None
|
||||||
|
quantity: str | None = None
|
||||||
|
unit: str | None = None
|
||||||
|
unit_price: str | None = None
|
||||||
|
amount: str | None = None
|
||||||
|
article_number: str | None = None
|
||||||
|
vat_rate: str | None = None
|
||||||
|
is_deduction: bool = False # True if this row is a deduction/discount
|
||||||
|
confidence: float = 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LineItemsResult:
|
||||||
|
"""Result of line items extraction."""
|
||||||
|
|
||||||
|
items: list[LineItem]
|
||||||
|
header_row: list[str]
|
||||||
|
raw_html: str
|
||||||
|
is_reversed: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_amount(self) -> str | None:
|
||||||
|
"""Calculate total amount from line items (deduction rows have negative amounts)."""
|
||||||
|
if not self.items:
|
||||||
|
return None
|
||||||
|
|
||||||
|
total = Decimal("0")
|
||||||
|
for item in self.items:
|
||||||
|
if item.amount:
|
||||||
|
try:
|
||||||
|
# Parse Swedish number format (1 234,56)
|
||||||
|
amount_str = item.amount.replace(" ", "").replace(",", ".")
|
||||||
|
total += Decimal(amount_str)
|
||||||
|
except InvalidOperation:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Format back to Swedish format
|
||||||
|
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
|
||||||
|
# Fix the space/comma swap
|
||||||
|
parts = formatted.rsplit(",", 1)
|
||||||
|
if len(parts) == 2:
|
||||||
|
return parts[0].replace(" ", " ") + "," + parts[1]
|
||||||
|
return formatted
|
||||||
@@ -158,36 +158,36 @@ class TableDetector:
|
|||||||
return tables
|
return tables
|
||||||
|
|
||||||
# Log raw result type for debugging
|
# Log raw result type for debugging
|
||||||
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
logger.debug(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||||
|
|
||||||
# Handle case where results is a single dict-like object (PaddleX 3.x)
|
# Handle case where results is a single dict-like object (PaddleX 3.x)
|
||||||
# rather than a list of results
|
# rather than a list of results
|
||||||
if hasattr(results, "get") and not isinstance(results, list):
|
if hasattr(results, "get") and not isinstance(results, list):
|
||||||
# Single result object - wrap in list for uniform processing
|
# Single result object - wrap in list for uniform processing
|
||||||
logger.info("Results is dict-like, wrapping in list")
|
logger.debug("Results is dict-like, wrapping in list")
|
||||||
results = [results]
|
results = [results]
|
||||||
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
|
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
|
||||||
# Iterator or generator - convert to list
|
# Iterator or generator - convert to list
|
||||||
try:
|
try:
|
||||||
results = list(results)
|
results = list(results)
|
||||||
logger.info(f"Converted iterator to list with {len(results)} items")
|
logger.debug(f"Converted iterator to list with {len(results)} items")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to convert results to list: {e}")
|
logger.warning(f"Failed to convert results to list: {e}")
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
logger.info(f"Processing {len(results)} result(s)")
|
logger.debug(f"Processing {len(results)} result(s)")
|
||||||
|
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
try:
|
try:
|
||||||
result_type = type(result).__name__
|
result_type = type(result).__name__
|
||||||
has_get = hasattr(result, "get")
|
has_get = hasattr(result, "get")
|
||||||
has_layout = hasattr(result, "layout_elements")
|
has_layout = hasattr(result, "layout_elements")
|
||||||
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
logger.debug(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||||
|
|
||||||
# Try PaddleX 3.x API first (dict-like with table_res_list)
|
# Try PaddleX 3.x API first (dict-like with table_res_list)
|
||||||
if has_get:
|
if has_get:
|
||||||
parsed = self._parse_paddlex_result(result)
|
parsed = self._parse_paddlex_result(result)
|
||||||
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
logger.debug(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||||
tables.extend(parsed)
|
tables.extend(parsed)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -201,14 +201,14 @@ class TableDetector:
|
|||||||
if table_result and table_result.confidence >= self.config.min_confidence:
|
if table_result and table_result.confidence >= self.config.min_confidence:
|
||||||
tables.append(table_result)
|
tables.append(table_result)
|
||||||
legacy_count += 1
|
legacy_count += 1
|
||||||
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
logger.debug(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
|
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
|
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
logger.info(f"Total tables detected: {len(tables)}")
|
logger.debug(f"Total tables detected: {len(tables)}")
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
|
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
|
||||||
@@ -223,7 +223,7 @@ class TableDetector:
|
|||||||
result_keys = list(result.keys())
|
result_keys = list(result.keys())
|
||||||
elif hasattr(result, "__dict__"):
|
elif hasattr(result, "__dict__"):
|
||||||
result_keys = list(result.__dict__.keys())
|
result_keys = list(result.__dict__.keys())
|
||||||
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
logger.debug(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||||
|
|
||||||
# Get table results from PaddleX 3.x API
|
# Get table results from PaddleX 3.x API
|
||||||
# Handle both dict.get() and attribute access
|
# Handle both dict.get() and attribute access
|
||||||
@@ -234,8 +234,8 @@ class TableDetector:
|
|||||||
table_res_list = getattr(result, "table_res_list", None)
|
table_res_list = getattr(result, "table_res_list", None)
|
||||||
parsing_res_list = getattr(result, "parsing_res_list", [])
|
parsing_res_list = getattr(result, "parsing_res_list", [])
|
||||||
|
|
||||||
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
logger.debug(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||||
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
logger.debug(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||||
|
|
||||||
if not table_res_list:
|
if not table_res_list:
|
||||||
# Log available keys/attributes for debugging
|
# Log available keys/attributes for debugging
|
||||||
@@ -330,7 +330,7 @@ class TableDetector:
|
|||||||
# Default confidence for PaddleX 3.x results
|
# Default confidence for PaddleX 3.x results
|
||||||
confidence = 0.9
|
confidence = 0.9
|
||||||
|
|
||||||
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
logger.debug(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||||
tables.append(TableDetectionResult(
|
tables.append(TableDetectionResult(
|
||||||
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
|
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
|
||||||
html=html,
|
html=html,
|
||||||
@@ -467,14 +467,14 @@ class TableDetector:
|
|||||||
if not pdf_path.exists():
|
if not pdf_path.exists():
|
||||||
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
||||||
|
|
||||||
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
logger.debug(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||||
|
|
||||||
# Render specific page
|
# Render specific page
|
||||||
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
|
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
|
||||||
if page_no == page_number:
|
if page_no == page_number:
|
||||||
image = Image.open(io.BytesIO(image_bytes))
|
image = Image.open(io.BytesIO(image_bytes))
|
||||||
image_array = np.array(image)
|
image_array = np.array(image)
|
||||||
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
logger.debug(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||||
return self.detect(image_array)
|
return self.detect(image_array)
|
||||||
|
|
||||||
raise ValueError(f"Page {page_number} not found in PDF")
|
raise ValueError(f"Page {page_number} not found in PDF")
|
||||||
|
|||||||
@@ -15,6 +15,11 @@ import logging
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Configuration constants
|
||||||
|
DEFAULT_ROW_TOLERANCE = 15.0 # Max vertical distance (pixels) to consider same row
|
||||||
|
MIN_ITEMS_FOR_VALID_EXTRACTION = 2 # Minimum items required for valid extraction
|
||||||
|
MIN_TEXT_ELEMENTS_FOR_EXTRACTION = 5 # Minimum text elements needed to attempt extraction
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TextElement:
|
class TextElement:
|
||||||
@@ -65,7 +70,10 @@ class TextLineItemsResult:
|
|||||||
extraction_method: str = "text_spatial"
|
extraction_method: str = "text_spatial"
|
||||||
|
|
||||||
|
|
||||||
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
|
# Amount pattern matches Swedish, US, and simple numeric formats
|
||||||
|
# Handles: "1 234,56", "1,234.56", "1234.56", "100 kr", "50:-", "-100,00"
|
||||||
|
# Does NOT handle: amounts with more than 2 decimal places, scientific notation
|
||||||
|
# See tests in test_text_line_items_extractor.py::TestAmountPattern
|
||||||
AMOUNT_PATTERN = re.compile(
|
AMOUNT_PATTERN = re.compile(
|
||||||
r"(?<![0-9])(?:"
|
r"(?<![0-9])(?:"
|
||||||
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
|
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
|
||||||
@@ -128,17 +136,17 @@ class TextLineItemsExtractor:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
row_tolerance: float = 15.0, # Max vertical distance to consider same row
|
row_tolerance: float = DEFAULT_ROW_TOLERANCE,
|
||||||
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
|
min_items_for_valid: int = MIN_ITEMS_FOR_VALID_EXTRACTION,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize extractor.
|
Initialize extractor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
row_tolerance: Maximum vertical distance (pixels) between elements
|
row_tolerance: Maximum vertical distance (pixels) between elements
|
||||||
to consider them on the same row.
|
to consider them on the same row. Default: 15.0
|
||||||
min_items_for_valid: Minimum number of line items required for
|
min_items_for_valid: Minimum number of line items required for
|
||||||
extraction to be considered successful.
|
extraction to be considered successful. Default: 2
|
||||||
"""
|
"""
|
||||||
self.row_tolerance = row_tolerance
|
self.row_tolerance = row_tolerance
|
||||||
self.min_items_for_valid = min_items_for_valid
|
self.min_items_for_valid = min_items_for_valid
|
||||||
@@ -161,10 +169,13 @@ class TextLineItemsExtractor:
|
|||||||
|
|
||||||
# Extract text elements from parsing results
|
# Extract text elements from parsing results
|
||||||
text_elements = self._extract_text_elements(parsing_res_list)
|
text_elements = self._extract_text_elements(parsing_res_list)
|
||||||
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
logger.debug(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||||
|
|
||||||
if len(text_elements) < 5: # Need at least a few elements
|
if len(text_elements) < MIN_TEXT_ELEMENTS_FOR_EXTRACTION:
|
||||||
logger.debug("Too few text elements for line item extraction")
|
logger.debug(
|
||||||
|
f"Too few text elements ({len(text_elements)}) for line item extraction, "
|
||||||
|
f"need at least {MIN_TEXT_ELEMENTS_FOR_EXTRACTION}"
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.extract_from_text_elements(text_elements)
|
return self.extract_from_text_elements(text_elements)
|
||||||
@@ -183,11 +194,11 @@ class TextLineItemsExtractor:
|
|||||||
"""
|
"""
|
||||||
# Group elements by row
|
# Group elements by row
|
||||||
rows = self._group_by_row(text_elements)
|
rows = self._group_by_row(text_elements)
|
||||||
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
logger.debug(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||||
|
|
||||||
# Find the line items section
|
# Find the line items section
|
||||||
item_rows = self._identify_line_item_rows(rows)
|
item_rows = self._identify_line_item_rows(rows)
|
||||||
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
logger.debug(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||||
|
|
||||||
if len(item_rows) < self.min_items_for_valid:
|
if len(item_rows) < self.min_items_for_valid:
|
||||||
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
|
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
|
||||||
@@ -195,7 +206,7 @@ class TextLineItemsExtractor:
|
|||||||
|
|
||||||
# Extract structured items
|
# Extract structured items
|
||||||
items = self._parse_line_items(item_rows)
|
items = self._parse_line_items(item_rows)
|
||||||
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
logger.debug(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||||
|
|
||||||
if len(items) < self.min_items_for_valid:
|
if len(items) < self.min_items_for_valid:
|
||||||
return None
|
return None
|
||||||
@@ -209,7 +220,11 @@ class TextLineItemsExtractor:
|
|||||||
def _extract_text_elements(
|
def _extract_text_elements(
|
||||||
self, parsing_res_list: list[dict[str, Any]]
|
self, parsing_res_list: list[dict[str, Any]]
|
||||||
) -> list[TextElement]:
|
) -> list[TextElement]:
|
||||||
"""Extract TextElement objects from parsing_res_list."""
|
"""Extract TextElement objects from parsing_res_list.
|
||||||
|
|
||||||
|
Handles both dict and LayoutBlock object formats from PP-StructureV3.
|
||||||
|
Gracefully skips invalid elements with appropriate logging.
|
||||||
|
"""
|
||||||
elements = []
|
elements = []
|
||||||
|
|
||||||
for elem in parsing_res_list:
|
for elem in parsing_res_list:
|
||||||
@@ -220,11 +235,15 @@ class TextLineItemsExtractor:
|
|||||||
bbox = elem.get("bbox", [])
|
bbox = elem.get("bbox", [])
|
||||||
# Try both 'text' and 'content' keys
|
# Try both 'text' and 'content' keys
|
||||||
text = elem.get("text", "") or elem.get("content", "")
|
text = elem.get("text", "") or elem.get("content", "")
|
||||||
else:
|
elif hasattr(elem, "label"):
|
||||||
label = getattr(elem, "label", "")
|
label = getattr(elem, "label", "")
|
||||||
bbox = getattr(elem, "bbox", [])
|
bbox = getattr(elem, "bbox", [])
|
||||||
# LayoutBlock objects use 'content' attribute
|
# LayoutBlock objects use 'content' attribute
|
||||||
text = getattr(elem, "content", "") or getattr(elem, "text", "")
|
text = getattr(elem, "content", "") or getattr(elem, "text", "")
|
||||||
|
else:
|
||||||
|
# Element is neither dict nor has expected attributes
|
||||||
|
logger.debug(f"Skipping element with unexpected type: {type(elem).__name__}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Only process text elements (skip images, tables, etc.)
|
# Only process text elements (skip images, tables, etc.)
|
||||||
if label not in ("text", "paragraph_title", "aside_text"):
|
if label not in ("text", "paragraph_title", "aside_text"):
|
||||||
@@ -232,6 +251,7 @@ class TextLineItemsExtractor:
|
|||||||
|
|
||||||
# Validate bbox
|
# Validate bbox
|
||||||
if not self._valid_bbox(bbox):
|
if not self._valid_bbox(bbox):
|
||||||
|
logger.debug(f"Skipping element with invalid bbox: {bbox}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Clean text
|
# Clean text
|
||||||
@@ -250,8 +270,13 @@ class TextLineItemsExtractor:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
except (KeyError, TypeError, ValueError, AttributeError) as e:
|
||||||
|
# Expected format issues - log at debug level
|
||||||
|
logger.debug(f"Skipping element due to format issue: {e}")
|
||||||
|
continue
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Failed to parse element: {e}")
|
# Unexpected errors - log at warning level for visibility
|
||||||
|
logger.warning(f"Unexpected error parsing element: {type(e).__name__}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return elements
|
return elements
|
||||||
@@ -270,6 +295,7 @@ class TextLineItemsExtractor:
|
|||||||
Group text elements into rows based on vertical position.
|
Group text elements into rows based on vertical position.
|
||||||
|
|
||||||
Elements within row_tolerance of each other are considered same row.
|
Elements within row_tolerance of each other are considered same row.
|
||||||
|
Uses dynamic average center_y to handle varying element heights more accurately.
|
||||||
"""
|
"""
|
||||||
if not elements:
|
if not elements:
|
||||||
return []
|
return []
|
||||||
@@ -277,22 +303,22 @@ class TextLineItemsExtractor:
|
|||||||
# Sort by vertical position
|
# Sort by vertical position
|
||||||
sorted_elements = sorted(elements, key=lambda e: e.center_y)
|
sorted_elements = sorted(elements, key=lambda e: e.center_y)
|
||||||
|
|
||||||
rows = []
|
rows: list[list[TextElement]] = []
|
||||||
current_row = [sorted_elements[0]]
|
current_row: list[TextElement] = [sorted_elements[0]]
|
||||||
current_y = sorted_elements[0].center_y
|
|
||||||
|
|
||||||
for elem in sorted_elements[1:]:
|
for elem in sorted_elements[1:]:
|
||||||
if abs(elem.center_y - current_y) <= self.row_tolerance:
|
# Calculate dynamic average center_y for current row
|
||||||
# Same row
|
avg_center_y = sum(e.center_y for e in current_row) / len(current_row)
|
||||||
|
|
||||||
|
if abs(elem.center_y - avg_center_y) <= self.row_tolerance:
|
||||||
|
# Same row - add element and recalculate average on next iteration
|
||||||
current_row.append(elem)
|
current_row.append(elem)
|
||||||
else:
|
else:
|
||||||
# New row
|
# New row - finalize current row
|
||||||
if current_row:
|
# Sort row by horizontal position (left to right)
|
||||||
# Sort row by horizontal position
|
current_row.sort(key=lambda e: e.center_x)
|
||||||
current_row.sort(key=lambda e: e.center_x)
|
rows.append(current_row)
|
||||||
rows.append(current_row)
|
|
||||||
current_row = [elem]
|
current_row = [elem]
|
||||||
current_y = elem.center_y
|
|
||||||
|
|
||||||
# Don't forget last row
|
# Don't forget last row
|
||||||
if current_row:
|
if current_row:
|
||||||
|
|||||||
@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
|
|||||||
|
|
||||||
extractor = LineItemsExtractor()
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
# Create mock table detection result
|
# Create mock table detection result with proper thead/tbody structure
|
||||||
mock_table = MagicMock(spec=TableDetectionResult)
|
mock_table = MagicMock(spec=TableDetectionResult)
|
||||||
mock_table.html = """
|
mock_table.html = """
|
||||||
<table>
|
<table>
|
||||||
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
|
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
|
||||||
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
|
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
|
||||||
</table>
|
</table>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
|
|||||||
assert len(result.items) >= 1
|
assert len(result.items) >= 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestPdfPathValidation:
|
||||||
|
"""Tests for PDF path validation."""
|
||||||
|
|
||||||
|
def test_detect_tables_with_nonexistent_path(self):
|
||||||
|
"""Test that non-existent PDF path returns empty results."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
# Create detector and call _detect_tables_with_parsing with non-existent path
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from backend.table.structure_detector import TableDetector
|
||||||
|
|
||||||
|
mock_detector = MagicMock(spec=TableDetector)
|
||||||
|
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||||
|
mock_detector, "nonexistent.pdf"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tables == []
|
||||||
|
assert parsing_res == []
|
||||||
|
|
||||||
|
def test_detect_tables_with_directory_path(self, tmp_path):
|
||||||
|
"""Test that directory path (not file) returns empty results."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from backend.table.structure_detector import TableDetector
|
||||||
|
|
||||||
|
mock_detector = MagicMock(spec=TableDetector)
|
||||||
|
|
||||||
|
# tmp_path is a directory, not a file
|
||||||
|
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||||
|
mock_detector, str(tmp_path)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert tables == []
|
||||||
|
assert parsing_res == []
|
||||||
|
|
||||||
|
def test_detect_tables_validates_file_exists(self, tmp_path):
|
||||||
|
"""Test path validation for file existence.
|
||||||
|
|
||||||
|
This test verifies that the method correctly validates the path exists
|
||||||
|
and is a file before attempting to process it.
|
||||||
|
"""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
# Create a real file path that exists
|
||||||
|
fake_pdf = tmp_path / "test.pdf"
|
||||||
|
fake_pdf.write_bytes(b"not a real pdf")
|
||||||
|
|
||||||
|
# Mock render_pdf_to_images to avoid actual PDF processing
|
||||||
|
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
|
||||||
|
# Return empty iterator - simulates file exists but no pages
|
||||||
|
mock_render.return_value = iter([])
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from backend.table.structure_detector import TableDetector
|
||||||
|
|
||||||
|
mock_detector = MagicMock(spec=TableDetector)
|
||||||
|
mock_detector._ensure_initialized = MagicMock()
|
||||||
|
mock_detector._pipeline = MagicMock()
|
||||||
|
|
||||||
|
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||||
|
mock_detector, str(fake_pdf)
|
||||||
|
)
|
||||||
|
|
||||||
|
# render_pdf_to_images was called (path validation passed)
|
||||||
|
mock_render.assert_called_once()
|
||||||
|
assert tables == []
|
||||||
|
assert parsing_res == []
|
||||||
|
|
||||||
|
|
||||||
class TestLineItemsResult:
|
class TestLineItemsResult:
|
||||||
"""Tests for LineItemsResult dataclass."""
|
"""Tests for LineItemsResult dataclass."""
|
||||||
|
|
||||||
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
|
|||||||
assert result.items[0].is_deduction is False
|
assert result.items[0].is_deduction is False
|
||||||
assert result.items[1].amount == "-2000"
|
assert result.items[1].amount == "-2000"
|
||||||
assert result.items[1].is_deduction is True
|
assert result.items[1].is_deduction is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestTextFallbackExtraction:
|
||||||
|
"""Tests for text-based fallback extraction."""
|
||||||
|
|
||||||
|
def test_text_fallback_disabled_by_default(self):
|
||||||
|
"""Test text fallback can be disabled."""
|
||||||
|
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||||
|
assert extractor.enable_text_fallback is False
|
||||||
|
|
||||||
|
def test_text_fallback_enabled_by_default(self):
|
||||||
|
"""Test text fallback is enabled by default."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
assert extractor.enable_text_fallback is True
|
||||||
|
|
||||||
|
def test_try_text_fallback_with_valid_parsing_res(self):
|
||||||
|
"""Test text fallback with valid parsing results."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from backend.table.text_line_items_extractor import (
|
||||||
|
TextLineItemsExtractor,
|
||||||
|
TextLineItem,
|
||||||
|
TextLineItemsResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
# Mock parsing_res_list with text elements
|
||||||
|
parsing_res = [
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
|
||||||
|
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
|
||||||
|
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
|
||||||
|
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create mock text extraction result
|
||||||
|
mock_text_result = TextLineItemsResult(
|
||||||
|
items=[
|
||||||
|
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
|
||||||
|
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
|
||||||
|
],
|
||||||
|
header_row=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
|
||||||
|
result = extractor._try_text_fallback(parsing_res)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result.items) == 2
|
||||||
|
assert result.items[0].description == "Product A"
|
||||||
|
assert result.items[1].description == "Product B"
|
||||||
|
|
||||||
|
def test_try_text_fallback_returns_none_on_failure(self):
|
||||||
|
"""Test text fallback returns None when extraction fails."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
|
||||||
|
result = extractor._try_text_fallback([])
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_extract_from_pdf_uses_text_fallback(self):
|
||||||
|
"""Test extract_from_pdf uses text fallback when no tables found."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
|
||||||
|
|
||||||
|
extractor = LineItemsExtractor(enable_text_fallback=True)
|
||||||
|
|
||||||
|
# Mock _detect_tables_with_parsing to return no tables but parsing_res
|
||||||
|
mock_text_result = TextLineItemsResult(
|
||||||
|
items=[
|
||||||
|
TextLineItem(row_index=0, description="Product", amount="100,00"),
|
||||||
|
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
|
||||||
|
],
|
||||||
|
header_row=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||||
|
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||||
|
|
||||||
|
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
|
||||||
|
result = extractor.extract_from_pdf("fake.pdf")
|
||||||
|
|
||||||
|
# Text fallback should be called
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
|
|
||||||
|
def test_extract_from_pdf_skips_fallback_when_disabled(self):
|
||||||
|
"""Test extract_from_pdf skips text fallback when disabled."""
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||||
|
|
||||||
|
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||||
|
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||||
|
|
||||||
|
result = extractor.extract_from_pdf("fake.pdf")
|
||||||
|
|
||||||
|
# Should return None, not use text fallback
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestVerticallyMergedCellExtraction:
|
||||||
|
"""Tests for vertically merged cell extraction."""
|
||||||
|
|
||||||
|
def test_detects_vertically_merged_cells(self):
|
||||||
|
"""Test detection of vertically merged cells in rows."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
# Rows with multiple product numbers in single cell
|
||||||
|
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
|
||||||
|
assert extractor._has_vertically_merged_cells(rows) is True
|
||||||
|
|
||||||
|
def test_splits_vertically_merged_rows(self):
|
||||||
|
"""Test splitting vertically merged rows."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
|
||||||
|
]
|
||||||
|
header, data = extractor._split_merged_rows(rows)
|
||||||
|
|
||||||
|
# Should split into header + data rows
|
||||||
|
assert isinstance(header, list)
|
||||||
|
assert isinstance(data, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeductionDetection:
|
||||||
|
"""Tests for deduction/discount detection."""
|
||||||
|
|
||||||
|
def test_detects_deduction_by_keyword_avdrag(self):
|
||||||
|
"""Test detection of deduction by 'avdrag' keyword."""
|
||||||
|
html = """
|
||||||
|
<html><body><table>
|
||||||
|
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||||
|
<tbody>
|
||||||
|
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
|
||||||
|
</tbody>
|
||||||
|
</table></body></html>
|
||||||
|
"""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
result = extractor.extract(html)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].is_deduction is True
|
||||||
|
|
||||||
|
def test_detects_deduction_by_keyword_rabatt(self):
|
||||||
|
"""Test detection of deduction by 'rabatt' keyword."""
|
||||||
|
html = """
|
||||||
|
<html><body><table>
|
||||||
|
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||||
|
<tbody>
|
||||||
|
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
|
||||||
|
</tbody>
|
||||||
|
</table></body></html>
|
||||||
|
"""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
result = extractor.extract(html)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].is_deduction is True
|
||||||
|
|
||||||
|
def test_detects_deduction_by_negative_amount(self):
|
||||||
|
"""Test detection of deduction by negative amount."""
|
||||||
|
html = """
|
||||||
|
<html><body><table>
|
||||||
|
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||||
|
<tbody>
|
||||||
|
<tr><td>Some credit</td><td>-250,00</td></tr>
|
||||||
|
</tbody>
|
||||||
|
</table></body></html>
|
||||||
|
"""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
result = extractor.extract(html)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].is_deduction is True
|
||||||
|
|
||||||
|
def test_normal_item_not_deduction(self):
|
||||||
|
"""Test normal item is not marked as deduction."""
|
||||||
|
html = """
|
||||||
|
<html><body><table>
|
||||||
|
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||||
|
<tbody>
|
||||||
|
<tr><td>Normal product</td><td>500,00</td></tr>
|
||||||
|
</tbody>
|
||||||
|
</table></body></html>
|
||||||
|
"""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
result = extractor.extract(html)
|
||||||
|
|
||||||
|
assert len(result.items) == 1
|
||||||
|
assert result.items[0].is_deduction is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestHeaderDetection:
|
||||||
|
"""Tests for header row detection."""
|
||||||
|
|
||||||
|
def test_detect_header_at_bottom(self):
|
||||||
|
"""Test detecting header at bottom of table (reversed)."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
["100,00", "Product A", "1"],
|
||||||
|
["200,00", "Product B", "2"],
|
||||||
|
["Belopp", "Beskrivning", "Antal"], # Header at bottom
|
||||||
|
]
|
||||||
|
|
||||||
|
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||||
|
|
||||||
|
assert header_idx == 2
|
||||||
|
assert is_at_end is True
|
||||||
|
assert "Belopp" in header
|
||||||
|
|
||||||
|
def test_detect_header_at_top(self):
|
||||||
|
"""Test detecting header at top of table."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
["Belopp", "Beskrivning", "Antal"], # Header at top
|
||||||
|
["100,00", "Product A", "1"],
|
||||||
|
["200,00", "Product B", "2"],
|
||||||
|
]
|
||||||
|
|
||||||
|
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||||
|
|
||||||
|
assert header_idx == 0
|
||||||
|
assert is_at_end is False
|
||||||
|
assert "Belopp" in header
|
||||||
|
|
||||||
|
def test_no_header_detected(self):
|
||||||
|
"""Test when no header is detected."""
|
||||||
|
extractor = LineItemsExtractor()
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
["100,00", "Product A", "1"],
|
||||||
|
["200,00", "Product B", "2"],
|
||||||
|
]
|
||||||
|
|
||||||
|
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||||
|
|
||||||
|
assert header_idx == -1
|
||||||
|
assert header == []
|
||||||
|
assert is_at_end is False
|
||||||
|
|||||||
448
tests/table/test_merged_cell_handler.py
Normal file
448
tests/table/test_merged_cell_handler.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
"""
|
||||||
|
Tests for Merged Cell Handler
|
||||||
|
|
||||||
|
Tests the detection and extraction of data from tables with merged cells,
|
||||||
|
a common issue with PP-StructureV3 OCR output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
|
||||||
|
from backend.table.html_table_parser import ColumnMapper
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def handler():
|
||||||
|
"""Create a MergedCellHandler with default ColumnMapper."""
|
||||||
|
return MergedCellHandler(ColumnMapper())
|
||||||
|
|
||||||
|
|
||||||
|
class TestHasVerticallyMergedCells:
|
||||||
|
"""Tests for has_vertically_merged_cells detection."""
|
||||||
|
|
||||||
|
def test_empty_rows_returns_false(self, handler):
|
||||||
|
"""Test empty rows returns False."""
|
||||||
|
assert handler.has_vertically_merged_cells([]) is False
|
||||||
|
|
||||||
|
def test_short_cells_ignored(self, handler):
|
||||||
|
"""Test cells shorter than 20 chars are ignored."""
|
||||||
|
rows = [["Short cell", "Also short"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is False
|
||||||
|
|
||||||
|
def test_detects_multiple_product_numbers(self, handler):
|
||||||
|
"""Test detection of multiple 7-digit product numbers in cell."""
|
||||||
|
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is True
|
||||||
|
|
||||||
|
def test_single_product_number_not_merged(self, handler):
|
||||||
|
"""Test single product number doesn't trigger detection."""
|
||||||
|
rows = [["Produktnr 1457280 and more text here for length"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is False
|
||||||
|
|
||||||
|
def test_detects_multiple_prices(self, handler):
|
||||||
|
"""Test detection of 3+ prices in cell (Swedish format)."""
|
||||||
|
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is True
|
||||||
|
|
||||||
|
def test_two_prices_not_merged(self, handler):
|
||||||
|
"""Test two prices doesn't trigger detection (needs 3+)."""
|
||||||
|
rows = [["Pris 127,20 234,56 total amount here"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is False
|
||||||
|
|
||||||
|
def test_detects_multiple_quantities(self, handler):
|
||||||
|
"""Test detection of multiple quantity patterns."""
|
||||||
|
rows = [["Antal 6ST 6ST 1ST more text here"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is True
|
||||||
|
|
||||||
|
def test_single_quantity_not_merged(self, handler):
|
||||||
|
"""Test single quantity doesn't trigger detection."""
|
||||||
|
rows = [["Antal 6ST and more text here for length"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is False
|
||||||
|
|
||||||
|
def test_empty_cell_skipped(self, handler):
|
||||||
|
"""Test empty cells are skipped."""
|
||||||
|
rows = [["", None, "Valid but short"]]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is False
|
||||||
|
|
||||||
|
def test_multiple_rows_checked(self, handler):
|
||||||
|
"""Test all rows are checked for merged content."""
|
||||||
|
rows = [
|
||||||
|
["Normal row with nothing special"],
|
||||||
|
["Produktnr 1457280 1457281 1060381 merged content"],
|
||||||
|
]
|
||||||
|
assert handler.has_vertically_merged_cells(rows) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitMergedRows:
|
||||||
|
"""Tests for split_merged_rows method."""
|
||||||
|
|
||||||
|
def test_empty_rows_returns_empty(self, handler):
|
||||||
|
"""Test empty rows returns empty result."""
|
||||||
|
header, data = handler.split_merged_rows([])
|
||||||
|
assert header == []
|
||||||
|
assert data == []
|
||||||
|
|
||||||
|
def test_all_empty_rows_returns_original(self, handler):
|
||||||
|
"""Test all empty rows returns original rows."""
|
||||||
|
rows = [["", ""], ["", ""]]
|
||||||
|
header, data = handler.split_merged_rows(rows)
|
||||||
|
assert header == []
|
||||||
|
assert data == rows
|
||||||
|
|
||||||
|
def test_splits_by_product_numbers(self, handler):
|
||||||
|
"""Test splitting rows by product numbers."""
|
||||||
|
rows = [
|
||||||
|
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
|
||||||
|
]
|
||||||
|
header, data = handler.split_merged_rows(rows)
|
||||||
|
|
||||||
|
assert len(header) == 3
|
||||||
|
assert header[0] == "Produktnr"
|
||||||
|
assert len(data) == 2
|
||||||
|
|
||||||
|
def test_splits_by_quantities(self, handler):
|
||||||
|
"""Test splitting rows by quantity patterns."""
|
||||||
|
rows = [
|
||||||
|
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
|
||||||
|
]
|
||||||
|
header, data = handler.split_merged_rows(rows)
|
||||||
|
|
||||||
|
# Should detect 2 quantities and split accordingly
|
||||||
|
assert len(data) >= 1
|
||||||
|
|
||||||
|
def test_single_row_not_split(self, handler):
|
||||||
|
"""Test single item row is not split."""
|
||||||
|
rows = [
|
||||||
|
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
|
||||||
|
]
|
||||||
|
header, data = handler.split_merged_rows(rows)
|
||||||
|
|
||||||
|
# Only 1 product number, so expected_rows <= 1
|
||||||
|
assert header == []
|
||||||
|
assert data == rows
|
||||||
|
|
||||||
|
def test_handles_missing_columns(self, handler):
|
||||||
|
"""Test handles rows with different column counts."""
|
||||||
|
rows = [
|
||||||
|
["Produktnr 1234567 1234568", ""],
|
||||||
|
["Antal 2ST 3ST"],
|
||||||
|
]
|
||||||
|
header, data = handler.split_merged_rows(rows)
|
||||||
|
|
||||||
|
# Should handle gracefully
|
||||||
|
assert isinstance(header, list)
|
||||||
|
assert isinstance(data, list)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCountExpectedRows:
|
||||||
|
"""Tests for _count_expected_rows helper."""
|
||||||
|
|
||||||
|
def test_counts_product_numbers(self, handler):
|
||||||
|
"""Test counting product numbers."""
|
||||||
|
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
|
||||||
|
count = handler._count_expected_rows(columns)
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
def test_counts_quantities(self, handler):
|
||||||
|
"""Test counting quantity patterns."""
|
||||||
|
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
|
||||||
|
count = handler._count_expected_rows(columns)
|
||||||
|
assert count == 4
|
||||||
|
|
||||||
|
def test_returns_max_count(self, handler):
|
||||||
|
"""Test returns maximum count across columns."""
|
||||||
|
columns = [
|
||||||
|
"Produktnr 1234567 1234568", # 2 products
|
||||||
|
"Antal 5ST 10ST 15ST", # 3 quantities
|
||||||
|
]
|
||||||
|
count = handler._count_expected_rows(columns)
|
||||||
|
assert count == 3
|
||||||
|
|
||||||
|
def test_empty_columns_return_zero(self, handler):
|
||||||
|
"""Test empty columns return 0."""
|
||||||
|
columns = ["", None, "Short"]
|
||||||
|
count = handler._count_expected_rows(columns)
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitCellContentForRows:
|
||||||
|
"""Tests for _split_cell_content_for_rows helper."""
|
||||||
|
|
||||||
|
def test_splits_by_product_numbers(self, handler):
|
||||||
|
"""Test splitting by product numbers with expected count."""
|
||||||
|
cell = "Produktnr 1234567 1234568"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert len(result) == 3 # header + 2 values
|
||||||
|
assert result[0] == "Produktnr"
|
||||||
|
assert "1234567" in result[1]
|
||||||
|
assert "1234568" in result[2]
|
||||||
|
|
||||||
|
def test_splits_by_quantities(self, handler):
|
||||||
|
"""Test splitting by quantity patterns."""
|
||||||
|
cell = "Antal 5ST 10ST"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert len(result) == 3 # header + 2 values
|
||||||
|
assert result[0] == "Antal"
|
||||||
|
|
||||||
|
def test_splits_discount_totalsumma(self, handler):
|
||||||
|
"""Test splitting discount+totalsumma columns."""
|
||||||
|
cell = "Rabatt i% Totalsumma 686,88 123,45"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert result[0] == "Totalsumma"
|
||||||
|
assert "686,88" in result[1]
|
||||||
|
assert "123,45" in result[2]
|
||||||
|
|
||||||
|
def test_splits_by_prices(self, handler):
|
||||||
|
"""Test splitting by price patterns."""
|
||||||
|
cell = "Pris 127,20 234,56"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert len(result) >= 2
|
||||||
|
|
||||||
|
def test_fallback_returns_original(self, handler):
|
||||||
|
"""Test fallback returns original cell."""
|
||||||
|
cell = "No patterns here"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert result == ["No patterns here"]
|
||||||
|
|
||||||
|
def test_product_number_with_description(self, handler):
|
||||||
|
"""Test product numbers include trailing description text."""
|
||||||
|
cell = "Art 1234567 Widget A 1234568 Widget B"
|
||||||
|
result = handler._split_cell_content_for_rows(cell, 2)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestSplitCellContent:
|
||||||
|
"""Tests for split_cell_content method."""
|
||||||
|
|
||||||
|
def test_splits_by_product_numbers(self, handler):
|
||||||
|
"""Test splitting by multiple product numbers."""
|
||||||
|
cell = "Produktnr 1234567 1234568 1234569"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
assert result[0] == "Produktnr"
|
||||||
|
assert "1234567" in result
|
||||||
|
assert "1234568" in result
|
||||||
|
assert "1234569" in result
|
||||||
|
|
||||||
|
def test_splits_by_quantities(self, handler):
|
||||||
|
"""Test splitting by multiple quantities."""
|
||||||
|
cell = "Antal 6ST 6ST 1ST"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
assert result[0] == "Antal"
|
||||||
|
assert len(result) >= 3
|
||||||
|
|
||||||
|
def test_splits_discount_amount_interleaved(self, handler):
|
||||||
|
"""Test splitting interleaved discount+amount patterns."""
|
||||||
|
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
# Should extract amounts (3+ digit numbers with decimals)
|
||||||
|
assert result[0] == "Totalsumma"
|
||||||
|
assert "686,88" in result
|
||||||
|
assert "123,45" in result
|
||||||
|
|
||||||
|
def test_splits_by_prices(self, handler):
|
||||||
|
"""Test splitting by prices."""
|
||||||
|
cell = "Pris 127,20 127,20 159,20"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
assert result[0] == "Pris"
|
||||||
|
|
||||||
|
def test_single_value_not_split(self, handler):
|
||||||
|
"""Test single value is not split."""
|
||||||
|
cell = "Single value"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
assert result == ["Single value"]
|
||||||
|
|
||||||
|
def test_single_product_not_split(self, handler):
|
||||||
|
"""Test single product number is not split."""
|
||||||
|
cell = "Produktnr 1234567"
|
||||||
|
result = handler.split_cell_content(cell)
|
||||||
|
|
||||||
|
assert result == ["Produktnr 1234567"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHasMergedHeader:
|
||||||
|
"""Tests for has_merged_header method."""
|
||||||
|
|
||||||
|
def test_none_header_returns_false(self, handler):
|
||||||
|
"""Test None header returns False."""
|
||||||
|
assert handler.has_merged_header(None) is False
|
||||||
|
|
||||||
|
def test_empty_header_returns_false(self, handler):
|
||||||
|
"""Test empty header returns False."""
|
||||||
|
assert handler.has_merged_header([]) is False
|
||||||
|
|
||||||
|
def test_multiple_non_empty_cells_returns_false(self, handler):
|
||||||
|
"""Test multiple non-empty cells returns False."""
|
||||||
|
header = ["Beskrivning", "Antal", "Belopp"]
|
||||||
|
assert handler.has_merged_header(header) is False
|
||||||
|
|
||||||
|
def test_single_cell_with_keywords_returns_true(self, handler):
|
||||||
|
"""Test single cell with multiple keywords returns True."""
|
||||||
|
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
|
||||||
|
assert handler.has_merged_header(header) is True
|
||||||
|
|
||||||
|
def test_single_cell_one_keyword_returns_false(self, handler):
|
||||||
|
"""Test single cell with only one keyword returns False."""
|
||||||
|
header = ["Beskrivning only"]
|
||||||
|
assert handler.has_merged_header(header) is False
|
||||||
|
|
||||||
|
def test_ignores_empty_trailing_cells(self, handler):
|
||||||
|
"""Test ignores empty trailing cells."""
|
||||||
|
header = ["Specifikation Hyra Avdrag", "", "", ""]
|
||||||
|
assert handler.has_merged_header(header) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractFromMergedCells:
|
||||||
|
"""Tests for extract_from_merged_cells method."""
|
||||||
|
|
||||||
|
def test_extracts_single_amount(self, handler):
|
||||||
|
"""Test extracting a single amount."""
|
||||||
|
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||||
|
rows = [["", "", "", "8159"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].amount == "8159"
|
||||||
|
assert items[0].is_deduction is False
|
||||||
|
assert items[0].article_number == "0218103-1201"
|
||||||
|
assert items[0].description == "2 rum och kök"
|
||||||
|
|
||||||
|
def test_extracts_deduction(self, handler):
|
||||||
|
"""Test extracting a deduction (negative amount)."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "-2000"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].amount == "-2000"
|
||||||
|
assert items[0].is_deduction is True
|
||||||
|
# First item (row_index=0) gets description from header, not "Avdrag"
|
||||||
|
# "Avdrag" is only set for subsequent deduction items
|
||||||
|
assert items[0].description is None
|
||||||
|
|
||||||
|
def test_extracts_multiple_amounts_same_row(self, handler):
|
||||||
|
"""Test extracting multiple amounts from same row."""
|
||||||
|
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||||
|
rows = [["", "", "", "8159 -2000"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0].amount == "8159"
|
||||||
|
assert items[1].amount == "-2000"
|
||||||
|
|
||||||
|
def test_extracts_amounts_from_multiple_rows(self, handler):
|
||||||
|
"""Test extracting amounts from multiple rows."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [
|
||||||
|
["", "", "", "8159"],
|
||||||
|
["", "", "", "-2000"],
|
||||||
|
]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
|
||||||
|
def test_skips_small_amounts(self, handler):
|
||||||
|
"""Test skipping small amounts below threshold."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 0
|
||||||
|
|
||||||
|
def test_skips_empty_rows(self, handler):
|
||||||
|
"""Test skipping empty rows."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", ""]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 0
|
||||||
|
|
||||||
|
def test_handles_swedish_format_with_spaces(self, handler):
|
||||||
|
"""Test handling Swedish number format with spaces."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "8 159"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].amount == "8159"
|
||||||
|
|
||||||
|
def test_confidence_is_lower_for_merged(self, handler):
|
||||||
|
"""Test confidence is 0.7 for merged cell extraction."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "8159"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert items[0].confidence == 0.7
|
||||||
|
|
||||||
|
def test_empty_header_still_extracts(self, handler):
|
||||||
|
"""Test extraction works with empty header."""
|
||||||
|
header = []
|
||||||
|
rows = [["", "", "", "8159"]]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].description is None
|
||||||
|
assert items[0].article_number is None
|
||||||
|
|
||||||
|
def test_row_index_increments(self, handler):
|
||||||
|
"""Test row_index increments for each item."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
# Use separate rows to avoid regex grouping issues
|
||||||
|
rows = [
|
||||||
|
["", "", "", "8159"],
|
||||||
|
["", "", "", "5000"],
|
||||||
|
["", "", "", "-2000"],
|
||||||
|
]
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
# Should have 3 items from 3 rows
|
||||||
|
assert len(items) == 3
|
||||||
|
assert items[0].row_index == 0
|
||||||
|
assert items[1].row_index == 1
|
||||||
|
assert items[2].row_index == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestMinAmountThreshold:
|
||||||
|
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
|
||||||
|
|
||||||
|
def test_threshold_value(self):
|
||||||
|
"""Test the threshold constant value."""
|
||||||
|
assert MIN_AMOUNT_THRESHOLD == 100
|
||||||
|
|
||||||
|
def test_amounts_at_threshold_included(self, handler):
|
||||||
|
"""Test amounts exactly at threshold are included."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "100"]] # Exactly at threshold
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0].amount == "100"
|
||||||
|
|
||||||
|
def test_amounts_below_threshold_excluded(self, handler):
|
||||||
|
"""Test amounts below threshold are excluded."""
|
||||||
|
header = ["Specifikation"]
|
||||||
|
rows = [["", "", "", "99"]] # Below threshold
|
||||||
|
|
||||||
|
items = handler.extract_from_merged_cells(header, rows)
|
||||||
|
|
||||||
|
assert len(items) == 0
|
||||||
157
tests/table/test_models.py
Normal file
157
tests/table/test_models.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""
|
||||||
|
Tests for Line Items Data Models
|
||||||
|
|
||||||
|
Tests for LineItem and LineItemsResult dataclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from backend.table.models import LineItem, LineItemsResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestLineItem:
|
||||||
|
"""Tests for LineItem dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values for optional fields."""
|
||||||
|
item = LineItem(row_index=0)
|
||||||
|
|
||||||
|
assert item.row_index == 0
|
||||||
|
assert item.description is None
|
||||||
|
assert item.quantity is None
|
||||||
|
assert item.unit is None
|
||||||
|
assert item.unit_price is None
|
||||||
|
assert item.amount is None
|
||||||
|
assert item.article_number is None
|
||||||
|
assert item.vat_rate is None
|
||||||
|
assert item.is_deduction is False
|
||||||
|
assert item.confidence == 0.9
|
||||||
|
|
||||||
|
def test_custom_confidence(self):
|
||||||
|
"""Test setting custom confidence."""
|
||||||
|
item = LineItem(row_index=0, confidence=0.7)
|
||||||
|
assert item.confidence == 0.7
|
||||||
|
|
||||||
|
def test_is_deduction_true(self):
|
||||||
|
"""Test is_deduction flag."""
|
||||||
|
item = LineItem(row_index=0, is_deduction=True)
|
||||||
|
assert item.is_deduction is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestLineItemsResult:
|
||||||
|
"""Tests for LineItemsResult dataclass."""
|
||||||
|
|
||||||
|
def test_total_amount_empty_items(self):
|
||||||
|
"""Test total_amount returns None for empty items."""
|
||||||
|
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||||
|
assert result.total_amount is None
|
||||||
|
|
||||||
|
def test_total_amount_single_item(self):
|
||||||
|
"""Test total_amount with single item."""
|
||||||
|
items = [LineItem(row_index=0, amount="100,00")]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "100,00"
|
||||||
|
|
||||||
|
def test_total_amount_multiple_items(self):
|
||||||
|
"""Test total_amount with multiple items."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="100,00"),
|
||||||
|
LineItem(row_index=1, amount="200,50"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "300,50"
|
||||||
|
|
||||||
|
def test_total_amount_with_deduction(self):
|
||||||
|
"""Test total_amount includes negative amounts (deductions)."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="1000,00"),
|
||||||
|
LineItem(row_index=1, amount="-200,00", is_deduction=True),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "800,00"
|
||||||
|
|
||||||
|
def test_total_amount_swedish_format_with_spaces(self):
|
||||||
|
"""Test total_amount handles Swedish format with spaces."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="1 234,56"),
|
||||||
|
LineItem(row_index=1, amount="2 000,00"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "3 234,56"
|
||||||
|
|
||||||
|
def test_total_amount_invalid_amount_skipped(self):
|
||||||
|
"""Test total_amount skips invalid amounts."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="100,00"),
|
||||||
|
LineItem(row_index=1, amount="invalid"),
|
||||||
|
LineItem(row_index=2, amount="200,00"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
# Invalid amount is skipped
|
||||||
|
assert result.total_amount == "300,00"
|
||||||
|
|
||||||
|
def test_total_amount_none_amount_skipped(self):
|
||||||
|
"""Test total_amount skips None amounts."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="100,00"),
|
||||||
|
LineItem(row_index=1, amount=None),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "100,00"
|
||||||
|
|
||||||
|
def test_total_amount_all_invalid_returns_none(self):
|
||||||
|
"""Test total_amount returns None when all amounts are invalid."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="invalid"),
|
||||||
|
LineItem(row_index=1, amount="also invalid"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount is None
|
||||||
|
|
||||||
|
def test_total_amount_large_numbers(self):
|
||||||
|
"""Test total_amount handles large numbers."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="123 456,78"),
|
||||||
|
LineItem(row_index=1, amount="876 543,22"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "1 000 000,00"
|
||||||
|
|
||||||
|
def test_total_amount_decimal_precision(self):
|
||||||
|
"""Test total_amount maintains decimal precision."""
|
||||||
|
items = [
|
||||||
|
LineItem(row_index=0, amount="0,01"),
|
||||||
|
LineItem(row_index=1, amount="0,02"),
|
||||||
|
]
|
||||||
|
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||||
|
|
||||||
|
assert result.total_amount == "0,03"
|
||||||
|
|
||||||
|
def test_is_reversed_default_false(self):
|
||||||
|
"""Test is_reversed defaults to False."""
|
||||||
|
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||||
|
assert result.is_reversed is False
|
||||||
|
|
||||||
|
def test_is_reversed_can_be_set(self):
|
||||||
|
"""Test is_reversed can be set to True."""
|
||||||
|
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
|
||||||
|
assert result.is_reversed is True
|
||||||
|
|
||||||
|
def test_header_row_preserved(self):
|
||||||
|
"""Test header_row is preserved."""
|
||||||
|
header = ["Beskrivning", "Antal", "Belopp"]
|
||||||
|
result = LineItemsResult(items=[], header_row=header, raw_html="")
|
||||||
|
assert result.header_row == header
|
||||||
|
|
||||||
|
def test_raw_html_preserved(self):
|
||||||
|
"""Test raw_html is preserved."""
|
||||||
|
html = "<table><tr><td>Test</td></tr></table>"
|
||||||
|
result = LineItemsResult(items=[], header_row=[], raw_html=html)
|
||||||
|
assert result.raw_html == html
|
||||||
@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
|
|||||||
assert len(results) == 1
|
assert len(results) == 1
|
||||||
assert results[0].cells == [] # Empty cells list
|
assert results[0].cells == [] # Empty cells list
|
||||||
assert results[0].html == "<table></table>"
|
assert results[0].html == "<table></table>"
|
||||||
|
|
||||||
|
def test_parse_paddlex_result_with_dict_ocr_data(self):
|
||||||
|
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"table_res_list": [
|
||||||
|
{
|
||||||
|
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
|
||||||
|
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
|
||||||
|
"table_ocr_pred": {
|
||||||
|
"rec_texts": ["A", "B"],
|
||||||
|
"rec_scores": [0.99, 0.98],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parsing_res_list": [
|
||||||
|
{"label": "table", "bbox": [10, 20, 200, 300]},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_pipeline.predict.return_value = [mock_result]
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert len(results[0].cells) == 2
|
||||||
|
assert results[0].cells[0]["text"] == "A"
|
||||||
|
assert results[0].cells[1]["text"] == "B"
|
||||||
|
|
||||||
|
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
|
||||||
|
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
|
||||||
|
mock_result = {
|
||||||
|
"table_res_list": [
|
||||||
|
{
|
||||||
|
"cell_box_list": [[0, 0, 50, 20]],
|
||||||
|
"pred_html": "<table><tr><td>A</td></tr></table>",
|
||||||
|
"table_ocr_pred": ["A"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"parsing_res_list": [
|
||||||
|
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
|
||||||
|
],
|
||||||
|
}
|
||||||
|
mock_pipeline.predict.return_value = [mock_result]
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
# Should use default bbox [0,0,0,0] when not found
|
||||||
|
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
class TestIteratorResults:
|
||||||
|
"""Tests for iterator/generator result handling."""
|
||||||
|
|
||||||
|
def test_handles_iterator_results(self):
|
||||||
|
"""Test handling of iterator results from pipeline."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
|
||||||
|
# Return a generator instead of list
|
||||||
|
def result_generator():
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "table"
|
||||||
|
element.bbox = [0, 0, 100, 100]
|
||||||
|
element.html = "<table></table>"
|
||||||
|
element.score = 0.9
|
||||||
|
element.cells = []
|
||||||
|
mock_result = MagicMock(spec=["layout_elements"])
|
||||||
|
mock_result.layout_elements = [element]
|
||||||
|
yield mock_result
|
||||||
|
|
||||||
|
mock_pipeline.predict.return_value = result_generator()
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_handles_failed_iterator_conversion(self):
|
||||||
|
"""Test handling when iterator conversion fails."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
|
||||||
|
# Create an object that has __iter__ but fails when converted to list
|
||||||
|
class FailingIterator:
|
||||||
|
def __iter__(self):
|
||||||
|
raise RuntimeError("Iterator failed")
|
||||||
|
|
||||||
|
mock_pipeline.predict.return_value = FailingIterator()
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
# Should return empty list, not raise
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathConversion:
|
||||||
|
"""Tests for path handling."""
|
||||||
|
|
||||||
|
def test_converts_path_object_to_string(self):
|
||||||
|
"""Test that Path objects are converted to strings."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
mock_pipeline.predict.return_value = []
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
path = Path("/some/path/to/image.png")
|
||||||
|
|
||||||
|
detector.detect(path)
|
||||||
|
|
||||||
|
# Should be called with string, not Path
|
||||||
|
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
|
||||||
|
|
||||||
|
|
||||||
|
class TestHtmlExtraction:
|
||||||
|
"""Tests for HTML extraction from different element formats."""
|
||||||
|
|
||||||
|
def test_extracts_html_from_res_dict(self):
|
||||||
|
"""Test extracting HTML from element.res dictionary."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "table"
|
||||||
|
element.bbox = [0, 0, 100, 100]
|
||||||
|
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
|
||||||
|
element.score = 0.9
|
||||||
|
element.cells = []
|
||||||
|
# Remove direct html attribute
|
||||||
|
del element.html
|
||||||
|
del element.table_html
|
||||||
|
|
||||||
|
mock_result = MagicMock(spec=["layout_elements"])
|
||||||
|
mock_result.layout_elements = [element]
|
||||||
|
mock_pipeline.predict.return_value = [mock_result]
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
|
||||||
|
|
||||||
|
def test_returns_empty_html_when_not_found(self):
|
||||||
|
"""Test empty HTML when no html attribute found."""
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "table"
|
||||||
|
element.bbox = [0, 0, 100, 100]
|
||||||
|
element.score = 0.9
|
||||||
|
element.cells = []
|
||||||
|
# Remove all html attributes
|
||||||
|
del element.html
|
||||||
|
del element.table_html
|
||||||
|
del element.res
|
||||||
|
|
||||||
|
mock_result = MagicMock(spec=["layout_elements"])
|
||||||
|
mock_result.layout_elements = [element]
|
||||||
|
mock_pipeline.predict.return_value = [mock_result]
|
||||||
|
|
||||||
|
detector = TableDetector(pipeline=mock_pipeline)
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
results = detector.detect(image)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].html == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestTableTypeDetection:
|
||||||
|
"""Tests for table type detection."""
|
||||||
|
|
||||||
|
def test_detects_borderless_table(self):
|
||||||
|
"""Test detection of borderless table type via _get_table_type."""
|
||||||
|
detector = TableDetector()
|
||||||
|
|
||||||
|
# Create mock element with borderless label
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "borderless_table"
|
||||||
|
|
||||||
|
result = detector._get_table_type(element)
|
||||||
|
assert result == "wireless"
|
||||||
|
|
||||||
|
def test_detects_wireless_table_label(self):
|
||||||
|
"""Test detection of wireless table type."""
|
||||||
|
detector = TableDetector()
|
||||||
|
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "wireless_table"
|
||||||
|
|
||||||
|
result = detector._get_table_type(element)
|
||||||
|
assert result == "wireless"
|
||||||
|
|
||||||
|
def test_defaults_to_wired_table(self):
|
||||||
|
"""Test default table type is wired."""
|
||||||
|
detector = TableDetector()
|
||||||
|
|
||||||
|
element = MagicMock()
|
||||||
|
element.label = "table"
|
||||||
|
|
||||||
|
result = detector._get_table_type(element)
|
||||||
|
assert result == "wired"
|
||||||
|
|
||||||
|
def test_type_attribute_instead_of_label(self):
|
||||||
|
"""Test table type detection using type attribute."""
|
||||||
|
detector = TableDetector()
|
||||||
|
|
||||||
|
element = MagicMock()
|
||||||
|
element.type = "wireless"
|
||||||
|
del element.label # Remove label
|
||||||
|
|
||||||
|
result = detector._get_table_type(element)
|
||||||
|
assert result == "wireless"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineRuntimeError:
|
||||||
|
"""Tests for pipeline runtime errors."""
|
||||||
|
|
||||||
|
def test_raises_runtime_error_when_pipeline_none(self):
|
||||||
|
"""Test RuntimeError when pipeline is None during detect."""
|
||||||
|
detector = TableDetector()
|
||||||
|
detector._initialized = True # Bypass lazy init
|
||||||
|
detector._pipeline = None
|
||||||
|
|
||||||
|
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as exc_info:
|
||||||
|
detector.detect(image)
|
||||||
|
|
||||||
|
assert "not initialized" in str(exc_info.value).lower()
|
||||||
|
|||||||
@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
|
|||||||
rows = extractor._group_by_row(elements)
|
rows = extractor._group_by_row(elements)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
|
def test_group_by_row_varying_heights_uses_average(self, extractor):
|
||||||
|
"""Test grouping handles varying element heights using dynamic average.
|
||||||
|
|
||||||
|
When elements have varying heights, the row center should be recalculated
|
||||||
|
as new elements are added, preventing tall elements from being incorrectly
|
||||||
|
grouped with the next row.
|
||||||
|
"""
|
||||||
|
# First element: small height, center_y = 105
|
||||||
|
# Second element: tall, center_y = 115 (but should still be same row)
|
||||||
|
# Third element: next row, center_y = 160
|
||||||
|
elements = [
|
||||||
|
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
|
||||||
|
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
|
||||||
|
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
|
||||||
|
]
|
||||||
|
rows = extractor._group_by_row(elements)
|
||||||
|
|
||||||
|
# With dynamic average, both first and second element should be same row
|
||||||
|
assert len(rows) == 2
|
||||||
|
assert len(rows[0]) == 2 # Short and Tall item
|
||||||
|
assert len(rows[1]) == 1 # Next row
|
||||||
|
|
||||||
|
def test_group_by_row_empty_input(self, extractor):
|
||||||
|
"""Test grouping with empty input returns empty list."""
|
||||||
|
rows = extractor._group_by_row([])
|
||||||
|
assert rows == []
|
||||||
|
|
||||||
def test_looks_like_line_item_with_amount(self, extractor):
|
def test_looks_like_line_item_with_amount(self, extractor):
|
||||||
"""Test line item detection with amount."""
|
"""Test line item detection with amount."""
|
||||||
row = [
|
row = [
|
||||||
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
|
|||||||
assert len(elements) == 4
|
assert len(elements) == 4
|
||||||
|
|
||||||
|
|
||||||
|
class TestExceptionHandling:
|
||||||
|
"""Tests for exception handling in text element extraction."""
|
||||||
|
|
||||||
|
def test_extract_text_elements_handles_missing_bbox(self):
|
||||||
|
"""Test that missing bbox is handled gracefully."""
|
||||||
|
extractor = TextLineItemsExtractor()
|
||||||
|
parsing_res = [
|
||||||
|
{"label": "text", "text": "No bbox"}, # Missing bbox
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||||
|
]
|
||||||
|
elements = extractor._extract_text_elements(parsing_res)
|
||||||
|
# Should only have 1 valid element
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0].text == "Valid"
|
||||||
|
|
||||||
|
def test_extract_text_elements_handles_invalid_bbox(self):
|
||||||
|
"""Test that invalid bbox (less than 4 values) is handled."""
|
||||||
|
extractor = TextLineItemsExtractor()
|
||||||
|
parsing_res = [
|
||||||
|
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||||
|
]
|
||||||
|
elements = extractor._extract_text_elements(parsing_res)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0].text == "Valid"
|
||||||
|
|
||||||
|
def test_extract_text_elements_handles_none_text(self):
|
||||||
|
"""Test that None text is handled."""
|
||||||
|
extractor = TextLineItemsExtractor()
|
||||||
|
parsing_res = [
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
|
||||||
|
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||||
|
]
|
||||||
|
elements = extractor._extract_text_elements(parsing_res)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0].text == "Valid"
|
||||||
|
|
||||||
|
def test_extract_text_elements_handles_empty_string(self):
|
||||||
|
"""Test that empty string text is skipped."""
|
||||||
|
extractor = TextLineItemsExtractor()
|
||||||
|
parsing_res = [
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
|
||||||
|
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||||
|
]
|
||||||
|
elements = extractor._extract_text_elements(parsing_res)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0].text == "Valid"
|
||||||
|
|
||||||
|
def test_extract_text_elements_handles_malformed_element(self):
|
||||||
|
"""Test that completely malformed elements are handled."""
|
||||||
|
extractor = TextLineItemsExtractor()
|
||||||
|
parsing_res = [
|
||||||
|
"not a dict", # String instead of dict
|
||||||
|
123, # Number instead of dict
|
||||||
|
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||||
|
]
|
||||||
|
elements = extractor._extract_text_elements(parsing_res)
|
||||||
|
assert len(elements) == 1
|
||||||
|
assert elements[0].text == "Valid"
|
||||||
|
|
||||||
|
|
||||||
class TestConvertTextLineItem:
|
class TestConvertTextLineItem:
|
||||||
"""Tests for convert_text_line_item function."""
|
"""Tests for convert_text_line_item function."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user