Compare commits

...

2 Commits

Author SHA1 Message Date
Yaojia Wang
0990239e9c feat: add field-specific bbox expansion strategies for YOLO training
Implement center-point based bbox scaling with directional compensation
to capture field labels that typically appear above or to the left of
field values. This improves YOLO training data quality by including
contextual information around field values.

Key changes:
- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function
- Define field-specific strategies (ocr_number, bankgiro, invoice_date, etc.)
- Support manual_mode for minimal padding (no scaling)
- Integrate expand_bbox into AnnotationGenerator
- Add FIELD_TO_CLASS mapping for field_name to class_name lookup
- Comprehensive tests with 100% coverage (45 tests)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 22:56:52 +01:00
Yaojia Wang
8723ef4653 refactor: split line_items_extractor into smaller modules with comprehensive tests
- Extract models.py (LineItem, LineItemsResult dataclasses)
- Extract html_table_parser.py (ColumnMapper, HtmlTableParser)
- Extract merged_cell_handler.py (MergedCellHandler for PP-StructureV3 merged cells)
- Reduce line_items_extractor.py from 971 to 396 lines
- Add constants for magic numbers (MIN_AMOUNT_THRESHOLD, ROW_GROUPING_THRESHOLD, etc.)
- Fix row grouping algorithm in text_line_items_extractor.py
- Demote INFO logs to DEBUG level in structure_detector.py
- Add 209 tests achieving 85%+ coverage on main modules

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-03 23:02:00 +01:00
24 changed files with 3654 additions and 859 deletions

View File

@@ -0,0 +1,204 @@
"""
HTML Table Parser
Parses HTML tables into structured data and maps columns to field names.
"""
from html.parser import HTMLParser
import logging
logger = logging.getLogger(__name__)
# Configuration constants
# Minimum pattern length to avoid false positives from short substrings
MIN_PATTERN_MATCH_LENGTH = 3
# Exact match bonus for column mapping priority
EXACT_MATCH_BONUS = 100
# Swedish column name mappings
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
COLUMN_MAPPINGS = {
"article_number": [
"art nummer",
"artikelnummer",
"artikel",
"artnr",
"art.nr",
"art nr",
"objektnummer", # Rental: property reference
"objekt",
],
"description": [
"beskrivning",
"produktbeskrivning",
"produkt",
"tjänst",
"text",
"benämning",
"vara/tjänst",
"vara",
# Rental invoice specific
"specifikation",
"spec",
"hyresperiod", # Rental period
"period",
"typ", # Type of charge
# Utility bills
"förbrukning", # Consumption
"avläsning", # Meter reading
],
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "", "kvm"],
"unit": ["enhet", "unit"],
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
"amount": [
"belopp",
"summa",
"total",
"netto",
"rad summa",
# Rental specific
"hyra", # Rent
"avgift", # Fee
"kostnad", # Cost
"debitering", # Charge
"totalt", # Total
],
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
# Additional field for rental: deductions/adjustments
"deduction": [
"avdrag", # Deduction
"rabatt", # Discount
"kredit", # Credit
],
}
# Keywords that indicate NOT a line items table
SUMMARY_KEYWORDS = [
"frakt",
"faktura.avg",
"fakturavg",
"exkl.moms",
"att betala",
"öresavr",
"bankgiro",
"plusgiro",
"ocr",
"forfallodatum",
"förfallodatum",
]
class _TableHTMLParser(HTMLParser):
"""Internal HTML parser for tables."""
def __init__(self):
super().__init__()
self.rows: list[list[str]] = []
self.current_row: list[str] = []
self.current_cell: str = ""
self.in_td = False
self.in_thead = False
self.header_row: list[str] = []
def handle_starttag(self, tag, attrs):
if tag == "tr":
self.current_row = []
elif tag in ("td", "th"):
self.in_td = True
self.current_cell = ""
elif tag == "thead":
self.in_thead = True
def handle_endtag(self, tag):
if tag in ("td", "th"):
self.in_td = False
self.current_row.append(self.current_cell.strip())
elif tag == "tr":
if self.current_row:
if self.in_thead:
self.header_row = self.current_row
else:
self.rows.append(self.current_row)
elif tag == "thead":
self.in_thead = False
def handle_data(self, data):
if self.in_td:
self.current_cell += data
class HTMLTableParser:
"""Parse HTML tables into structured data."""
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
"""
Parse HTML table and return header and rows.
Args:
html: HTML string containing table.
Returns:
Tuple of (header_row, data_rows).
"""
parser = _TableHTMLParser()
parser.feed(html)
return parser.header_row, parser.rows
class ColumnMapper:
"""Map column headers to field names."""
def __init__(self, mappings: dict[str, list[str]] | None = None):
"""
Initialize column mapper.
Args:
mappings: Custom column mappings. Uses Swedish defaults if None.
"""
self.mappings = mappings or COLUMN_MAPPINGS
def map(self, headers: list[str]) -> dict[int, str]:
"""
Map column indices to field names.
Args:
headers: List of column header strings.
Returns:
Dictionary mapping column index to field name.
"""
mapping = {}
for idx, header in enumerate(headers):
normalized = self._normalize(header)
if not normalized.strip():
continue
best_match = None
best_match_len = 0
for field_name, patterns in self.mappings.items():
for pattern in patterns:
if pattern == normalized:
# Exact match gets highest priority
best_match = field_name
best_match_len = len(pattern) + EXACT_MATCH_BONUS
break
elif pattern in normalized and len(pattern) > best_match_len:
# Partial match requires minimum length to avoid false positives
if len(pattern) >= MIN_PATTERN_MATCH_LENGTH:
best_match = field_name
best_match_len = len(pattern)
if best_match_len > EXACT_MATCH_BONUS:
# Found exact match, no need to check other fields
break
if best_match:
mapping[idx] = best_match
return mapping
def _normalize(self, header: str) -> str:
"""Normalize header text for matching."""
return header.lower().strip().replace(".", "").replace("-", " ")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,423 @@
"""
Merged Cell Handler
Handles detection and extraction of data from tables with merged cells,
a common issue with PP-StructureV3 OCR output.
"""
import re
import logging
from typing import TYPE_CHECKING
from .models import LineItem
if TYPE_CHECKING:
from .html_table_parser import ColumnMapper
logger = logging.getLogger(__name__)
# Minimum positive amount to consider as line item (filters noise like row indices)
MIN_AMOUNT_THRESHOLD = 100
class MergedCellHandler:
"""Handles tables with vertically merged cells from PP-StructureV3."""
def __init__(self, mapper: "ColumnMapper"):
"""
Initialize handler.
Args:
mapper: ColumnMapper instance for header keyword detection.
"""
self.mapper = mapper
def has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
"""
Check if table rows contain vertically merged data in single cells.
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
"""
if not rows:
return False
for row in rows:
for cell in row:
if not cell or len(cell) < 20:
continue
# Check for multiple product numbers (7+ digit patterns)
product_nums = re.findall(r"\b\d{7}\b", cell)
if len(product_nums) >= 2:
logger.debug(f"has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
return True
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
if len(prices) >= 3:
logger.debug(f"has_vertically_merged_cells: found {len(prices)} prices in cell")
return True
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
if len(quantities) >= 2:
logger.debug(f"has_vertically_merged_cells: found {len(quantities)} quantities in cell")
return True
return False
def split_merged_rows(
self, rows: list[list[str]]
) -> tuple[list[str], list[list[str]]]:
"""
Split vertically merged cells back into separate rows.
Handles complex cases where PP-StructureV3 merges content across
multiple HTML rows. For example, 5 line items might be spread across
3 HTML rows with content mixed together.
Strategy:
1. Merge all row content per column
2. Detect how many actual data rows exist (by counting product numbers)
3. Split each column's content into that many lines
Returns header and data rows.
"""
if not rows:
return [], []
# Filter out completely empty rows
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
if not non_empty_rows:
return [], rows
# Determine column count
col_count = max(len(r) for r in non_empty_rows)
# Merge content from all rows for each column
merged_columns = []
for col_idx in range(col_count):
col_content = []
for row in non_empty_rows:
if col_idx < len(row) and row[col_idx].strip():
col_content.append(row[col_idx].strip())
merged_columns.append(" ".join(col_content))
logger.debug(f"split_merged_rows: merged columns = {merged_columns}")
# Count how many actual data rows we should have
# Use the column with most product numbers as reference
expected_rows = self._count_expected_rows(merged_columns)
logger.debug(f"split_merged_rows: expecting {expected_rows} data rows")
if expected_rows <= 1:
# Not enough data for splitting
return [], rows
# Split each column based on expected row count
split_columns = []
for col_idx, col_text in enumerate(merged_columns):
if not col_text.strip():
split_columns.append([""] * (expected_rows + 1)) # +1 for header
continue
lines = self._split_cell_content_for_rows(col_text, expected_rows)
split_columns.append(lines)
# Ensure all columns have same number of lines (immutable approach)
max_lines = max(len(col) for col in split_columns)
split_columns = [
col + [""] * (max_lines - len(col))
for col in split_columns
]
logger.debug(f"split_merged_rows: split into {max_lines} lines total")
# First line is header, rest are data rows
header = [col[0] for col in split_columns]
data_rows = []
for line_idx in range(1, max_lines):
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
if any(cell.strip() for cell in row):
data_rows.append(row)
logger.debug(f"split_merged_rows: header={header}, data_rows count={len(data_rows)}")
return header, data_rows
def _count_expected_rows(self, merged_columns: list[str]) -> int:
"""
Count how many data rows should exist based on content patterns.
Returns the maximum count found from:
- Product numbers (7 digits)
- Quantity patterns (number + ST/PCS)
- Amount patterns (in columns likely to be totals)
"""
max_count = 0
for col_text in merged_columns:
if not col_text:
continue
# Count product numbers (most reliable indicator)
product_nums = re.findall(r"\b\d{7}\b", col_text)
max_count = max(max_count, len(product_nums))
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
max_count = max(max_count, len(quantities))
return max_count
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
"""
Split cell content knowing how many data rows we expect.
This is smarter than split_cell_content because it knows the target count.
"""
cell = cell.strip()
# Try product number split first
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) == expected_rows:
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
# Include description text after each product number
values = []
for i in range(1, len(parts), 2): # Odd indices are product numbers
if i < len(parts):
prod_num = parts[i].strip()
# Check if there's description text after
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
# If description looks like text (not another pattern), include it
if desc and not re.match(r"^\d{7}$", desc):
# Truncate at next product number pattern if any
desc_clean = re.split(r"\d{7}", desc)[0].strip()
if desc_clean:
values.append(f"{prod_num} {desc_clean}")
else:
values.append(prod_num)
else:
values.append(prod_num)
if len(values) == expected_rows:
return [header] + values
# Try quantity split
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) == expected_rows:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
if len(values) == expected_rows:
return [header] + values
# Try amount split for discount+totalsumma columns
cell_lower = cell.lower()
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
if has_discount and has_total:
# Extract only amounts (3+ digit numbers), skip discount percentages
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= expected_rows:
# Take the last expected_rows amounts (they are likely the totals)
return ["Totalsumma"] + amounts[:expected_rows]
# Try price split
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= expected_rows:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
if len(values) >= expected_rows:
return [header] + values[:expected_rows]
# Fall back to original single-value behavior
return [cell]
def split_cell_content(self, cell: str) -> list[str]:
"""
Split a cell containing merged multi-line content.
Strategies:
1. Look for product number patterns (7 digits)
2. Look for quantity patterns (number + ST/PCS)
3. Look for price patterns (with decimal)
4. Handle interleaved discount+amount patterns
"""
cell = cell.strip()
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
product_pattern = re.compile(r"(\b\d{7}\b)")
products = product_pattern.findall(cell)
if len(products) >= 2:
# Extract header (text before first product number) and values
parts = product_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
return [header] + values
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
quantities = qty_pattern.findall(cell)
if len(quantities) >= 2:
parts = qty_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
return [header] + values
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
# Check if header contains two keywords indicating merged columns
cell_lower = cell.lower()
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
if has_discount_header and has_amount_header:
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
amounts = amount_pattern.findall(cell)
if len(amounts) >= 2:
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
# This avoids the "Rabatt" keyword causing is_deduction=True
header = "Totalsumma"
return [header] + amounts
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
prices = price_pattern.findall(cell)
if len(prices) >= 2:
parts = price_pattern.split(cell)
header = parts[0].strip() if parts else ""
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
return [header] + values
# No pattern detected, return as single value
return [cell]
def has_merged_header(self, header: list[str] | None) -> bool:
"""
Check if header appears to be a merged cell containing multiple column names.
This happens when OCR merges table headers into a single cell, e.g.:
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
Also handles cases where PP-StructureV3 produces headers like:
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
"""
if header is None or not header:
return False
# Filter out empty cells to find the actual content
non_empty_cells = [h for h in header if h.strip()]
# Check if we have a single non-empty cell that contains multiple keywords
if len(non_empty_cells) == 1:
header_text = non_empty_cells[0].lower()
# Count how many column keywords are in this single cell
keyword_count = 0
for patterns in self.mapper.mappings.values():
for pattern in patterns:
if pattern in header_text:
keyword_count += 1
break # Only count once per field type
logger.debug(f"has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
return keyword_count >= 2
return False
def extract_from_merged_cells(
self, header: list[str], rows: list[list[str]]
) -> list[LineItem]:
"""
Extract line items from tables with merged cells.
For poorly OCR'd tables like:
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
Row 1: ["", "", "", "8159"] <- amount row
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
Or:
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
"""
items = []
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
amount_pattern = re.compile(
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
)
# Try to parse header cell for description info
header_text = " ".join(h for h in header if h.strip()) if header else ""
logger.debug(f"extract_from_merged_cells: header_text='{header_text}'")
logger.debug(f"extract_from_merged_cells: rows={rows}")
# Extract description from header
description = None
article_number = None
# Look for object number pattern (e.g., "0218103-1201")
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
if obj_match:
article_number = obj_match.group(1)
# Look for description after object number
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
if desc_match:
description = desc_match.group(1).strip()
row_index = 0
for row in rows:
# Combine all non-empty cells in the row
row_text = " ".join(cell.strip() for cell in row if cell.strip())
logger.debug(f"extract_from_merged_cells: row text='{row_text}'")
if not row_text:
continue
# Find all amounts in the row
amounts = amount_pattern.findall(row_text)
logger.debug(f"extract_from_merged_cells: amounts={amounts}")
for amt_str in amounts:
# Clean the amount string
cleaned = amt_str.replace(" ", "").strip()
if not cleaned or cleaned == "-":
continue
is_deduction = cleaned.startswith("-")
# Skip small positive numbers that are likely not amounts
# (e.g., row indices, small percentages)
if not is_deduction:
try:
val = float(cleaned.replace(",", "."))
if val < MIN_AMOUNT_THRESHOLD:
continue
except ValueError:
continue
# Create a line item for each amount
item = LineItem(
row_index=row_index,
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
article_number=article_number if row_index == 0 else None,
amount=cleaned,
is_deduction=is_deduction,
confidence=0.7,
)
items.append(item)
row_index += 1
logger.debug(f"extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
return items

View File

@@ -0,0 +1,61 @@
"""
Line Items Data Models
Dataclasses for line item extraction results.
"""
from dataclasses import dataclass
from decimal import Decimal, InvalidOperation
@dataclass
class LineItem:
"""Single line item from invoice."""
row_index: int
description: str | None = None
quantity: str | None = None
unit: str | None = None
unit_price: str | None = None
amount: str | None = None
article_number: str | None = None
vat_rate: str | None = None
is_deduction: bool = False # True if this row is a deduction/discount
confidence: float = 0.9
@dataclass
class LineItemsResult:
"""Result of line items extraction."""
items: list[LineItem]
header_row: list[str]
raw_html: str
is_reversed: bool = False
@property
def total_amount(self) -> str | None:
"""Calculate total amount from line items (deduction rows have negative amounts)."""
if not self.items:
return None
total = Decimal("0")
for item in self.items:
if item.amount:
try:
# Parse Swedish number format (1 234,56)
amount_str = item.amount.replace(" ", "").replace(",", ".")
total += Decimal(amount_str)
except InvalidOperation:
pass
if total == 0:
return None
# Format back to Swedish format
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
# Fix the space/comma swap
parts = formatted.rsplit(",", 1)
if len(parts) == 2:
return parts[0].replace(" ", " ") + "," + parts[1]
return formatted

View File

@@ -158,36 +158,36 @@ class TableDetector:
return tables
# Log raw result type for debugging
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
logger.debug(f"PP-StructureV3 raw results type: {type(results).__name__}")
# Handle case where results is a single dict-like object (PaddleX 3.x)
# rather than a list of results
if hasattr(results, "get") and not isinstance(results, list):
# Single result object - wrap in list for uniform processing
logger.info("Results is dict-like, wrapping in list")
logger.debug("Results is dict-like, wrapping in list")
results = [results]
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
# Iterator or generator - convert to list
try:
results = list(results)
logger.info(f"Converted iterator to list with {len(results)} items")
logger.debug(f"Converted iterator to list with {len(results)} items")
except Exception as e:
logger.warning(f"Failed to convert results to list: {e}")
return tables
logger.info(f"Processing {len(results)} result(s)")
logger.debug(f"Processing {len(results)} result(s)")
for i, result in enumerate(results):
try:
result_type = type(result).__name__
has_get = hasattr(result, "get")
has_layout = hasattr(result, "layout_elements")
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
logger.debug(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
# Try PaddleX 3.x API first (dict-like with table_res_list)
if has_get:
parsed = self._parse_paddlex_result(result)
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
logger.debug(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
tables.extend(parsed)
continue
@@ -201,14 +201,14 @@ class TableDetector:
if table_result and table_result.confidence >= self.config.min_confidence:
tables.append(table_result)
legacy_count += 1
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
logger.debug(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
else:
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
except Exception as e:
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
continue
logger.info(f"Total tables detected: {len(tables)}")
logger.debug(f"Total tables detected: {len(tables)}")
return tables
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
@@ -223,7 +223,7 @@ class TableDetector:
result_keys = list(result.keys())
elif hasattr(result, "__dict__"):
result_keys = list(result.__dict__.keys())
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
logger.debug(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
# Get table results from PaddleX 3.x API
# Handle both dict.get() and attribute access
@@ -234,8 +234,8 @@ class TableDetector:
table_res_list = getattr(result, "table_res_list", None)
parsing_res_list = getattr(result, "parsing_res_list", [])
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
logger.debug(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
logger.debug(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
if not table_res_list:
# Log available keys/attributes for debugging
@@ -330,7 +330,7 @@ class TableDetector:
# Default confidence for PaddleX 3.x results
confidence = 0.9
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
logger.debug(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
tables.append(TableDetectionResult(
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
html=html,
@@ -467,14 +467,14 @@ class TableDetector:
if not pdf_path.exists():
raise FileNotFoundError(f"PDF not found: {pdf_path}")
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
logger.debug(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
# Render specific page
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
if page_no == page_number:
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
logger.debug(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
return self.detect(image_array)
raise ValueError(f"Page {page_number} not found in PDF")

View File

@@ -15,6 +15,11 @@ import logging
logger = logging.getLogger(__name__)
# Configuration constants
DEFAULT_ROW_TOLERANCE = 15.0 # Max vertical distance (pixels) to consider same row
MIN_ITEMS_FOR_VALID_EXTRACTION = 2 # Minimum items required for valid extraction
MIN_TEXT_ELEMENTS_FOR_EXTRACTION = 5 # Minimum text elements needed to attempt extraction
@dataclass
class TextElement:
@@ -65,7 +70,10 @@ class TextLineItemsResult:
extraction_method: str = "text_spatial"
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
# Amount pattern matches Swedish, US, and simple numeric formats
# Handles: "1 234,56", "1,234.56", "1234.56", "100 kr", "50:-", "-100,00"
# Does NOT handle: amounts with more than 2 decimal places, scientific notation
# See tests in test_text_line_items_extractor.py::TestAmountPattern
AMOUNT_PATTERN = re.compile(
r"(?<![0-9])(?:"
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
@@ -128,17 +136,17 @@ class TextLineItemsExtractor:
def __init__(
self,
row_tolerance: float = 15.0, # Max vertical distance to consider same row
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
row_tolerance: float = DEFAULT_ROW_TOLERANCE,
min_items_for_valid: int = MIN_ITEMS_FOR_VALID_EXTRACTION,
):
"""
Initialize extractor.
Args:
row_tolerance: Maximum vertical distance (pixels) between elements
to consider them on the same row.
to consider them on the same row. Default: 15.0
min_items_for_valid: Minimum number of line items required for
extraction to be considered successful.
extraction to be considered successful. Default: 2
"""
self.row_tolerance = row_tolerance
self.min_items_for_valid = min_items_for_valid
@@ -161,10 +169,13 @@ class TextLineItemsExtractor:
# Extract text elements from parsing results
text_elements = self._extract_text_elements(parsing_res_list)
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
logger.debug(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
if len(text_elements) < 5: # Need at least a few elements
logger.debug("Too few text elements for line item extraction")
if len(text_elements) < MIN_TEXT_ELEMENTS_FOR_EXTRACTION:
logger.debug(
f"Too few text elements ({len(text_elements)}) for line item extraction, "
f"need at least {MIN_TEXT_ELEMENTS_FOR_EXTRACTION}"
)
return None
return self.extract_from_text_elements(text_elements)
@@ -183,11 +194,11 @@ class TextLineItemsExtractor:
"""
# Group elements by row
rows = self._group_by_row(text_elements)
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
logger.debug(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
# Find the line items section
item_rows = self._identify_line_item_rows(rows)
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
logger.debug(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
if len(item_rows) < self.min_items_for_valid:
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
@@ -195,7 +206,7 @@ class TextLineItemsExtractor:
# Extract structured items
items = self._parse_line_items(item_rows)
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
logger.debug(f"TextLineItemsExtractor: extracted {len(items)} line items")
if len(items) < self.min_items_for_valid:
return None
@@ -209,7 +220,11 @@ class TextLineItemsExtractor:
def _extract_text_elements(
self, parsing_res_list: list[dict[str, Any]]
) -> list[TextElement]:
"""Extract TextElement objects from parsing_res_list."""
"""Extract TextElement objects from parsing_res_list.
Handles both dict and LayoutBlock object formats from PP-StructureV3.
Gracefully skips invalid elements with appropriate logging.
"""
elements = []
for elem in parsing_res_list:
@@ -220,11 +235,15 @@ class TextLineItemsExtractor:
bbox = elem.get("bbox", [])
# Try both 'text' and 'content' keys
text = elem.get("text", "") or elem.get("content", "")
else:
elif hasattr(elem, "label"):
label = getattr(elem, "label", "")
bbox = getattr(elem, "bbox", [])
# LayoutBlock objects use 'content' attribute
text = getattr(elem, "content", "") or getattr(elem, "text", "")
else:
# Element is neither dict nor has expected attributes
logger.debug(f"Skipping element with unexpected type: {type(elem).__name__}")
continue
# Only process text elements (skip images, tables, etc.)
if label not in ("text", "paragraph_title", "aside_text"):
@@ -232,6 +251,7 @@ class TextLineItemsExtractor:
# Validate bbox
if not self._valid_bbox(bbox):
logger.debug(f"Skipping element with invalid bbox: {bbox}")
continue
# Clean text
@@ -250,8 +270,13 @@ class TextLineItemsExtractor:
),
)
)
except (KeyError, TypeError, ValueError, AttributeError) as e:
# Expected format issues - log at debug level
logger.debug(f"Skipping element due to format issue: {e}")
continue
except Exception as e:
logger.debug(f"Failed to parse element: {e}")
# Unexpected errors - log at warning level for visibility
logger.warning(f"Unexpected error parsing element: {type(e).__name__}: {e}")
continue
return elements
@@ -270,6 +295,7 @@ class TextLineItemsExtractor:
Group text elements into rows based on vertical position.
Elements within row_tolerance of each other are considered same row.
Uses dynamic average center_y to handle varying element heights more accurately.
"""
if not elements:
return []
@@ -277,22 +303,22 @@ class TextLineItemsExtractor:
# Sort by vertical position
sorted_elements = sorted(elements, key=lambda e: e.center_y)
rows = []
current_row = [sorted_elements[0]]
current_y = sorted_elements[0].center_y
rows: list[list[TextElement]] = []
current_row: list[TextElement] = [sorted_elements[0]]
for elem in sorted_elements[1:]:
if abs(elem.center_y - current_y) <= self.row_tolerance:
# Same row
# Calculate dynamic average center_y for current row
avg_center_y = sum(e.center_y for e in current_row) / len(current_row)
if abs(elem.center_y - avg_center_y) <= self.row_tolerance:
# Same row - add element and recalculate average on next iteration
current_row.append(elem)
else:
# New row
if current_row:
# Sort row by horizontal position
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
# New row - finalize current row
# Sort row by horizontal position (left to right)
current_row.sort(key=lambda e: e.center_x)
rows.append(current_row)
current_row = [elem]
current_y = elem.center_y
# Don't forget last row
if current_row:

View File

@@ -0,0 +1,37 @@
"""
BBox Scale Strategy Module.
Provides field-specific bounding box expansion strategies for YOLO training data.
Expands bboxes using center-point scaling with directional compensation to capture
field labels that typically appear above or to the left of field values.
Two modes are supported:
- Auto-label: Field-specific scale strategies with directional compensation
- Manual-label: Minimal padding only to prevent edge clipping
Usage:
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
Available exports:
- ScaleStrategy: Dataclass for scale strategy configuration
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
- expand_bbox: Function to expand bbox using field-specific strategy
"""
from .scale_strategy import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
from .expander import expand_bbox
__all__ = [
"ScaleStrategy",
"DEFAULT_STRATEGY",
"MANUAL_LABEL_STRATEGY",
"FIELD_SCALE_STRATEGIES",
"expand_bbox",
]

View File

@@ -0,0 +1,101 @@
"""
BBox Expander Module.
Provides functions to expand bounding boxes using field-specific strategies.
Expansion is center-point based with directional compensation.
Two modes:
- Auto-label (default): Field-specific scale strategies
- Manual-label: Minimal padding only to prevent edge clipping
"""
from .scale_strategy import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
def expand_bbox(
bbox: tuple[float, float, float, float],
image_width: float,
image_height: float,
field_type: str,
strategies: dict[str, ScaleStrategy] | None = None,
manual_mode: bool = False,
) -> tuple[int, int, int, int]:
"""
Expand bbox using field-specific scale strategy.
The expansion follows these steps:
1. Scale bbox around center point (scale_x, scale_y)
2. Apply directional compensation (extra_*_ratio)
3. Clamp expansion to max_pad limits
4. Clamp to image boundaries
Args:
bbox: (x0, y0, x1, y1) in pixels
image_width: Image width for boundary clamping
image_height: Image height for boundary clamping
field_type: Field class_name (e.g., "ocr_number")
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
Returns:
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
"""
x0, y0, x1, y1 = bbox
w = x1 - x0
h = y1 - y0
# Get strategy based on mode
if manual_mode:
strategy = MANUAL_LABEL_STRATEGY
elif strategies is None:
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
else:
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
# Step 1: Scale around center point
cx = (x0 + x1) / 2
cy = (y0 + y1) / 2
new_w = w * strategy.scale_x
new_h = h * strategy.scale_y
nx0 = cx - new_w / 2
nx1 = cx + new_w / 2
ny0 = cy - new_h / 2
ny1 = cy + new_h / 2
# Step 2: Apply directional compensation
nx0 -= w * strategy.extra_left_ratio
nx1 += w * strategy.extra_right_ratio
ny0 -= h * strategy.extra_top_ratio
ny1 += h * strategy.extra_bottom_ratio
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
left_pad = min(x0 - nx0, strategy.max_pad_x)
right_pad = min(nx1 - x1, strategy.max_pad_x)
top_pad = min(y0 - ny0, strategy.max_pad_y)
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
# Ensure pads are non-negative (in case of contraction)
left_pad = max(0, left_pad)
right_pad = max(0, right_pad)
top_pad = max(0, top_pad)
bottom_pad = max(0, bottom_pad)
nx0 = x0 - left_pad
nx1 = x1 + right_pad
ny0 = y0 - top_pad
ny1 = y1 + bottom_pad
# Step 4: Clamp to image boundaries
nx0 = max(0, int(nx0))
ny0 = max(0, int(ny0))
nx1 = min(int(image_width), int(nx1))
ny1 = min(int(image_height), int(ny1))
return (nx0, ny0, nx1, ny1)

View File

@@ -0,0 +1,140 @@
"""
Scale Strategy Configuration.
Defines field-specific bbox expansion strategies for YOLO training data.
Each strategy controls how bboxes are expanded around field values to
capture contextual information like labels.
"""
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class ScaleStrategy:
"""Immutable scale strategy for bbox expansion.
Attributes:
scale_x: Horizontal scale factor (1.0 = no scaling)
scale_y: Vertical scale factor (1.0 = no scaling)
extra_top_ratio: Additional expansion ratio towards top (for labels above)
extra_bottom_ratio: Additional expansion ratio towards bottom
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
max_pad_x: Maximum horizontal padding in pixels
max_pad_y: Maximum vertical padding in pixels
"""
scale_x: float = 1.15
scale_y: float = 1.15
extra_top_ratio: float = 0.0
extra_bottom_ratio: float = 0.0
extra_left_ratio: float = 0.0
extra_right_ratio: float = 0.0
max_pad_x: int = 50
max_pad_y: int = 50
# Default strategy for unknown fields (auto-label mode)
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
# Manual label strategy - minimal padding to prevent edge clipping
# No scaling, no directional compensation, just small uniform padding
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.0,
extra_bottom_ratio=0.0,
extra_left_ratio=0.0,
extra_right_ratio=0.0,
max_pad_x=10, # Small padding to prevent edge loss
max_pad_y=10,
)
# Field-specific strategies based on Swedish invoice field characteristics
# Field labels typically appear above or to the left of values
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
# OCR number - label "OCR" or "Referens" typically above
"ocr_number": ScaleStrategy(
scale_x=1.15,
scale_y=1.80,
extra_top_ratio=0.60,
max_pad_x=50,
max_pad_y=140,
),
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
"bankgiro": ScaleStrategy(
scale_x=1.45,
scale_y=1.35,
extra_left_ratio=0.80,
max_pad_x=160,
max_pad_y=90,
),
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
"plusgiro": ScaleStrategy(
scale_x=1.45,
scale_y=1.35,
extra_left_ratio=0.80,
max_pad_x=160,
max_pad_y=90,
),
# Invoice date - label "Fakturadatum" typically above
"invoice_date": ScaleStrategy(
scale_x=1.25,
scale_y=1.55,
extra_top_ratio=0.40,
max_pad_x=80,
max_pad_y=110,
),
# Due date - label "Forfalldatum" typically above, sometimes left
"invoice_due_date": ScaleStrategy(
scale_x=1.30,
scale_y=1.65,
extra_top_ratio=0.45,
extra_left_ratio=0.35,
max_pad_x=100,
max_pad_y=120,
),
# Amount - currency symbol "SEK" or "kr" may be to the right
"amount": ScaleStrategy(
scale_x=1.20,
scale_y=1.35,
extra_right_ratio=0.30,
max_pad_x=70,
max_pad_y=80,
),
# Invoice number - label "Fakturanummer" typically above
"invoice_number": ScaleStrategy(
scale_x=1.20,
scale_y=1.50,
extra_top_ratio=0.40,
max_pad_x=80,
max_pad_y=100,
),
# Supplier org number - label "Org.nr" typically above or left
"supplier_org_number": ScaleStrategy(
scale_x=1.25,
scale_y=1.40,
extra_top_ratio=0.30,
extra_left_ratio=0.20,
max_pad_x=90,
max_pad_y=90,
),
# Customer number - label "Kundnummer" typically above or left
"customer_number": ScaleStrategy(
scale_x=1.25,
scale_y=1.45,
extra_top_ratio=0.35,
extra_left_ratio=0.25,
max_pad_x=90,
max_pad_y=100,
),
# Payment line - machine-readable code, minimal expansion needed
"payment_line": ScaleStrategy(
scale_x=1.10,
scale_y=1.20,
max_pad_x=40,
max_pad_y=40,
),
}

View File

@@ -16,6 +16,7 @@ Available exports:
- FIELD_CLASSES: dict[int, str] - class_id to class_name
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
- FIELD_TO_CLASS: dict[str, str] - field_name to class_name
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
@@ -27,6 +28,7 @@ from .mappings import (
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
FIELD_TO_CLASS,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
ACCOUNT_FIELD_MAPPING,
@@ -40,6 +42,7 @@ __all__ = [
"FIELD_CLASSES",
"FIELD_CLASS_IDS",
"CLASS_TO_FIELD",
"FIELD_TO_CLASS",
"CSV_TO_CLASS_MAPPING",
"TRAINING_FIELD_CLASSES",
"ACCOUNT_FIELD_MAPPING",

View File

@@ -47,6 +47,12 @@ TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
}
# field_name -> class_name mapping (reverse of CLASS_TO_FIELD)
# Example: {"InvoiceNumber": "invoice_number", "OCR": "ocr_number", ...}
FIELD_TO_CLASS: Final[dict[str, str]] = {
fd.field_name: fd.class_name for fd in FIELD_DEFINITIONS
}
# Account field mapping for supplier_accounts special handling
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {

View File

@@ -2,6 +2,7 @@
YOLO Annotation Generator
Generates YOLO format annotations from matched fields.
Uses field-specific bbox expansion strategies for optimal training data.
"""
from dataclasses import dataclass
@@ -14,7 +15,9 @@ from shared.fields import (
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
CLASS_NAMES,
ACCOUNT_FIELD_MAPPING,
FIELD_TO_CLASS,
)
from shared.bbox import expand_bbox
@dataclass
@@ -38,19 +41,16 @@ class AnnotationGenerator:
def __init__(
self,
min_confidence: float = 0.7,
bbox_padding_px: int = 20, # Absolute padding in pixels
min_bbox_height_px: int = 30 # Minimum bbox height
min_bbox_height_px: int = 30, # Minimum bbox height
):
"""
Initialize annotation generator.
Args:
min_confidence: Minimum match score to include in training
bbox_padding_px: Absolute padding in pixels to add around bboxes
min_bbox_height_px: Minimum bbox height in pixels
"""
self.min_confidence = min_confidence
self.bbox_padding_px = bbox_padding_px
self.min_bbox_height_px = min_bbox_height_px
def generate_from_matches(
@@ -63,6 +63,10 @@ class AnnotationGenerator:
"""
Generate YOLO annotations from field matches.
Uses field-specific bbox expansion strategies for optimal training data.
Each field type has customized scale factors and directional compensation
to capture field labels and context.
Args:
matches: Dict of field_name -> list of Match objects
image_width: Width of the rendered image in pixels
@@ -82,6 +86,8 @@ class AnnotationGenerator:
continue
class_id = FIELD_CLASSES[field_name]
# Get class_name for bbox expansion strategy
class_name = FIELD_TO_CLASS.get(field_name, field_name)
# Take only the best match per field
if field_matches:
@@ -94,19 +100,20 @@ class AnnotationGenerator:
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Apply field-specific bbox expansion strategy
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type=class_name,
)
# Ensure minimum height
current_height = y1 - y0
if current_height < self.min_bbox_height_px:
extra = (self.min_bbox_height_px - current_height) / 2
y0 = max(0, y0 - extra)
y1 = min(image_height, y1 + extra)
y0 = max(0, int(y0 - extra))
y1 = min(int(image_height), int(y1 + extra))
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width
@@ -143,6 +150,9 @@ class AnnotationGenerator:
"""
Add payment_line annotation from machine code parser result.
Uses "payment_line" scale strategy for minimal expansion
(machine-readable code needs less context).
Args:
annotations: Existing list of annotations to append to
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
@@ -163,12 +173,13 @@ class AnnotationGenerator:
x0, y0, x1, y1 = payment_line_bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Add absolute padding
pad = self.bbox_padding_px
x0 = max(0, x0 - pad)
y0 = max(0, y0 - pad)
x1 = min(image_width, x1 + pad)
y1 = min(image_height, y1 + pad)
# Apply field-specific bbox expansion strategy for payment_line
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type="payment_line",
)
# Convert to YOLO format (normalized center + size)
x_center = (x0 + x1) / 2 / image_width

View File

@@ -0,0 +1 @@
"""Tests for shared.bbox module."""

View File

@@ -0,0 +1,556 @@
"""
Tests for expand_bbox function.
Tests verify that bbox expansion works correctly with center-point scaling,
directional compensation, max padding clamping, and image boundary handling.
"""
import pytest
from shared.bbox import (
expand_bbox,
ScaleStrategy,
FIELD_SCALE_STRATEGIES,
DEFAULT_STRATEGY,
)
class TestExpandBboxCenterScaling:
"""Tests for center-point based scaling."""
def test_center_scaling_expands_symmetrically(self):
"""Verify bbox expands symmetrically around center when no extra ratios."""
# 100x50 bbox at (100, 200)
bbox = (100, 200, 200, 250)
strategy = ScaleStrategy(
scale_x=1.2, # 20% wider
scale_y=1.4, # 40% taller
max_pad_x=1000, # Large to avoid clamping
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Original: width=100, height=50
# New: width=120, height=70
# Center: (150, 225)
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
assert result[0] == 90 # x0
assert result[1] == 190 # y0
assert result[2] == 210 # x1
assert result[3] == 260 # y1
def test_no_scaling_returns_original(self):
"""Verify scale=1.0 with no extras returns original bbox."""
bbox = (100, 200, 200, 250)
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
assert result == (100, 200, 200, 250)
class TestExpandBboxDirectionalCompensation:
"""Tests for directional compensation (extra ratios)."""
def test_extra_top_expands_upward(self):
"""Verify extra_top_ratio adds expansion toward top."""
bbox = (100, 200, 200, 250) # width=100, height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.5, # Add 50% of height to top
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_top = 50 * 0.5 = 25
assert result[0] == 100 # x0 unchanged
assert result[1] == 175 # y0 = 200 - 25
assert result[2] == 200 # x1 unchanged
assert result[3] == 250 # y1 unchanged
def test_extra_left_expands_leftward(self):
"""Verify extra_left_ratio adds expansion toward left."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=0.8, # Add 80% of width to left
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_left = 100 * 0.8 = 80
assert result[0] == 20 # x0 = 100 - 80
assert result[1] == 200 # y0 unchanged
assert result[2] == 200 # x1 unchanged
assert result[3] == 250 # y1 unchanged
def test_extra_right_expands_rightward(self):
"""Verify extra_right_ratio adds expansion toward right."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_right_ratio=0.3, # Add 30% of width to right
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_right = 100 * 0.3 = 30
assert result[0] == 100 # x0 unchanged
assert result[1] == 200 # y0 unchanged
assert result[2] == 230 # x1 = 200 + 30
assert result[3] == 250 # y1 unchanged
def test_extra_bottom_expands_downward(self):
"""Verify extra_bottom_ratio adds expansion toward bottom."""
bbox = (100, 200, 200, 250) # height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_bottom_ratio=0.4, # Add 40% of height to bottom
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_bottom = 50 * 0.4 = 20
assert result[0] == 100 # x0 unchanged
assert result[1] == 200 # y0 unchanged
assert result[2] == 200 # x1 unchanged
assert result[3] == 270 # y1 = 250 + 20
def test_combined_scaling_and_directional(self):
"""Verify scale + directional compensation work together."""
bbox = (100, 200, 200, 250) # width=100, height=50
strategy = ScaleStrategy(
scale_x=1.2, # 20% wider -> 120 width
scale_y=1.0, # no height change
extra_left_ratio=0.5, # Add 50% of width to left
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Center: x=150
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
# After extra_left: x0 = 90 - (100 * 0.5) = 40
assert result[0] == 40 # x0
assert result[2] == 210 # x1
class TestExpandBboxMaxPadClamping:
"""Tests for max padding clamping."""
def test_max_pad_x_limits_horizontal_expansion(self):
"""Verify max_pad_x limits expansion on left and right."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=2.0, # Double width (would add 50 each side)
scale_y=1.0,
max_pad_x=30, # Limit to 30 pixels each side
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
# But max_pad_x=30 limits to: x0=70, x1=230
assert result[0] == 70 # x0 = 100 - 30
assert result[2] == 230 # x1 = 200 + 30
def test_max_pad_y_limits_vertical_expansion(self):
"""Verify max_pad_y limits expansion on top and bottom."""
bbox = (100, 200, 200, 250) # height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=3.0, # Triple height (would add 50 each side)
max_pad_x=1000,
max_pad_y=20, # Limit to 20 pixels each side
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Scale would make: y0=175, y1=275 (50px each side)
# But max_pad_y=20 limits to: y0=180, y1=270
assert result[1] == 180 # y0 = 200 - 20
assert result[3] == 270 # y1 = 250 + 20
def test_max_pad_preserves_asymmetry(self):
"""Verify max_pad clamping preserves asymmetric expansion."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=1.0, # 100px left expansion
extra_right_ratio=0.0, # No right expansion
max_pad_x=50, # Limit to 50 pixels
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Left would expand 100, clamped to 50
# Right stays at 0
assert result[0] == 50 # x0 = 100 - 50
assert result[2] == 200 # x1 unchanged
class TestExpandBboxImageBoundaryClamping:
"""Tests for image boundary clamping."""
def test_clamps_to_left_boundary(self):
"""Verify x0 is clamped to 0."""
bbox = (10, 200, 110, 250) # Close to left edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=0.5, # Would push x0 below 0
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
assert result[0] == 0 # Clamped to 0
def test_clamps_to_top_boundary(self):
"""Verify y0 is clamped to 0."""
bbox = (100, 10, 200, 60) # Close to top edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.5, # Would push y0 below 0
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
assert result[1] == 0 # Clamped to 0
def test_clamps_to_right_boundary(self):
"""Verify x1 is clamped to image_width."""
bbox = (900, 200, 990, 250) # Close to right edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_right_ratio=0.5, # Would push x1 beyond image_width
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
assert result[2] == 1000 # Clamped to image_width
def test_clamps_to_bottom_boundary(self):
"""Verify y1 is clamped to image_height."""
bbox = (100, 940, 200, 990) # Close to bottom edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
assert result[3] == 1000 # Clamped to image_height
class TestExpandBboxUnknownField:
"""Tests for unknown field handling."""
def test_unknown_field_uses_default_strategy(self):
"""Verify unknown field types use DEFAULT_STRATEGY."""
bbox = (100, 200, 200, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="unknown_field_xyz",
)
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
# Original: width=100, height=50
# New: width=115, height=57.5
# Center: (150, 225)
# x0 = 150 - 57.5 = 92.5 -> 92
# x1 = 150 + 57.5 = 207.5 -> 207
# y0 = 225 - 28.75 = 196.25 -> 196
# y1 = 225 + 28.75 = 253.75 -> 253
# But max_pad_x=50 may clamp...
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
assert result[0] == 92
assert result[2] == 207
class TestExpandBboxWithRealStrategies:
"""Tests using actual FIELD_SCALE_STRATEGIES."""
def test_ocr_number_expands_significantly_upward(self):
"""Verify ocr_number field gets significant upward expansion."""
bbox = (100, 200, 200, 230) # Small height=30
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="ocr_number",
)
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
# y0 should decrease significantly
assert result[1] < 200 - 10 # At least 10px upward expansion
def test_bankgiro_expands_significantly_leftward(self):
"""Verify bankgiro field gets significant leftward expansion."""
bbox = (200, 200, 300, 230) # width=100
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
)
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
# x0 should decrease significantly
assert result[0] < 200 - 30 # At least 30px leftward expansion
def test_amount_expands_rightward(self):
"""Verify amount field gets rightward expansion for currency."""
bbox = (100, 200, 200, 230) # width=100
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="amount",
)
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
# x1 should increase
assert result[2] > 200 + 10 # At least 10px rightward expansion
class TestExpandBboxReturnType:
"""Tests for return type and value format."""
def test_returns_tuple_of_four_ints(self):
"""Verify return type is tuple of 4 integers."""
bbox = (100.5, 200.3, 200.7, 250.9)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="invoice_number",
)
assert isinstance(result, tuple)
assert len(result) == 4
assert all(isinstance(v, int) for v in result)
def test_returns_valid_bbox_format(self):
"""Verify returned bbox has x0 < x1 and y0 < y1."""
bbox = (100, 200, 200, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="invoice_number",
)
x0, y0, x1, y1 = result
assert x0 < x1, "x0 should be less than x1"
assert y0 < y1, "y0 should be less than y1"
class TestManualLabelMode:
"""Tests for manual_mode parameter."""
def test_manual_mode_uses_minimal_padding(self):
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
bbox = (100, 200, 200, 250) # width=100, height=50
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro", # Would normally expand left significantly
manual_mode=True,
)
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
# Should only add 10px padding each side (but scale=1.0 means no scaling)
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
# Only max_pad=10 applies as a limit, but there's no expansion to limit
# So result should be same as original
assert result == (100, 200, 200, 250)
def test_manual_mode_ignores_field_type(self):
"""Verify manual_mode ignores field-specific strategies."""
bbox = (100, 200, 200, 250)
# Different fields should give same result in manual_mode
result_bankgiro = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
manual_mode=True,
)
result_ocr = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="ocr_number",
manual_mode=True,
)
assert result_bankgiro == result_ocr
def test_manual_mode_vs_auto_mode_different(self):
"""Verify manual_mode produces different results than auto mode."""
bbox = (100, 200, 200, 250)
auto_result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro", # Has extra_left_ratio=0.80
manual_mode=False,
)
manual_result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
manual_mode=True,
)
# Auto mode should expand more (especially to the left for bankgiro)
assert auto_result[0] < manual_result[0] # Auto x0 is more left
def test_manual_mode_clamps_to_image_bounds(self):
"""Verify manual_mode still respects image boundaries."""
bbox = (5, 5, 50, 50) # Close to top-left corner
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test",
manual_mode=True,
)
# Should clamp to 0
assert result[0] >= 0
assert result[1] >= 0

View File

@@ -0,0 +1,192 @@
"""
Tests for ScaleStrategy configuration.
Tests verify that scale strategies are properly defined, immutable,
and cover all required fields.
"""
import pytest
from shared.bbox import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
from shared.fields import CLASS_NAMES
class TestScaleStrategyDataclass:
"""Tests for ScaleStrategy dataclass behavior."""
def test_default_strategy_values(self):
"""Verify default strategy has expected default values."""
strategy = ScaleStrategy()
assert strategy.scale_x == 1.15
assert strategy.scale_y == 1.15
assert strategy.extra_top_ratio == 0.0
assert strategy.extra_bottom_ratio == 0.0
assert strategy.extra_left_ratio == 0.0
assert strategy.extra_right_ratio == 0.0
assert strategy.max_pad_x == 50
assert strategy.max_pad_y == 50
def test_scale_strategy_immutability(self):
"""Verify ScaleStrategy is frozen (immutable)."""
strategy = ScaleStrategy()
with pytest.raises(AttributeError):
strategy.scale_x = 2.0 # type: ignore
def test_custom_strategy_values(self):
"""Verify custom values are properly set."""
strategy = ScaleStrategy(
scale_x=1.5,
scale_y=1.8,
extra_top_ratio=0.6,
extra_left_ratio=0.8,
max_pad_x=100,
max_pad_y=150,
)
assert strategy.scale_x == 1.5
assert strategy.scale_y == 1.8
assert strategy.extra_top_ratio == 0.6
assert strategy.extra_left_ratio == 0.8
assert strategy.max_pad_x == 100
assert strategy.max_pad_y == 150
class TestDefaultStrategy:
"""Tests for DEFAULT_STRATEGY constant."""
def test_default_strategy_is_scale_strategy(self):
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
def test_default_strategy_matches_default_values(self):
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
expected = ScaleStrategy()
assert DEFAULT_STRATEGY == expected
class TestManualLabelStrategy:
"""Tests for MANUAL_LABEL_STRATEGY constant."""
def test_manual_label_strategy_is_scale_strategy(self):
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
def test_manual_label_strategy_has_no_scaling(self):
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
def test_manual_label_strategy_has_no_directional_expansion(self):
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
def test_manual_label_strategy_has_small_max_pad(self):
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
class TestFieldScaleStrategies:
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
def test_all_class_names_have_strategies(self):
"""Verify all field class names have defined strategies."""
for class_name in CLASS_NAMES:
assert class_name in FIELD_SCALE_STRATEGIES, (
f"Missing strategy for field: {class_name}"
)
def test_strategies_are_scale_strategy_instances(self):
"""Verify all strategies are ScaleStrategy instances."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert isinstance(strategy, ScaleStrategy), (
f"Strategy for {field_name} is not a ScaleStrategy"
)
def test_scale_values_are_greater_than_one(self):
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.scale_x >= 1.0, (
f"{field_name} scale_x should be >= 1.0"
)
assert strategy.scale_y >= 1.0, (
f"{field_name} scale_y should be >= 1.0"
)
def test_extra_ratios_are_non_negative(self):
"""Verify all extra ratios are >= 0."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.extra_top_ratio >= 0, (
f"{field_name} extra_top_ratio should be >= 0"
)
assert strategy.extra_bottom_ratio >= 0, (
f"{field_name} extra_bottom_ratio should be >= 0"
)
assert strategy.extra_left_ratio >= 0, (
f"{field_name} extra_left_ratio should be >= 0"
)
assert strategy.extra_right_ratio >= 0, (
f"{field_name} extra_right_ratio should be >= 0"
)
def test_max_pad_values_are_positive(self):
"""Verify all max_pad values are > 0."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.max_pad_x > 0, (
f"{field_name} max_pad_x should be > 0"
)
assert strategy.max_pad_y > 0, (
f"{field_name} max_pad_y should be > 0"
)
class TestSpecificFieldStrategies:
"""Tests for specific field strategy configurations."""
def test_ocr_number_expands_upward(self):
"""Verify ocr_number strategy expands upward to capture label."""
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
assert strategy.extra_top_ratio > 0.0
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
def test_bankgiro_expands_leftward(self):
"""Verify bankgiro strategy expands leftward to capture prefix."""
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
assert strategy.extra_left_ratio > 0.0
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
def test_plusgiro_expands_leftward(self):
"""Verify plusgiro strategy expands leftward to capture prefix."""
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
assert strategy.extra_left_ratio > 0.0
assert strategy.extra_left_ratio >= 0.5
def test_amount_expands_rightward(self):
"""Verify amount strategy expands rightward for currency symbol."""
strategy = FIELD_SCALE_STRATEGIES["amount"]
assert strategy.extra_right_ratio > 0.0
def test_invoice_date_expands_upward(self):
"""Verify invoice_date strategy expands upward to capture label."""
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
assert strategy.extra_top_ratio > 0.0
def test_invoice_due_date_expands_upward_and_leftward(self):
"""Verify invoice_due_date strategy expands both up and left."""
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
assert strategy.extra_top_ratio > 0.0
assert strategy.extra_left_ratio > 0.0
def test_payment_line_has_minimal_expansion(self):
"""Verify payment_line has conservative expansion (machine code)."""
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
# Payment line is machine-readable, needs minimal expansion
assert strategy.scale_x <= 1.2
assert strategy.scale_y <= 1.3

View File

@@ -16,6 +16,7 @@ from shared.fields import (
FIELD_CLASSES,
FIELD_CLASS_IDS,
CLASS_TO_FIELD,
FIELD_TO_CLASS,
CSV_TO_CLASS_MAPPING,
TRAINING_FIELD_CLASSES,
NUM_CLASSES,
@@ -133,6 +134,20 @@ class TestMappingConsistency:
assert fd.field_name in TRAINING_FIELD_CLASSES
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
def test_field_to_class_is_inverse_of_class_to_field(self):
"""Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses."""
for class_name, field_name in CLASS_TO_FIELD.items():
assert FIELD_TO_CLASS[field_name] == class_name
for field_name, class_name in FIELD_TO_CLASS.items():
assert CLASS_TO_FIELD[class_name] == field_name
def test_field_to_class_has_all_fields(self):
"""Verify FIELD_TO_CLASS has mapping for all field names."""
for fd in FIELD_DEFINITIONS:
assert fd.field_name in FIELD_TO_CLASS
assert FIELD_TO_CLASS[fd.field_name] == fd.class_name
class TestSpecificFieldDefinitions:
"""Tests for specific field definitions to catch common mistakes."""

View File

@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
extractor = LineItemsExtractor()
# Create mock table detection result
# Create mock table detection result with proper thead/tbody structure
mock_table = MagicMock(spec=TableDetectionResult)
mock_table.html = """
<table>
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
</table>
"""
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
assert len(result.items) >= 1
class TestPdfPathValidation:
"""Tests for PDF path validation."""
def test_detect_tables_with_nonexistent_path(self):
"""Test that non-existent PDF path returns empty results."""
extractor = LineItemsExtractor()
# Create detector and call _detect_tables_with_parsing with non-existent path
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, "nonexistent.pdf"
)
assert tables == []
assert parsing_res == []
def test_detect_tables_with_directory_path(self, tmp_path):
"""Test that directory path (not file) returns empty results."""
extractor = LineItemsExtractor()
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
# tmp_path is a directory, not a file
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, str(tmp_path)
)
assert tables == []
assert parsing_res == []
def test_detect_tables_validates_file_exists(self, tmp_path):
"""Test path validation for file existence.
This test verifies that the method correctly validates the path exists
and is a file before attempting to process it.
"""
from unittest.mock import patch
extractor = LineItemsExtractor()
# Create a real file path that exists
fake_pdf = tmp_path / "test.pdf"
fake_pdf.write_bytes(b"not a real pdf")
# Mock render_pdf_to_images to avoid actual PDF processing
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
# Return empty iterator - simulates file exists but no pages
mock_render.return_value = iter([])
from unittest.mock import MagicMock
from backend.table.structure_detector import TableDetector
mock_detector = MagicMock(spec=TableDetector)
mock_detector._ensure_initialized = MagicMock()
mock_detector._pipeline = MagicMock()
tables, parsing_res = extractor._detect_tables_with_parsing(
mock_detector, str(fake_pdf)
)
# render_pdf_to_images was called (path validation passed)
mock_render.assert_called_once()
assert tables == []
assert parsing_res == []
class TestLineItemsResult:
"""Tests for LineItemsResult dataclass."""
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
assert result.items[0].is_deduction is False
assert result.items[1].amount == "-2000"
assert result.items[1].is_deduction is True
class TestTextFallbackExtraction:
"""Tests for text-based fallback extraction."""
def test_text_fallback_disabled_by_default(self):
"""Test text fallback can be disabled."""
extractor = LineItemsExtractor(enable_text_fallback=False)
assert extractor.enable_text_fallback is False
def test_text_fallback_enabled_by_default(self):
"""Test text fallback is enabled by default."""
extractor = LineItemsExtractor()
assert extractor.enable_text_fallback is True
def test_try_text_fallback_with_valid_parsing_res(self):
"""Test text fallback with valid parsing results."""
from unittest.mock import patch, MagicMock
from backend.table.text_line_items_extractor import (
TextLineItemsExtractor,
TextLineItem,
TextLineItemsResult,
)
extractor = LineItemsExtractor()
# Mock parsing_res_list with text elements
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
]
# Create mock text extraction result
mock_text_result = TextLineItemsResult(
items=[
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
],
header_row=[],
)
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
result = extractor._try_text_fallback(parsing_res)
assert result is not None
assert len(result.items) == 2
assert result.items[0].description == "Product A"
assert result.items[1].description == "Product B"
def test_try_text_fallback_returns_none_on_failure(self):
"""Test text fallback returns None when extraction fails."""
from unittest.mock import patch
extractor = LineItemsExtractor()
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
result = extractor._try_text_fallback([])
assert result is None
def test_extract_from_pdf_uses_text_fallback(self):
"""Test extract_from_pdf uses text fallback when no tables found."""
from unittest.mock import patch, MagicMock
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
extractor = LineItemsExtractor(enable_text_fallback=True)
# Mock _detect_tables_with_parsing to return no tables but parsing_res
mock_text_result = TextLineItemsResult(
items=[
TextLineItem(row_index=0, description="Product", amount="100,00"),
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
],
header_row=[],
)
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
result = extractor.extract_from_pdf("fake.pdf")
# Text fallback should be called
mock_fallback.assert_called_once()
def test_extract_from_pdf_skips_fallback_when_disabled(self):
"""Test extract_from_pdf skips text fallback when disabled."""
from unittest.mock import patch
extractor = LineItemsExtractor(enable_text_fallback=False)
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
result = extractor.extract_from_pdf("fake.pdf")
# Should return None, not use text fallback
assert result is None
class TestVerticallyMergedCellExtraction:
"""Tests for vertically merged cell extraction."""
def test_detects_vertically_merged_cells(self):
"""Test detection of vertically merged cells in rows."""
extractor = LineItemsExtractor()
# Rows with multiple product numbers in single cell
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
assert extractor._has_vertically_merged_cells(rows) is True
def test_splits_vertically_merged_rows(self):
"""Test splitting vertically merged rows."""
extractor = LineItemsExtractor()
rows = [
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
]
header, data = extractor._split_merged_rows(rows)
# Should split into header + data rows
assert isinstance(header, list)
assert isinstance(data, list)
class TestDeductionDetection:
"""Tests for deduction/discount detection."""
def test_detects_deduction_by_keyword_avdrag(self):
"""Test detection of deduction by 'avdrag' keyword."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_detects_deduction_by_keyword_rabatt(self):
"""Test detection of deduction by 'rabatt' keyword."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_detects_deduction_by_negative_amount(self):
"""Test detection of deduction by negative amount."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Some credit</td><td>-250,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is True
def test_normal_item_not_deduction(self):
"""Test normal item is not marked as deduction."""
html = """
<html><body><table>
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
<tbody>
<tr><td>Normal product</td><td>500,00</td></tr>
</tbody>
</table></body></html>
"""
extractor = LineItemsExtractor()
result = extractor.extract(html)
assert len(result.items) == 1
assert result.items[0].is_deduction is False
class TestHeaderDetection:
"""Tests for header row detection."""
def test_detect_header_at_bottom(self):
"""Test detecting header at bottom of table (reversed)."""
extractor = LineItemsExtractor()
rows = [
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
["Belopp", "Beskrivning", "Antal"], # Header at bottom
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == 2
assert is_at_end is True
assert "Belopp" in header
def test_detect_header_at_top(self):
"""Test detecting header at top of table."""
extractor = LineItemsExtractor()
rows = [
["Belopp", "Beskrivning", "Antal"], # Header at top
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == 0
assert is_at_end is False
assert "Belopp" in header
def test_no_header_detected(self):
"""Test when no header is detected."""
extractor = LineItemsExtractor()
rows = [
["100,00", "Product A", "1"],
["200,00", "Product B", "2"],
]
header_idx, header, is_at_end = extractor._detect_header_row(rows)
assert header_idx == -1
assert header == []
assert is_at_end is False

View File

@@ -0,0 +1,448 @@
"""
Tests for Merged Cell Handler
Tests the detection and extraction of data from tables with merged cells,
a common issue with PP-StructureV3 OCR output.
"""
import pytest
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
from backend.table.html_table_parser import ColumnMapper
@pytest.fixture
def handler():
"""Create a MergedCellHandler with default ColumnMapper."""
return MergedCellHandler(ColumnMapper())
class TestHasVerticallyMergedCells:
"""Tests for has_vertically_merged_cells detection."""
def test_empty_rows_returns_false(self, handler):
"""Test empty rows returns False."""
assert handler.has_vertically_merged_cells([]) is False
def test_short_cells_ignored(self, handler):
"""Test cells shorter than 20 chars are ignored."""
rows = [["Short cell", "Also short"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_product_numbers(self, handler):
"""Test detection of multiple 7-digit product numbers in cell."""
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_single_product_number_not_merged(self, handler):
"""Test single product number doesn't trigger detection."""
rows = [["Produktnr 1457280 and more text here for length"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_prices(self, handler):
"""Test detection of 3+ prices in cell (Swedish format)."""
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_two_prices_not_merged(self, handler):
"""Test two prices doesn't trigger detection (needs 3+)."""
rows = [["Pris 127,20 234,56 total amount here"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_detects_multiple_quantities(self, handler):
"""Test detection of multiple quantity patterns."""
rows = [["Antal 6ST 6ST 1ST more text here"]]
assert handler.has_vertically_merged_cells(rows) is True
def test_single_quantity_not_merged(self, handler):
"""Test single quantity doesn't trigger detection."""
rows = [["Antal 6ST and more text here for length"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_empty_cell_skipped(self, handler):
"""Test empty cells are skipped."""
rows = [["", None, "Valid but short"]]
assert handler.has_vertically_merged_cells(rows) is False
def test_multiple_rows_checked(self, handler):
"""Test all rows are checked for merged content."""
rows = [
["Normal row with nothing special"],
["Produktnr 1457280 1457281 1060381 merged content"],
]
assert handler.has_vertically_merged_cells(rows) is True
class TestSplitMergedRows:
"""Tests for split_merged_rows method."""
def test_empty_rows_returns_empty(self, handler):
"""Test empty rows returns empty result."""
header, data = handler.split_merged_rows([])
assert header == []
assert data == []
def test_all_empty_rows_returns_original(self, handler):
"""Test all empty rows returns original rows."""
rows = [["", ""], ["", ""]]
header, data = handler.split_merged_rows(rows)
assert header == []
assert data == rows
def test_splits_by_product_numbers(self, handler):
"""Test splitting rows by product numbers."""
rows = [
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
]
header, data = handler.split_merged_rows(rows)
assert len(header) == 3
assert header[0] == "Produktnr"
assert len(data) == 2
def test_splits_by_quantities(self, handler):
"""Test splitting rows by quantity patterns."""
rows = [
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
]
header, data = handler.split_merged_rows(rows)
# Should detect 2 quantities and split accordingly
assert len(data) >= 1
def test_single_row_not_split(self, handler):
"""Test single item row is not split."""
rows = [
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
]
header, data = handler.split_merged_rows(rows)
# Only 1 product number, so expected_rows <= 1
assert header == []
assert data == rows
def test_handles_missing_columns(self, handler):
"""Test handles rows with different column counts."""
rows = [
["Produktnr 1234567 1234568", ""],
["Antal 2ST 3ST"],
]
header, data = handler.split_merged_rows(rows)
# Should handle gracefully
assert isinstance(header, list)
assert isinstance(data, list)
class TestCountExpectedRows:
"""Tests for _count_expected_rows helper."""
def test_counts_product_numbers(self, handler):
"""Test counting product numbers."""
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
count = handler._count_expected_rows(columns)
assert count == 3
def test_counts_quantities(self, handler):
"""Test counting quantity patterns."""
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
count = handler._count_expected_rows(columns)
assert count == 4
def test_returns_max_count(self, handler):
"""Test returns maximum count across columns."""
columns = [
"Produktnr 1234567 1234568", # 2 products
"Antal 5ST 10ST 15ST", # 3 quantities
]
count = handler._count_expected_rows(columns)
assert count == 3
def test_empty_columns_return_zero(self, handler):
"""Test empty columns return 0."""
columns = ["", None, "Short"]
count = handler._count_expected_rows(columns)
assert count == 0
class TestSplitCellContentForRows:
"""Tests for _split_cell_content_for_rows helper."""
def test_splits_by_product_numbers(self, handler):
"""Test splitting by product numbers with expected count."""
cell = "Produktnr 1234567 1234568"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3 # header + 2 values
assert result[0] == "Produktnr"
assert "1234567" in result[1]
assert "1234568" in result[2]
def test_splits_by_quantities(self, handler):
"""Test splitting by quantity patterns."""
cell = "Antal 5ST 10ST"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3 # header + 2 values
assert result[0] == "Antal"
def test_splits_discount_totalsumma(self, handler):
"""Test splitting discount+totalsumma columns."""
cell = "Rabatt i% Totalsumma 686,88 123,45"
result = handler._split_cell_content_for_rows(cell, 2)
assert result[0] == "Totalsumma"
assert "686,88" in result[1]
assert "123,45" in result[2]
def test_splits_by_prices(self, handler):
"""Test splitting by price patterns."""
cell = "Pris 127,20 234,56"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) >= 2
def test_fallback_returns_original(self, handler):
"""Test fallback returns original cell."""
cell = "No patterns here"
result = handler._split_cell_content_for_rows(cell, 2)
assert result == ["No patterns here"]
def test_product_number_with_description(self, handler):
"""Test product numbers include trailing description text."""
cell = "Art 1234567 Widget A 1234568 Widget B"
result = handler._split_cell_content_for_rows(cell, 2)
assert len(result) == 3
class TestSplitCellContent:
"""Tests for split_cell_content method."""
def test_splits_by_product_numbers(self, handler):
"""Test splitting by multiple product numbers."""
cell = "Produktnr 1234567 1234568 1234569"
result = handler.split_cell_content(cell)
assert result[0] == "Produktnr"
assert "1234567" in result
assert "1234568" in result
assert "1234569" in result
def test_splits_by_quantities(self, handler):
"""Test splitting by multiple quantities."""
cell = "Antal 6ST 6ST 1ST"
result = handler.split_cell_content(cell)
assert result[0] == "Antal"
assert len(result) >= 3
def test_splits_discount_amount_interleaved(self, handler):
"""Test splitting interleaved discount+amount patterns."""
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
result = handler.split_cell_content(cell)
# Should extract amounts (3+ digit numbers with decimals)
assert result[0] == "Totalsumma"
assert "686,88" in result
assert "123,45" in result
def test_splits_by_prices(self, handler):
"""Test splitting by prices."""
cell = "Pris 127,20 127,20 159,20"
result = handler.split_cell_content(cell)
assert result[0] == "Pris"
def test_single_value_not_split(self, handler):
"""Test single value is not split."""
cell = "Single value"
result = handler.split_cell_content(cell)
assert result == ["Single value"]
def test_single_product_not_split(self, handler):
"""Test single product number is not split."""
cell = "Produktnr 1234567"
result = handler.split_cell_content(cell)
assert result == ["Produktnr 1234567"]
class TestHasMergedHeader:
"""Tests for has_merged_header method."""
def test_none_header_returns_false(self, handler):
"""Test None header returns False."""
assert handler.has_merged_header(None) is False
def test_empty_header_returns_false(self, handler):
"""Test empty header returns False."""
assert handler.has_merged_header([]) is False
def test_multiple_non_empty_cells_returns_false(self, handler):
"""Test multiple non-empty cells returns False."""
header = ["Beskrivning", "Antal", "Belopp"]
assert handler.has_merged_header(header) is False
def test_single_cell_with_keywords_returns_true(self, handler):
"""Test single cell with multiple keywords returns True."""
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
assert handler.has_merged_header(header) is True
def test_single_cell_one_keyword_returns_false(self, handler):
"""Test single cell with only one keyword returns False."""
header = ["Beskrivning only"]
assert handler.has_merged_header(header) is False
def test_ignores_empty_trailing_cells(self, handler):
"""Test ignores empty trailing cells."""
header = ["Specifikation Hyra Avdrag", "", "", ""]
assert handler.has_merged_header(header) is True
class TestExtractFromMergedCells:
"""Tests for extract_from_merged_cells method."""
def test_extracts_single_amount(self, handler):
"""Test extracting a single amount."""
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "8159"
assert items[0].is_deduction is False
assert items[0].article_number == "0218103-1201"
assert items[0].description == "2 rum och kök"
def test_extracts_deduction(self, handler):
"""Test extracting a deduction (negative amount)."""
header = ["Specifikation"]
rows = [["", "", "", "-2000"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "-2000"
assert items[0].is_deduction is True
# First item (row_index=0) gets description from header, not "Avdrag"
# "Avdrag" is only set for subsequent deduction items
assert items[0].description is None
def test_extracts_multiple_amounts_same_row(self, handler):
"""Test extracting multiple amounts from same row."""
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
rows = [["", "", "", "8159 -2000"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 2
assert items[0].amount == "8159"
assert items[1].amount == "-2000"
def test_extracts_amounts_from_multiple_rows(self, handler):
"""Test extracting amounts from multiple rows."""
header = ["Specifikation"]
rows = [
["", "", "", "8159"],
["", "", "", "-2000"],
]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 2
def test_skips_small_amounts(self, handler):
"""Test skipping small amounts below threshold."""
header = ["Specifikation"]
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0
def test_skips_empty_rows(self, handler):
"""Test skipping empty rows."""
header = ["Specifikation"]
rows = [["", "", "", ""]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0
def test_handles_swedish_format_with_spaces(self, handler):
"""Test handling Swedish number format with spaces."""
header = ["Specifikation"]
rows = [["", "", "", "8 159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "8159"
def test_confidence_is_lower_for_merged(self, handler):
"""Test confidence is 0.7 for merged cell extraction."""
header = ["Specifikation"]
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert items[0].confidence == 0.7
def test_empty_header_still_extracts(self, handler):
"""Test extraction works with empty header."""
header = []
rows = [["", "", "", "8159"]]
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].description is None
assert items[0].article_number is None
def test_row_index_increments(self, handler):
"""Test row_index increments for each item."""
header = ["Specifikation"]
# Use separate rows to avoid regex grouping issues
rows = [
["", "", "", "8159"],
["", "", "", "5000"],
["", "", "", "-2000"],
]
items = handler.extract_from_merged_cells(header, rows)
# Should have 3 items from 3 rows
assert len(items) == 3
assert items[0].row_index == 0
assert items[1].row_index == 1
assert items[2].row_index == 2
class TestMinAmountThreshold:
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
def test_threshold_value(self):
"""Test the threshold constant value."""
assert MIN_AMOUNT_THRESHOLD == 100
def test_amounts_at_threshold_included(self, handler):
"""Test amounts exactly at threshold are included."""
header = ["Specifikation"]
rows = [["", "", "", "100"]] # Exactly at threshold
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 1
assert items[0].amount == "100"
def test_amounts_below_threshold_excluded(self, handler):
"""Test amounts below threshold are excluded."""
header = ["Specifikation"]
rows = [["", "", "", "99"]] # Below threshold
items = handler.extract_from_merged_cells(header, rows)
assert len(items) == 0

157
tests/table/test_models.py Normal file
View File

@@ -0,0 +1,157 @@
"""
Tests for Line Items Data Models
Tests for LineItem and LineItemsResult dataclasses.
"""
import pytest
from backend.table.models import LineItem, LineItemsResult
class TestLineItem:
"""Tests for LineItem dataclass."""
def test_default_values(self):
"""Test default values for optional fields."""
item = LineItem(row_index=0)
assert item.row_index == 0
assert item.description is None
assert item.quantity is None
assert item.unit is None
assert item.unit_price is None
assert item.amount is None
assert item.article_number is None
assert item.vat_rate is None
assert item.is_deduction is False
assert item.confidence == 0.9
def test_custom_confidence(self):
"""Test setting custom confidence."""
item = LineItem(row_index=0, confidence=0.7)
assert item.confidence == 0.7
def test_is_deduction_true(self):
"""Test is_deduction flag."""
item = LineItem(row_index=0, is_deduction=True)
assert item.is_deduction is True
class TestLineItemsResult:
"""Tests for LineItemsResult dataclass."""
def test_total_amount_empty_items(self):
"""Test total_amount returns None for empty items."""
result = LineItemsResult(items=[], header_row=[], raw_html="")
assert result.total_amount is None
def test_total_amount_single_item(self):
"""Test total_amount with single item."""
items = [LineItem(row_index=0, amount="100,00")]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "100,00"
def test_total_amount_multiple_items(self):
"""Test total_amount with multiple items."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount="200,50"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "300,50"
def test_total_amount_with_deduction(self):
"""Test total_amount includes negative amounts (deductions)."""
items = [
LineItem(row_index=0, amount="1000,00"),
LineItem(row_index=1, amount="-200,00", is_deduction=True),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "800,00"
def test_total_amount_swedish_format_with_spaces(self):
"""Test total_amount handles Swedish format with spaces."""
items = [
LineItem(row_index=0, amount="1 234,56"),
LineItem(row_index=1, amount="2 000,00"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "3 234,56"
def test_total_amount_invalid_amount_skipped(self):
"""Test total_amount skips invalid amounts."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount="invalid"),
LineItem(row_index=2, amount="200,00"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
# Invalid amount is skipped
assert result.total_amount == "300,00"
def test_total_amount_none_amount_skipped(self):
"""Test total_amount skips None amounts."""
items = [
LineItem(row_index=0, amount="100,00"),
LineItem(row_index=1, amount=None),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "100,00"
def test_total_amount_all_invalid_returns_none(self):
"""Test total_amount returns None when all amounts are invalid."""
items = [
LineItem(row_index=0, amount="invalid"),
LineItem(row_index=1, amount="also invalid"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount is None
def test_total_amount_large_numbers(self):
"""Test total_amount handles large numbers."""
items = [
LineItem(row_index=0, amount="123 456,78"),
LineItem(row_index=1, amount="876 543,22"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "1 000 000,00"
def test_total_amount_decimal_precision(self):
"""Test total_amount maintains decimal precision."""
items = [
LineItem(row_index=0, amount="0,01"),
LineItem(row_index=1, amount="0,02"),
]
result = LineItemsResult(items=items, header_row=[], raw_html="")
assert result.total_amount == "0,03"
def test_is_reversed_default_false(self):
"""Test is_reversed defaults to False."""
result = LineItemsResult(items=[], header_row=[], raw_html="")
assert result.is_reversed is False
def test_is_reversed_can_be_set(self):
"""Test is_reversed can be set to True."""
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
assert result.is_reversed is True
def test_header_row_preserved(self):
"""Test header_row is preserved."""
header = ["Beskrivning", "Antal", "Belopp"]
result = LineItemsResult(items=[], header_row=header, raw_html="")
assert result.header_row == header
def test_raw_html_preserved(self):
"""Test raw_html is preserved."""
html = "<table><tr><td>Test</td></tr></table>"
result = LineItemsResult(items=[], header_row=[], raw_html=html)
assert result.raw_html == html

View File

@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
assert len(results) == 1
assert results[0].cells == [] # Empty cells list
assert results[0].html == "<table></table>"
def test_parse_paddlex_result_with_dict_ocr_data(self):
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
mock_pipeline = MagicMock()
mock_result = {
"table_res_list": [
{
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
"table_ocr_pred": {
"rec_texts": ["A", "B"],
"rec_scores": [0.99, 0.98],
},
}
],
"parsing_res_list": [
{"label": "table", "bbox": [10, 20, 200, 300]},
],
}
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert len(results[0].cells) == 2
assert results[0].cells[0]["text"] == "A"
assert results[0].cells[1]["text"] == "B"
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
mock_pipeline = MagicMock()
mock_result = {
"table_res_list": [
{
"cell_box_list": [[0, 0, 50, 20]],
"pred_html": "<table><tr><td>A</td></tr></table>",
"table_ocr_pred": ["A"],
}
],
"parsing_res_list": [
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
],
}
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
# Should use default bbox [0,0,0,0] when not found
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
class TestIteratorResults:
"""Tests for iterator/generator result handling."""
def test_handles_iterator_results(self):
"""Test handling of iterator results from pipeline."""
mock_pipeline = MagicMock()
# Return a generator instead of list
def result_generator():
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.html = "<table></table>"
element.score = 0.9
element.cells = []
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
yield mock_result
mock_pipeline.predict.return_value = result_generator()
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
def test_handles_failed_iterator_conversion(self):
"""Test handling when iterator conversion fails."""
mock_pipeline = MagicMock()
# Create an object that has __iter__ but fails when converted to list
class FailingIterator:
def __iter__(self):
raise RuntimeError("Iterator failed")
mock_pipeline.predict.return_value = FailingIterator()
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
# Should return empty list, not raise
assert results == []
class TestPathConversion:
"""Tests for path handling."""
def test_converts_path_object_to_string(self):
"""Test that Path objects are converted to strings."""
from pathlib import Path
mock_pipeline = MagicMock()
mock_pipeline.predict.return_value = []
detector = TableDetector(pipeline=mock_pipeline)
path = Path("/some/path/to/image.png")
detector.detect(path)
# Should be called with string, not Path
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
class TestHtmlExtraction:
"""Tests for HTML extraction from different element formats."""
def test_extracts_html_from_res_dict(self):
"""Test extracting HTML from element.res dictionary."""
mock_pipeline = MagicMock()
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
element.score = 0.9
element.cells = []
# Remove direct html attribute
del element.html
del element.table_html
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
def test_returns_empty_html_when_not_found(self):
"""Test empty HTML when no html attribute found."""
mock_pipeline = MagicMock()
element = MagicMock()
element.label = "table"
element.bbox = [0, 0, 100, 100]
element.score = 0.9
element.cells = []
# Remove all html attributes
del element.html
del element.table_html
del element.res
mock_result = MagicMock(spec=["layout_elements"])
mock_result.layout_elements = [element]
mock_pipeline.predict.return_value = [mock_result]
detector = TableDetector(pipeline=mock_pipeline)
image = np.zeros((100, 100, 3), dtype=np.uint8)
results = detector.detect(image)
assert len(results) == 1
assert results[0].html == ""
class TestTableTypeDetection:
"""Tests for table type detection."""
def test_detects_borderless_table(self):
"""Test detection of borderless table type via _get_table_type."""
detector = TableDetector()
# Create mock element with borderless label
element = MagicMock()
element.label = "borderless_table"
result = detector._get_table_type(element)
assert result == "wireless"
def test_detects_wireless_table_label(self):
"""Test detection of wireless table type."""
detector = TableDetector()
element = MagicMock()
element.label = "wireless_table"
result = detector._get_table_type(element)
assert result == "wireless"
def test_defaults_to_wired_table(self):
"""Test default table type is wired."""
detector = TableDetector()
element = MagicMock()
element.label = "table"
result = detector._get_table_type(element)
assert result == "wired"
def test_type_attribute_instead_of_label(self):
"""Test table type detection using type attribute."""
detector = TableDetector()
element = MagicMock()
element.type = "wireless"
del element.label # Remove label
result = detector._get_table_type(element)
assert result == "wireless"
class TestPipelineRuntimeError:
"""Tests for pipeline runtime errors."""
def test_raises_runtime_error_when_pipeline_none(self):
"""Test RuntimeError when pipeline is None during detect."""
detector = TableDetector()
detector._initialized = True # Bypass lazy init
detector._pipeline = None
image = np.zeros((100, 100, 3), dtype=np.uint8)
with pytest.raises(RuntimeError) as exc_info:
detector.detect(image)
assert "not initialized" in str(exc_info.value).lower()

View File

@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
rows = extractor._group_by_row(elements)
assert len(rows) == 2
def test_group_by_row_varying_heights_uses_average(self, extractor):
"""Test grouping handles varying element heights using dynamic average.
When elements have varying heights, the row center should be recalculated
as new elements are added, preventing tall elements from being incorrectly
grouped with the next row.
"""
# First element: small height, center_y = 105
# Second element: tall, center_y = 115 (but should still be same row)
# Third element: next row, center_y = 160
elements = [
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
]
rows = extractor._group_by_row(elements)
# With dynamic average, both first and second element should be same row
assert len(rows) == 2
assert len(rows[0]) == 2 # Short and Tall item
assert len(rows[1]) == 1 # Next row
def test_group_by_row_empty_input(self, extractor):
"""Test grouping with empty input returns empty list."""
rows = extractor._group_by_row([])
assert rows == []
def test_looks_like_line_item_with_amount(self, extractor):
"""Test line item detection with amount."""
row = [
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
assert len(elements) == 4
class TestExceptionHandling:
"""Tests for exception handling in text element extraction."""
def test_extract_text_elements_handles_missing_bbox(self):
"""Test that missing bbox is handled gracefully."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "text": "No bbox"}, # Missing bbox
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
# Should only have 1 valid element
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_invalid_bbox(self):
"""Test that invalid bbox (less than 4 values) is handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_none_text(self):
"""Test that None text is handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_empty_string(self):
"""Test that empty string text is skipped."""
extractor = TextLineItemsExtractor()
parsing_res = [
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
def test_extract_text_elements_handles_malformed_element(self):
"""Test that completely malformed elements are handled."""
extractor = TextLineItemsExtractor()
parsing_res = [
"not a dict", # String instead of dict
123, # Number instead of dict
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
]
elements = extractor._extract_text_elements(parsing_res)
assert len(elements) == 1
assert elements[0].text == "Valid"
class TestConvertTextLineItem:
"""Tests for convert_text_line_item function."""

View File

@@ -0,0 +1 @@
"""Tests for training package."""

View File

@@ -0,0 +1 @@
"""Tests for training.yolo module."""

View File

@@ -0,0 +1,342 @@
"""
Tests for AnnotationGenerator with field-specific bbox expansion.
Tests verify that annotations are generated correctly using
field-specific scale strategies.
"""
from dataclasses import dataclass
import pytest
from training.yolo.annotation_generator import (
AnnotationGenerator,
YOLOAnnotation,
)
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
@dataclass
class MockMatch:
"""Mock Match object for testing."""
bbox: tuple[float, float, float, float]
score: float
class TestYOLOAnnotation:
"""Tests for YOLOAnnotation dataclass."""
def test_to_string_format(self):
"""Verify YOLO format string output."""
ann = YOLOAnnotation(
class_id=0,
x_center=0.5,
y_center=0.5,
width=0.1,
height=0.05,
confidence=0.9
)
result = ann.to_string()
assert result == "0 0.500000 0.500000 0.100000 0.050000"
def test_default_confidence(self):
"""Verify default confidence is 1.0."""
ann = YOLOAnnotation(
class_id=0,
x_center=0.5,
y_center=0.5,
width=0.1,
height=0.05,
)
assert ann.confidence == 1.0
class TestAnnotationGeneratorInit:
"""Tests for AnnotationGenerator initialization."""
def test_default_values(self):
"""Verify default initialization values."""
gen = AnnotationGenerator()
assert gen.min_confidence == 0.7
assert gen.min_bbox_height_px == 30
def test_custom_values(self):
"""Verify custom initialization values."""
gen = AnnotationGenerator(
min_confidence=0.8,
min_bbox_height_px=40,
)
assert gen.min_confidence == 0.8
assert gen.min_bbox_height_px == 40
class TestGenerateFromMatches:
"""Tests for generate_from_matches method."""
def test_generates_annotation_for_valid_match(self):
"""Verify annotation is generated for valid match."""
gen = AnnotationGenerator(min_confidence=0.5)
# Mock match in PDF points (72 DPI)
# At 150 DPI, coords multiply by 150/72 = 2.083
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 1
ann = annotations[0]
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
assert ann.confidence == 0.8
# Normalized values should be in 0-1 range
assert 0 <= ann.x_center <= 1
assert 0 <= ann.y_center <= 1
assert 0 < ann.width <= 1
assert 0 < ann.height <= 1
def test_skips_low_confidence_match(self):
"""Verify low confidence matches are skipped."""
gen = AnnotationGenerator(min_confidence=0.7)
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_skips_unknown_field(self):
"""Verify unknown fields are skipped."""
gen = AnnotationGenerator(min_confidence=0.5)
matches = {
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_takes_best_match_only(self):
"""Verify only the best match is used per field."""
gen = AnnotationGenerator(min_confidence=0.5)
matches = {
"InvoiceNumber": [
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 1
assert annotations[0].confidence == 0.9
def test_handles_empty_matches(self):
"""Verify empty matches list is handled."""
gen = AnnotationGenerator()
matches = {
"InvoiceNumber": []
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(annotations) == 0
def test_applies_field_specific_expansion(self):
"""Verify different fields get different expansion."""
gen = AnnotationGenerator(min_confidence=0.5)
# Same bbox, different fields
bbox = (100, 200, 200, 230)
matches_invoice_number = {
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
}
matches_bankgiro = {
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
}
ann_invoice = gen.generate_from_matches(
matches=matches_invoice_number,
image_width=1000,
image_height=1000,
dpi=150
)[0]
ann_bankgiro = gen.generate_from_matches(
matches=matches_bankgiro,
image_width=1000,
image_height=1000,
dpi=150
)[0]
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
# They should have different widths due to different expansion
# Bankgiro expands more to the left
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
def test_enforces_min_bbox_height(self):
"""Verify minimum bbox height is enforced."""
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
# Very small bbox
matches = {
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
}
annotations = gen.generate_from_matches(
matches=matches,
image_width=1000,
image_height=1000,
dpi=72 # 1:1 scale
)
assert len(annotations) == 1
# Height should be at least min_bbox_height_px / image_height
# After scale strategy expansion, height should be >= 50/1000 = 0.05
# Actually the min_bbox_height check happens AFTER expand_bbox
# So the final height should meet the minimum
class TestAddPaymentLineAnnotation:
"""Tests for add_payment_line_annotation method."""
def test_adds_payment_line_annotation(self):
"""Verify payment_line annotation is added."""
gen = AnnotationGenerator(min_confidence=0.5)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 1
ann = result[0]
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
assert ann.confidence == 0.9
def test_skips_none_bbox(self):
"""Verify None bbox is handled."""
gen = AnnotationGenerator(min_confidence=0.5)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=None,
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 0
def test_skips_low_confidence(self):
"""Verify low confidence is skipped."""
gen = AnnotationGenerator(min_confidence=0.7)
annotations = []
result = gen.add_payment_line_annotation(
annotations=annotations,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.5,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 0
def test_appends_to_existing_annotations(self):
"""Verify payment_line is appended to existing list."""
gen = AnnotationGenerator(min_confidence=0.5)
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
result = gen.add_payment_line_annotation(
annotations=existing,
payment_line_bbox=(100, 200, 400, 230),
confidence=0.9,
image_width=1000,
image_height=1000,
dpi=150
)
assert len(result) == 2
assert result[0].class_id == 0 # Original
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
class TestMultipleFieldsIntegration:
"""Integration tests for multiple fields."""
def test_generates_annotations_for_all_field_types(self):
"""Verify annotations can be generated for all field types."""
gen = AnnotationGenerator(min_confidence=0.5)
# Create matches for each field (except payment_line which is derived)
field_names = [
"InvoiceNumber",
"InvoiceDate",
"InvoiceDueDate",
"OCR",
"Bankgiro",
"Plusgiro",
"Amount",
"supplier_organisation_number",
"customer_number",
]
matches = {}
for i, field_name in enumerate(field_names):
# Stagger bboxes to avoid overlap
matches[field_name] = [
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
]
annotations = gen.generate_from_matches(
matches=matches,
image_width=2000,
image_height=2000,
dpi=150
)
assert len(annotations) == len(field_names)
# Verify all class_ids are present
class_ids = {ann.class_id for ann in annotations}
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
assert class_ids == expected_class_ids