Compare commits
2 Commits
c2c8f2dd04
...
0990239e9c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0990239e9c | ||
|
|
8723ef4653 |
204
packages/backend/backend/table/html_table_parser.py
Normal file
204
packages/backend/backend/table/html_table_parser.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
HTML Table Parser
|
||||
|
||||
Parses HTML tables into structured data and maps columns to field names.
|
||||
"""
|
||||
|
||||
from html.parser import HTMLParser
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constants
|
||||
# Minimum pattern length to avoid false positives from short substrings
|
||||
MIN_PATTERN_MATCH_LENGTH = 3
|
||||
# Exact match bonus for column mapping priority
|
||||
EXACT_MATCH_BONUS = 100
|
||||
|
||||
# Swedish column name mappings
|
||||
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
|
||||
COLUMN_MAPPINGS = {
|
||||
"article_number": [
|
||||
"art nummer",
|
||||
"artikelnummer",
|
||||
"artikel",
|
||||
"artnr",
|
||||
"art.nr",
|
||||
"art nr",
|
||||
"objektnummer", # Rental: property reference
|
||||
"objekt",
|
||||
],
|
||||
"description": [
|
||||
"beskrivning",
|
||||
"produktbeskrivning",
|
||||
"produkt",
|
||||
"tjänst",
|
||||
"text",
|
||||
"benämning",
|
||||
"vara/tjänst",
|
||||
"vara",
|
||||
# Rental invoice specific
|
||||
"specifikation",
|
||||
"spec",
|
||||
"hyresperiod", # Rental period
|
||||
"period",
|
||||
"typ", # Type of charge
|
||||
# Utility bills
|
||||
"förbrukning", # Consumption
|
||||
"avläsning", # Meter reading
|
||||
],
|
||||
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "m²", "kvm"],
|
||||
"unit": ["enhet", "unit"],
|
||||
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
|
||||
"amount": [
|
||||
"belopp",
|
||||
"summa",
|
||||
"total",
|
||||
"netto",
|
||||
"rad summa",
|
||||
# Rental specific
|
||||
"hyra", # Rent
|
||||
"avgift", # Fee
|
||||
"kostnad", # Cost
|
||||
"debitering", # Charge
|
||||
"totalt", # Total
|
||||
],
|
||||
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
|
||||
# Additional field for rental: deductions/adjustments
|
||||
"deduction": [
|
||||
"avdrag", # Deduction
|
||||
"rabatt", # Discount
|
||||
"kredit", # Credit
|
||||
],
|
||||
}
|
||||
|
||||
# Keywords that indicate NOT a line items table
|
||||
SUMMARY_KEYWORDS = [
|
||||
"frakt",
|
||||
"faktura.avg",
|
||||
"fakturavg",
|
||||
"exkl.moms",
|
||||
"att betala",
|
||||
"öresavr",
|
||||
"bankgiro",
|
||||
"plusgiro",
|
||||
"ocr",
|
||||
"forfallodatum",
|
||||
"förfallodatum",
|
||||
]
|
||||
|
||||
|
||||
class _TableHTMLParser(HTMLParser):
|
||||
"""Internal HTML parser for tables."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows: list[list[str]] = []
|
||||
self.current_row: list[str] = []
|
||||
self.current_cell: str = ""
|
||||
self.in_td = False
|
||||
self.in_thead = False
|
||||
self.header_row: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr":
|
||||
self.current_row = []
|
||||
elif tag in ("td", "th"):
|
||||
self.in_td = True
|
||||
self.current_cell = ""
|
||||
elif tag == "thead":
|
||||
self.in_thead = True
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td", "th"):
|
||||
self.in_td = False
|
||||
self.current_row.append(self.current_cell.strip())
|
||||
elif tag == "tr":
|
||||
if self.current_row:
|
||||
if self.in_thead:
|
||||
self.header_row = self.current_row
|
||||
else:
|
||||
self.rows.append(self.current_row)
|
||||
elif tag == "thead":
|
||||
self.in_thead = False
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.in_td:
|
||||
self.current_cell += data
|
||||
|
||||
|
||||
class HTMLTableParser:
|
||||
"""Parse HTML tables into structured data."""
|
||||
|
||||
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Parse HTML table and return header and rows.
|
||||
|
||||
Args:
|
||||
html: HTML string containing table.
|
||||
|
||||
Returns:
|
||||
Tuple of (header_row, data_rows).
|
||||
"""
|
||||
parser = _TableHTMLParser()
|
||||
parser.feed(html)
|
||||
return parser.header_row, parser.rows
|
||||
|
||||
|
||||
class ColumnMapper:
|
||||
"""Map column headers to field names."""
|
||||
|
||||
def __init__(self, mappings: dict[str, list[str]] | None = None):
|
||||
"""
|
||||
Initialize column mapper.
|
||||
|
||||
Args:
|
||||
mappings: Custom column mappings. Uses Swedish defaults if None.
|
||||
"""
|
||||
self.mappings = mappings or COLUMN_MAPPINGS
|
||||
|
||||
def map(self, headers: list[str]) -> dict[int, str]:
|
||||
"""
|
||||
Map column indices to field names.
|
||||
|
||||
Args:
|
||||
headers: List of column header strings.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping column index to field name.
|
||||
"""
|
||||
mapping = {}
|
||||
for idx, header in enumerate(headers):
|
||||
normalized = self._normalize(header)
|
||||
|
||||
if not normalized.strip():
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_match_len = 0
|
||||
|
||||
for field_name, patterns in self.mappings.items():
|
||||
for pattern in patterns:
|
||||
if pattern == normalized:
|
||||
# Exact match gets highest priority
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern) + EXACT_MATCH_BONUS
|
||||
break
|
||||
elif pattern in normalized and len(pattern) > best_match_len:
|
||||
# Partial match requires minimum length to avoid false positives
|
||||
if len(pattern) >= MIN_PATTERN_MATCH_LENGTH:
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern)
|
||||
|
||||
if best_match_len > EXACT_MATCH_BONUS:
|
||||
# Found exact match, no need to check other fields
|
||||
break
|
||||
|
||||
if best_match:
|
||||
mapping[idx] = best_match
|
||||
|
||||
return mapping
|
||||
|
||||
def _normalize(self, header: str) -> str:
|
||||
"""Normalize header text for matching."""
|
||||
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||
File diff suppressed because it is too large
Load Diff
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Merged Cell Handler
|
||||
|
||||
Handles detection and extraction of data from tables with merged cells,
|
||||
a common issue with PP-StructureV3 OCR output.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import LineItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .html_table_parser import ColumnMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum positive amount to consider as line item (filters noise like row indices)
|
||||
MIN_AMOUNT_THRESHOLD = 100
|
||||
|
||||
|
||||
class MergedCellHandler:
|
||||
"""Handles tables with vertically merged cells from PP-StructureV3."""
|
||||
|
||||
def __init__(self, mapper: "ColumnMapper"):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
mapper: ColumnMapper instance for header keyword detection.
|
||||
"""
|
||||
self.mapper = mapper
|
||||
|
||||
def has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
|
||||
"""
|
||||
Check if table rows contain vertically merged data in single cells.
|
||||
|
||||
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
|
||||
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
|
||||
|
||||
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
|
||||
"""
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
for row in rows:
|
||||
for cell in row:
|
||||
if not cell or len(cell) < 20:
|
||||
continue
|
||||
|
||||
# Check for multiple product numbers (7+ digit patterns)
|
||||
product_nums = re.findall(r"\b\d{7}\b", cell)
|
||||
if len(product_nums) >= 2:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
|
||||
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
|
||||
if len(prices) >= 3:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(prices)} prices in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
|
||||
if len(quantities) >= 2:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(quantities)} quantities in cell")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def split_merged_rows(
|
||||
self, rows: list[list[str]]
|
||||
) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Split vertically merged cells back into separate rows.
|
||||
|
||||
Handles complex cases where PP-StructureV3 merges content across
|
||||
multiple HTML rows. For example, 5 line items might be spread across
|
||||
3 HTML rows with content mixed together.
|
||||
|
||||
Strategy:
|
||||
1. Merge all row content per column
|
||||
2. Detect how many actual data rows exist (by counting product numbers)
|
||||
3. Split each column's content into that many lines
|
||||
|
||||
Returns header and data rows.
|
||||
"""
|
||||
if not rows:
|
||||
return [], []
|
||||
|
||||
# Filter out completely empty rows
|
||||
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
|
||||
if not non_empty_rows:
|
||||
return [], rows
|
||||
|
||||
# Determine column count
|
||||
col_count = max(len(r) for r in non_empty_rows)
|
||||
|
||||
# Merge content from all rows for each column
|
||||
merged_columns = []
|
||||
for col_idx in range(col_count):
|
||||
col_content = []
|
||||
for row in non_empty_rows:
|
||||
if col_idx < len(row) and row[col_idx].strip():
|
||||
col_content.append(row[col_idx].strip())
|
||||
merged_columns.append(" ".join(col_content))
|
||||
|
||||
logger.debug(f"split_merged_rows: merged columns = {merged_columns}")
|
||||
|
||||
# Count how many actual data rows we should have
|
||||
# Use the column with most product numbers as reference
|
||||
expected_rows = self._count_expected_rows(merged_columns)
|
||||
logger.debug(f"split_merged_rows: expecting {expected_rows} data rows")
|
||||
|
||||
if expected_rows <= 1:
|
||||
# Not enough data for splitting
|
||||
return [], rows
|
||||
|
||||
# Split each column based on expected row count
|
||||
split_columns = []
|
||||
for col_idx, col_text in enumerate(merged_columns):
|
||||
if not col_text.strip():
|
||||
split_columns.append([""] * (expected_rows + 1)) # +1 for header
|
||||
continue
|
||||
lines = self._split_cell_content_for_rows(col_text, expected_rows)
|
||||
split_columns.append(lines)
|
||||
|
||||
# Ensure all columns have same number of lines (immutable approach)
|
||||
max_lines = max(len(col) for col in split_columns)
|
||||
split_columns = [
|
||||
col + [""] * (max_lines - len(col))
|
||||
for col in split_columns
|
||||
]
|
||||
|
||||
logger.debug(f"split_merged_rows: split into {max_lines} lines total")
|
||||
|
||||
# First line is header, rest are data rows
|
||||
header = [col[0] for col in split_columns]
|
||||
data_rows = []
|
||||
for line_idx in range(1, max_lines):
|
||||
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
|
||||
if any(cell.strip() for cell in row):
|
||||
data_rows.append(row)
|
||||
|
||||
logger.debug(f"split_merged_rows: header={header}, data_rows count={len(data_rows)}")
|
||||
return header, data_rows
|
||||
|
||||
def _count_expected_rows(self, merged_columns: list[str]) -> int:
|
||||
"""
|
||||
Count how many data rows should exist based on content patterns.
|
||||
|
||||
Returns the maximum count found from:
|
||||
- Product numbers (7 digits)
|
||||
- Quantity patterns (number + ST/PCS)
|
||||
- Amount patterns (in columns likely to be totals)
|
||||
"""
|
||||
max_count = 0
|
||||
|
||||
for col_text in merged_columns:
|
||||
if not col_text:
|
||||
continue
|
||||
|
||||
# Count product numbers (most reliable indicator)
|
||||
product_nums = re.findall(r"\b\d{7}\b", col_text)
|
||||
max_count = max(max_count, len(product_nums))
|
||||
|
||||
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
|
||||
max_count = max(max_count, len(quantities))
|
||||
|
||||
return max_count
|
||||
|
||||
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
|
||||
"""
|
||||
Split cell content knowing how many data rows we expect.
|
||||
|
||||
This is smarter than split_cell_content because it knows the target count.
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Try product number split first
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) == expected_rows:
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
# Include description text after each product number
|
||||
values = []
|
||||
for i in range(1, len(parts), 2): # Odd indices are product numbers
|
||||
if i < len(parts):
|
||||
prod_num = parts[i].strip()
|
||||
# Check if there's description text after
|
||||
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
|
||||
# If description looks like text (not another pattern), include it
|
||||
if desc and not re.match(r"^\d{7}$", desc):
|
||||
# Truncate at next product number pattern if any
|
||||
desc_clean = re.split(r"\d{7}", desc)[0].strip()
|
||||
if desc_clean:
|
||||
values.append(f"{prod_num} {desc_clean}")
|
||||
else:
|
||||
values.append(prod_num)
|
||||
else:
|
||||
values.append(prod_num)
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try quantity split
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) == expected_rows:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try amount split for discount+totalsumma columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
|
||||
|
||||
if has_discount and has_total:
|
||||
# Extract only amounts (3+ digit numbers), skip discount percentages
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
if len(amounts) >= expected_rows:
|
||||
# Take the last expected_rows amounts (they are likely the totals)
|
||||
return ["Totalsumma"] + amounts[:expected_rows]
|
||||
|
||||
# Try price split
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= expected_rows:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
if len(values) >= expected_rows:
|
||||
return [header] + values[:expected_rows]
|
||||
|
||||
# Fall back to original single-value behavior
|
||||
return [cell]
|
||||
|
||||
def split_cell_content(self, cell: str) -> list[str]:
|
||||
"""
|
||||
Split a cell containing merged multi-line content.
|
||||
|
||||
Strategies:
|
||||
1. Look for product number patterns (7 digits)
|
||||
2. Look for quantity patterns (number + ST/PCS)
|
||||
3. Look for price patterns (with decimal)
|
||||
4. Handle interleaved discount+amount patterns
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) >= 2:
|
||||
# Extract header (text before first product number) and values
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) >= 2:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
|
||||
# Check if header contains two keywords indicating merged columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
|
||||
|
||||
if has_discount_header and has_amount_header:
|
||||
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
|
||||
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
|
||||
if len(amounts) >= 2:
|
||||
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
|
||||
# This avoids the "Rabatt" keyword causing is_deduction=True
|
||||
header = "Totalsumma"
|
||||
return [header] + amounts
|
||||
|
||||
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= 2:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# No pattern detected, return as single value
|
||||
return [cell]
|
||||
|
||||
def has_merged_header(self, header: list[str] | None) -> bool:
|
||||
"""
|
||||
Check if header appears to be a merged cell containing multiple column names.
|
||||
|
||||
This happens when OCR merges table headers into a single cell, e.g.:
|
||||
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
|
||||
|
||||
Also handles cases where PP-StructureV3 produces headers like:
|
||||
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
|
||||
"""
|
||||
if header is None or not header:
|
||||
return False
|
||||
|
||||
# Filter out empty cells to find the actual content
|
||||
non_empty_cells = [h for h in header if h.strip()]
|
||||
|
||||
# Check if we have a single non-empty cell that contains multiple keywords
|
||||
if len(non_empty_cells) == 1:
|
||||
header_text = non_empty_cells[0].lower()
|
||||
# Count how many column keywords are in this single cell
|
||||
keyword_count = 0
|
||||
for patterns in self.mapper.mappings.values():
|
||||
for pattern in patterns:
|
||||
if pattern in header_text:
|
||||
keyword_count += 1
|
||||
break # Only count once per field type
|
||||
|
||||
logger.debug(f"has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
|
||||
return keyword_count >= 2
|
||||
|
||||
return False
|
||||
|
||||
def extract_from_merged_cells(
|
||||
self, header: list[str], rows: list[list[str]]
|
||||
) -> list[LineItem]:
|
||||
"""
|
||||
Extract line items from tables with merged cells.
|
||||
|
||||
For poorly OCR'd tables like:
|
||||
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
Row 1: ["", "", "", "8159"] <- amount row
|
||||
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
|
||||
|
||||
Or:
|
||||
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
|
||||
|
||||
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
|
||||
"""
|
||||
items = []
|
||||
|
||||
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
|
||||
amount_pattern = re.compile(
|
||||
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
|
||||
)
|
||||
|
||||
# Try to parse header cell for description info
|
||||
header_text = " ".join(h for h in header if h.strip()) if header else ""
|
||||
logger.debug(f"extract_from_merged_cells: header_text='{header_text}'")
|
||||
logger.debug(f"extract_from_merged_cells: rows={rows}")
|
||||
|
||||
# Extract description from header
|
||||
description = None
|
||||
article_number = None
|
||||
|
||||
# Look for object number pattern (e.g., "0218103-1201")
|
||||
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
|
||||
if obj_match:
|
||||
article_number = obj_match.group(1)
|
||||
|
||||
# Look for description after object number
|
||||
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
|
||||
if desc_match:
|
||||
description = desc_match.group(1).strip()
|
||||
|
||||
row_index = 0
|
||||
for row in rows:
|
||||
# Combine all non-empty cells in the row
|
||||
row_text = " ".join(cell.strip() for cell in row if cell.strip())
|
||||
logger.debug(f"extract_from_merged_cells: row text='{row_text}'")
|
||||
|
||||
if not row_text:
|
||||
continue
|
||||
|
||||
# Find all amounts in the row
|
||||
amounts = amount_pattern.findall(row_text)
|
||||
logger.debug(f"extract_from_merged_cells: amounts={amounts}")
|
||||
|
||||
for amt_str in amounts:
|
||||
# Clean the amount string
|
||||
cleaned = amt_str.replace(" ", "").strip()
|
||||
if not cleaned or cleaned == "-":
|
||||
continue
|
||||
|
||||
is_deduction = cleaned.startswith("-")
|
||||
|
||||
# Skip small positive numbers that are likely not amounts
|
||||
# (e.g., row indices, small percentages)
|
||||
if not is_deduction:
|
||||
try:
|
||||
val = float(cleaned.replace(",", "."))
|
||||
if val < MIN_AMOUNT_THRESHOLD:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Create a line item for each amount
|
||||
item = LineItem(
|
||||
row_index=row_index,
|
||||
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
|
||||
article_number=article_number if row_index == 0 else None,
|
||||
amount=cleaned,
|
||||
is_deduction=is_deduction,
|
||||
confidence=0.7,
|
||||
)
|
||||
items.append(item)
|
||||
row_index += 1
|
||||
logger.debug(f"extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
|
||||
|
||||
return items
|
||||
61
packages/backend/backend/table/models.py
Normal file
61
packages/backend/backend/table/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Line Items Data Models
|
||||
|
||||
Dataclasses for line item extraction results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal, InvalidOperation
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItem:
|
||||
"""Single line item from invoice."""
|
||||
|
||||
row_index: int
|
||||
description: str | None = None
|
||||
quantity: str | None = None
|
||||
unit: str | None = None
|
||||
unit_price: str | None = None
|
||||
amount: str | None = None
|
||||
article_number: str | None = None
|
||||
vat_rate: str | None = None
|
||||
is_deduction: bool = False # True if this row is a deduction/discount
|
||||
confidence: float = 0.9
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItemsResult:
|
||||
"""Result of line items extraction."""
|
||||
|
||||
items: list[LineItem]
|
||||
header_row: list[str]
|
||||
raw_html: str
|
||||
is_reversed: bool = False
|
||||
|
||||
@property
|
||||
def total_amount(self) -> str | None:
|
||||
"""Calculate total amount from line items (deduction rows have negative amounts)."""
|
||||
if not self.items:
|
||||
return None
|
||||
|
||||
total = Decimal("0")
|
||||
for item in self.items:
|
||||
if item.amount:
|
||||
try:
|
||||
# Parse Swedish number format (1 234,56)
|
||||
amount_str = item.amount.replace(" ", "").replace(",", ".")
|
||||
total += Decimal(amount_str)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
|
||||
if total == 0:
|
||||
return None
|
||||
|
||||
# Format back to Swedish format
|
||||
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
|
||||
# Fix the space/comma swap
|
||||
parts = formatted.rsplit(",", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].replace(" ", " ") + "," + parts[1]
|
||||
return formatted
|
||||
@@ -158,36 +158,36 @@ class TableDetector:
|
||||
return tables
|
||||
|
||||
# Log raw result type for debugging
|
||||
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||
logger.debug(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||
|
||||
# Handle case where results is a single dict-like object (PaddleX 3.x)
|
||||
# rather than a list of results
|
||||
if hasattr(results, "get") and not isinstance(results, list):
|
||||
# Single result object - wrap in list for uniform processing
|
||||
logger.info("Results is dict-like, wrapping in list")
|
||||
logger.debug("Results is dict-like, wrapping in list")
|
||||
results = [results]
|
||||
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
|
||||
# Iterator or generator - convert to list
|
||||
try:
|
||||
results = list(results)
|
||||
logger.info(f"Converted iterator to list with {len(results)} items")
|
||||
logger.debug(f"Converted iterator to list with {len(results)} items")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert results to list: {e}")
|
||||
return tables
|
||||
|
||||
logger.info(f"Processing {len(results)} result(s)")
|
||||
logger.debug(f"Processing {len(results)} result(s)")
|
||||
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result_type = type(result).__name__
|
||||
has_get = hasattr(result, "get")
|
||||
has_layout = hasattr(result, "layout_elements")
|
||||
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||
logger.debug(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||
|
||||
# Try PaddleX 3.x API first (dict-like with table_res_list)
|
||||
if has_get:
|
||||
parsed = self._parse_paddlex_result(result)
|
||||
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||
logger.debug(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||
tables.extend(parsed)
|
||||
continue
|
||||
|
||||
@@ -201,14 +201,14 @@ class TableDetector:
|
||||
if table_result and table_result.confidence >= self.config.min_confidence:
|
||||
tables.append(table_result)
|
||||
legacy_count += 1
|
||||
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||
logger.debug(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||
else:
|
||||
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Total tables detected: {len(tables)}")
|
||||
logger.debug(f"Total tables detected: {len(tables)}")
|
||||
return tables
|
||||
|
||||
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
|
||||
@@ -223,7 +223,7 @@ class TableDetector:
|
||||
result_keys = list(result.keys())
|
||||
elif hasattr(result, "__dict__"):
|
||||
result_keys = list(result.__dict__.keys())
|
||||
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||
logger.debug(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||
|
||||
# Get table results from PaddleX 3.x API
|
||||
# Handle both dict.get() and attribute access
|
||||
@@ -234,8 +234,8 @@ class TableDetector:
|
||||
table_res_list = getattr(result, "table_res_list", None)
|
||||
parsing_res_list = getattr(result, "parsing_res_list", [])
|
||||
|
||||
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||
logger.debug(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||
logger.debug(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||
|
||||
if not table_res_list:
|
||||
# Log available keys/attributes for debugging
|
||||
@@ -330,7 +330,7 @@ class TableDetector:
|
||||
# Default confidence for PaddleX 3.x results
|
||||
confidence = 0.9
|
||||
|
||||
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||
logger.debug(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||
tables.append(TableDetectionResult(
|
||||
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
|
||||
html=html,
|
||||
@@ -467,14 +467,14 @@ class TableDetector:
|
||||
if not pdf_path.exists():
|
||||
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
||||
|
||||
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||
logger.debug(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||
|
||||
# Render specific page
|
||||
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
|
||||
if page_no == page_number:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||
logger.debug(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||
return self.detect(image_array)
|
||||
|
||||
raise ValueError(f"Page {page_number} not found in PDF")
|
||||
|
||||
@@ -15,6 +15,11 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constants
|
||||
DEFAULT_ROW_TOLERANCE = 15.0 # Max vertical distance (pixels) to consider same row
|
||||
MIN_ITEMS_FOR_VALID_EXTRACTION = 2 # Minimum items required for valid extraction
|
||||
MIN_TEXT_ELEMENTS_FOR_EXTRACTION = 5 # Minimum text elements needed to attempt extraction
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextElement:
|
||||
@@ -65,7 +70,10 @@ class TextLineItemsResult:
|
||||
extraction_method: str = "text_spatial"
|
||||
|
||||
|
||||
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
|
||||
# Amount pattern matches Swedish, US, and simple numeric formats
|
||||
# Handles: "1 234,56", "1,234.56", "1234.56", "100 kr", "50:-", "-100,00"
|
||||
# Does NOT handle: amounts with more than 2 decimal places, scientific notation
|
||||
# See tests in test_text_line_items_extractor.py::TestAmountPattern
|
||||
AMOUNT_PATTERN = re.compile(
|
||||
r"(?<![0-9])(?:"
|
||||
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
|
||||
@@ -128,17 +136,17 @@ class TextLineItemsExtractor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
row_tolerance: float = 15.0, # Max vertical distance to consider same row
|
||||
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
|
||||
row_tolerance: float = DEFAULT_ROW_TOLERANCE,
|
||||
min_items_for_valid: int = MIN_ITEMS_FOR_VALID_EXTRACTION,
|
||||
):
|
||||
"""
|
||||
Initialize extractor.
|
||||
|
||||
Args:
|
||||
row_tolerance: Maximum vertical distance (pixels) between elements
|
||||
to consider them on the same row.
|
||||
to consider them on the same row. Default: 15.0
|
||||
min_items_for_valid: Minimum number of line items required for
|
||||
extraction to be considered successful.
|
||||
extraction to be considered successful. Default: 2
|
||||
"""
|
||||
self.row_tolerance = row_tolerance
|
||||
self.min_items_for_valid = min_items_for_valid
|
||||
@@ -161,10 +169,13 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Extract text elements from parsing results
|
||||
text_elements = self._extract_text_elements(parsing_res_list)
|
||||
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||
logger.debug(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||
|
||||
if len(text_elements) < 5: # Need at least a few elements
|
||||
logger.debug("Too few text elements for line item extraction")
|
||||
if len(text_elements) < MIN_TEXT_ELEMENTS_FOR_EXTRACTION:
|
||||
logger.debug(
|
||||
f"Too few text elements ({len(text_elements)}) for line item extraction, "
|
||||
f"need at least {MIN_TEXT_ELEMENTS_FOR_EXTRACTION}"
|
||||
)
|
||||
return None
|
||||
|
||||
return self.extract_from_text_elements(text_elements)
|
||||
@@ -183,11 +194,11 @@ class TextLineItemsExtractor:
|
||||
"""
|
||||
# Group elements by row
|
||||
rows = self._group_by_row(text_elements)
|
||||
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||
logger.debug(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||
|
||||
# Find the line items section
|
||||
item_rows = self._identify_line_item_rows(rows)
|
||||
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||
logger.debug(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||
|
||||
if len(item_rows) < self.min_items_for_valid:
|
||||
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
|
||||
@@ -195,7 +206,7 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Extract structured items
|
||||
items = self._parse_line_items(item_rows)
|
||||
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||
logger.debug(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||
|
||||
if len(items) < self.min_items_for_valid:
|
||||
return None
|
||||
@@ -209,7 +220,11 @@ class TextLineItemsExtractor:
|
||||
def _extract_text_elements(
|
||||
self, parsing_res_list: list[dict[str, Any]]
|
||||
) -> list[TextElement]:
|
||||
"""Extract TextElement objects from parsing_res_list."""
|
||||
"""Extract TextElement objects from parsing_res_list.
|
||||
|
||||
Handles both dict and LayoutBlock object formats from PP-StructureV3.
|
||||
Gracefully skips invalid elements with appropriate logging.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
for elem in parsing_res_list:
|
||||
@@ -220,11 +235,15 @@ class TextLineItemsExtractor:
|
||||
bbox = elem.get("bbox", [])
|
||||
# Try both 'text' and 'content' keys
|
||||
text = elem.get("text", "") or elem.get("content", "")
|
||||
else:
|
||||
elif hasattr(elem, "label"):
|
||||
label = getattr(elem, "label", "")
|
||||
bbox = getattr(elem, "bbox", [])
|
||||
# LayoutBlock objects use 'content' attribute
|
||||
text = getattr(elem, "content", "") or getattr(elem, "text", "")
|
||||
else:
|
||||
# Element is neither dict nor has expected attributes
|
||||
logger.debug(f"Skipping element with unexpected type: {type(elem).__name__}")
|
||||
continue
|
||||
|
||||
# Only process text elements (skip images, tables, etc.)
|
||||
if label not in ("text", "paragraph_title", "aside_text"):
|
||||
@@ -232,6 +251,7 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Validate bbox
|
||||
if not self._valid_bbox(bbox):
|
||||
logger.debug(f"Skipping element with invalid bbox: {bbox}")
|
||||
continue
|
||||
|
||||
# Clean text
|
||||
@@ -250,8 +270,13 @@ class TextLineItemsExtractor:
|
||||
),
|
||||
)
|
||||
)
|
||||
except (KeyError, TypeError, ValueError, AttributeError) as e:
|
||||
# Expected format issues - log at debug level
|
||||
logger.debug(f"Skipping element due to format issue: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse element: {e}")
|
||||
# Unexpected errors - log at warning level for visibility
|
||||
logger.warning(f"Unexpected error parsing element: {type(e).__name__}: {e}")
|
||||
continue
|
||||
|
||||
return elements
|
||||
@@ -270,6 +295,7 @@ class TextLineItemsExtractor:
|
||||
Group text elements into rows based on vertical position.
|
||||
|
||||
Elements within row_tolerance of each other are considered same row.
|
||||
Uses dynamic average center_y to handle varying element heights more accurately.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
@@ -277,22 +303,22 @@ class TextLineItemsExtractor:
|
||||
# Sort by vertical position
|
||||
sorted_elements = sorted(elements, key=lambda e: e.center_y)
|
||||
|
||||
rows = []
|
||||
current_row = [sorted_elements[0]]
|
||||
current_y = sorted_elements[0].center_y
|
||||
rows: list[list[TextElement]] = []
|
||||
current_row: list[TextElement] = [sorted_elements[0]]
|
||||
|
||||
for elem in sorted_elements[1:]:
|
||||
if abs(elem.center_y - current_y) <= self.row_tolerance:
|
||||
# Same row
|
||||
# Calculate dynamic average center_y for current row
|
||||
avg_center_y = sum(e.center_y for e in current_row) / len(current_row)
|
||||
|
||||
if abs(elem.center_y - avg_center_y) <= self.row_tolerance:
|
||||
# Same row - add element and recalculate average on next iteration
|
||||
current_row.append(elem)
|
||||
else:
|
||||
# New row
|
||||
if current_row:
|
||||
# Sort row by horizontal position
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
# New row - finalize current row
|
||||
# Sort row by horizontal position (left to right)
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
current_row = [elem]
|
||||
current_y = elem.center_y
|
||||
|
||||
# Don't forget last row
|
||||
if current_row:
|
||||
|
||||
37
packages/shared/shared/bbox/__init__.py
Normal file
37
packages/shared/shared/bbox/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
BBox Scale Strategy Module.
|
||||
|
||||
Provides field-specific bounding box expansion strategies for YOLO training data.
|
||||
Expands bboxes using center-point scaling with directional compensation to capture
|
||||
field labels that typically appear above or to the left of field values.
|
||||
|
||||
Two modes are supported:
|
||||
- Auto-label: Field-specific scale strategies with directional compensation
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
|
||||
Usage:
|
||||
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
|
||||
|
||||
Available exports:
|
||||
- ScaleStrategy: Dataclass for scale strategy configuration
|
||||
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
|
||||
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
|
||||
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
|
||||
- expand_bbox: Function to expand bbox using field-specific strategy
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from .expander import expand_bbox
|
||||
|
||||
__all__ = [
|
||||
"ScaleStrategy",
|
||||
"DEFAULT_STRATEGY",
|
||||
"MANUAL_LABEL_STRATEGY",
|
||||
"FIELD_SCALE_STRATEGIES",
|
||||
"expand_bbox",
|
||||
]
|
||||
101
packages/shared/shared/bbox/expander.py
Normal file
101
packages/shared/shared/bbox/expander.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
BBox Expander Module.
|
||||
|
||||
Provides functions to expand bounding boxes using field-specific strategies.
|
||||
Expansion is center-point based with directional compensation.
|
||||
|
||||
Two modes:
|
||||
- Auto-label (default): Field-specific scale strategies
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
|
||||
|
||||
def expand_bbox(
|
||||
bbox: tuple[float, float, float, float],
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
field_type: str,
|
||||
strategies: dict[str, ScaleStrategy] | None = None,
|
||||
manual_mode: bool = False,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Expand bbox using field-specific scale strategy.
|
||||
|
||||
The expansion follows these steps:
|
||||
1. Scale bbox around center point (scale_x, scale_y)
|
||||
2. Apply directional compensation (extra_*_ratio)
|
||||
3. Clamp expansion to max_pad limits
|
||||
4. Clamp to image boundaries
|
||||
|
||||
Args:
|
||||
bbox: (x0, y0, x1, y1) in pixels
|
||||
image_width: Image width for boundary clamping
|
||||
image_height: Image height for boundary clamping
|
||||
field_type: Field class_name (e.g., "ocr_number")
|
||||
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
|
||||
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
|
||||
|
||||
Returns:
|
||||
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
|
||||
"""
|
||||
x0, y0, x1, y1 = bbox
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
# Get strategy based on mode
|
||||
if manual_mode:
|
||||
strategy = MANUAL_LABEL_STRATEGY
|
||||
elif strategies is None:
|
||||
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
|
||||
else:
|
||||
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
|
||||
|
||||
# Step 1: Scale around center point
|
||||
cx = (x0 + x1) / 2
|
||||
cy = (y0 + y1) / 2
|
||||
|
||||
new_w = w * strategy.scale_x
|
||||
new_h = h * strategy.scale_y
|
||||
|
||||
nx0 = cx - new_w / 2
|
||||
nx1 = cx + new_w / 2
|
||||
ny0 = cy - new_h / 2
|
||||
ny1 = cy + new_h / 2
|
||||
|
||||
# Step 2: Apply directional compensation
|
||||
nx0 -= w * strategy.extra_left_ratio
|
||||
nx1 += w * strategy.extra_right_ratio
|
||||
ny0 -= h * strategy.extra_top_ratio
|
||||
ny1 += h * strategy.extra_bottom_ratio
|
||||
|
||||
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
|
||||
left_pad = min(x0 - nx0, strategy.max_pad_x)
|
||||
right_pad = min(nx1 - x1, strategy.max_pad_x)
|
||||
top_pad = min(y0 - ny0, strategy.max_pad_y)
|
||||
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
|
||||
|
||||
# Ensure pads are non-negative (in case of contraction)
|
||||
left_pad = max(0, left_pad)
|
||||
right_pad = max(0, right_pad)
|
||||
top_pad = max(0, top_pad)
|
||||
bottom_pad = max(0, bottom_pad)
|
||||
|
||||
nx0 = x0 - left_pad
|
||||
nx1 = x1 + right_pad
|
||||
ny0 = y0 - top_pad
|
||||
ny1 = y1 + bottom_pad
|
||||
|
||||
# Step 4: Clamp to image boundaries
|
||||
nx0 = max(0, int(nx0))
|
||||
ny0 = max(0, int(ny0))
|
||||
nx1 = min(int(image_width), int(nx1))
|
||||
ny1 = min(int(image_height), int(ny1))
|
||||
|
||||
return (nx0, ny0, nx1, ny1)
|
||||
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Scale Strategy Configuration.
|
||||
|
||||
Defines field-specific bbox expansion strategies for YOLO training data.
|
||||
Each strategy controls how bboxes are expanded around field values to
|
||||
capture contextual information like labels.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScaleStrategy:
|
||||
"""Immutable scale strategy for bbox expansion.
|
||||
|
||||
Attributes:
|
||||
scale_x: Horizontal scale factor (1.0 = no scaling)
|
||||
scale_y: Vertical scale factor (1.0 = no scaling)
|
||||
extra_top_ratio: Additional expansion ratio towards top (for labels above)
|
||||
extra_bottom_ratio: Additional expansion ratio towards bottom
|
||||
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
|
||||
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
|
||||
max_pad_x: Maximum horizontal padding in pixels
|
||||
max_pad_y: Maximum vertical padding in pixels
|
||||
"""
|
||||
|
||||
scale_x: float = 1.15
|
||||
scale_y: float = 1.15
|
||||
extra_top_ratio: float = 0.0
|
||||
extra_bottom_ratio: float = 0.0
|
||||
extra_left_ratio: float = 0.0
|
||||
extra_right_ratio: float = 0.0
|
||||
max_pad_x: int = 50
|
||||
max_pad_y: int = 50
|
||||
|
||||
|
||||
# Default strategy for unknown fields (auto-label mode)
|
||||
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
|
||||
|
||||
# Manual label strategy - minimal padding to prevent edge clipping
|
||||
# No scaling, no directional compensation, just small uniform padding
|
||||
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.0,
|
||||
extra_bottom_ratio=0.0,
|
||||
extra_left_ratio=0.0,
|
||||
extra_right_ratio=0.0,
|
||||
max_pad_x=10, # Small padding to prevent edge loss
|
||||
max_pad_y=10,
|
||||
)
|
||||
|
||||
|
||||
# Field-specific strategies based on Swedish invoice field characteristics
|
||||
# Field labels typically appear above or to the left of values
|
||||
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
|
||||
# OCR number - label "OCR" or "Referens" typically above
|
||||
"ocr_number": ScaleStrategy(
|
||||
scale_x=1.15,
|
||||
scale_y=1.80,
|
||||
extra_top_ratio=0.60,
|
||||
max_pad_x=50,
|
||||
max_pad_y=140,
|
||||
),
|
||||
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
|
||||
"bankgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
|
||||
"plusgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Invoice date - label "Fakturadatum" typically above
|
||||
"invoice_date": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.55,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=110,
|
||||
),
|
||||
# Due date - label "Forfalldatum" typically above, sometimes left
|
||||
"invoice_due_date": ScaleStrategy(
|
||||
scale_x=1.30,
|
||||
scale_y=1.65,
|
||||
extra_top_ratio=0.45,
|
||||
extra_left_ratio=0.35,
|
||||
max_pad_x=100,
|
||||
max_pad_y=120,
|
||||
),
|
||||
# Amount - currency symbol "SEK" or "kr" may be to the right
|
||||
"amount": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.35,
|
||||
extra_right_ratio=0.30,
|
||||
max_pad_x=70,
|
||||
max_pad_y=80,
|
||||
),
|
||||
# Invoice number - label "Fakturanummer" typically above
|
||||
"invoice_number": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.50,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Supplier org number - label "Org.nr" typically above or left
|
||||
"supplier_org_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.40,
|
||||
extra_top_ratio=0.30,
|
||||
extra_left_ratio=0.20,
|
||||
max_pad_x=90,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Customer number - label "Kundnummer" typically above or left
|
||||
"customer_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.45,
|
||||
extra_top_ratio=0.35,
|
||||
extra_left_ratio=0.25,
|
||||
max_pad_x=90,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Payment line - machine-readable code, minimal expansion needed
|
||||
"payment_line": ScaleStrategy(
|
||||
scale_x=1.10,
|
||||
scale_y=1.20,
|
||||
max_pad_x=40,
|
||||
max_pad_y=40,
|
||||
),
|
||||
}
|
||||
@@ -16,6 +16,7 @@ Available exports:
|
||||
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||
- FIELD_TO_CLASS: dict[str, str] - field_name to class_name
|
||||
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||
@@ -27,6 +28,7 @@ from .mappings import (
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
FIELD_TO_CLASS,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
@@ -40,6 +42,7 @@ __all__ = [
|
||||
"FIELD_CLASSES",
|
||||
"FIELD_CLASS_IDS",
|
||||
"CLASS_TO_FIELD",
|
||||
"FIELD_TO_CLASS",
|
||||
"CSV_TO_CLASS_MAPPING",
|
||||
"TRAINING_FIELD_CLASSES",
|
||||
"ACCOUNT_FIELD_MAPPING",
|
||||
|
||||
@@ -47,6 +47,12 @@ TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# field_name -> class_name mapping (reverse of CLASS_TO_FIELD)
|
||||
# Example: {"InvoiceNumber": "invoice_number", "OCR": "ocr_number", ...}
|
||||
FIELD_TO_CLASS: Final[dict[str, str]] = {
|
||||
fd.field_name: fd.class_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# Account field mapping for supplier_accounts special handling
|
||||
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
YOLO Annotation Generator
|
||||
|
||||
Generates YOLO format annotations from matched fields.
|
||||
Uses field-specific bbox expansion strategies for optimal training data.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -14,7 +15,9 @@ from shared.fields import (
|
||||
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
||||
CLASS_NAMES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
FIELD_TO_CLASS,
|
||||
)
|
||||
from shared.bbox import expand_bbox
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,19 +41,16 @@ class AnnotationGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20, # Absolute padding in pixels
|
||||
min_bbox_height_px: int = 30 # Minimum bbox height
|
||||
min_bbox_height_px: int = 30, # Minimum bbox height
|
||||
):
|
||||
"""
|
||||
Initialize annotation generator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum match score to include in training
|
||||
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
||||
min_bbox_height_px: Minimum bbox height in pixels
|
||||
"""
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
|
||||
def generate_from_matches(
|
||||
@@ -63,6 +63,10 @@ class AnnotationGenerator:
|
||||
"""
|
||||
Generate YOLO annotations from field matches.
|
||||
|
||||
Uses field-specific bbox expansion strategies for optimal training data.
|
||||
Each field type has customized scale factors and directional compensation
|
||||
to capture field labels and context.
|
||||
|
||||
Args:
|
||||
matches: Dict of field_name -> list of Match objects
|
||||
image_width: Width of the rendered image in pixels
|
||||
@@ -82,6 +86,8 @@ class AnnotationGenerator:
|
||||
continue
|
||||
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
# Get class_name for bbox expansion strategy
|
||||
class_name = FIELD_TO_CLASS.get(field_name, field_name)
|
||||
|
||||
# Take only the best match per field
|
||||
if field_matches:
|
||||
@@ -94,19 +100,20 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = best_match.bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
# Apply field-specific bbox expansion strategy
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
current_height = y1 - y0
|
||||
if current_height < self.min_bbox_height_px:
|
||||
extra = (self.min_bbox_height_px - current_height) / 2
|
||||
y0 = max(0, y0 - extra)
|
||||
y1 = min(image_height, y1 + extra)
|
||||
y0 = max(0, int(y0 - extra))
|
||||
y1 = min(int(image_height), int(y1 + extra))
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
@@ -143,6 +150,9 @@ class AnnotationGenerator:
|
||||
"""
|
||||
Add payment_line annotation from machine code parser result.
|
||||
|
||||
Uses "payment_line" scale strategy for minimal expansion
|
||||
(machine-readable code needs less context).
|
||||
|
||||
Args:
|
||||
annotations: Existing list of annotations to append to
|
||||
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
||||
@@ -163,12 +173,13 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = payment_line_bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
# Apply field-specific bbox expansion strategy for payment_line
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type="payment_line",
|
||||
)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
|
||||
1
tests/shared/bbox/__init__.py
Normal file
1
tests/shared/bbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for shared.bbox module."""
|
||||
556
tests/shared/bbox/test_expander.py
Normal file
556
tests/shared/bbox/test_expander.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Tests for expand_bbox function.
|
||||
|
||||
Tests verify that bbox expansion works correctly with center-point scaling,
|
||||
directional compensation, max padding clamping, and image boundary handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
expand_bbox,
|
||||
ScaleStrategy,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
DEFAULT_STRATEGY,
|
||||
)
|
||||
|
||||
|
||||
class TestExpandBboxCenterScaling:
|
||||
"""Tests for center-point based scaling."""
|
||||
|
||||
def test_center_scaling_expands_symmetrically(self):
|
||||
"""Verify bbox expands symmetrically around center when no extra ratios."""
|
||||
# 100x50 bbox at (100, 200)
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider
|
||||
scale_y=1.4, # 40% taller
|
||||
max_pad_x=1000, # Large to avoid clamping
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Original: width=100, height=50
|
||||
# New: width=120, height=70
|
||||
# Center: (150, 225)
|
||||
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
|
||||
assert result[0] == 90 # x0
|
||||
assert result[1] == 190 # y0
|
||||
assert result[2] == 210 # x1
|
||||
assert result[3] == 260 # y1
|
||||
|
||||
def test_no_scaling_returns_original(self):
|
||||
"""Verify scale=1.0 with no extras returns original bbox."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result == (100, 200, 200, 250)
|
||||
|
||||
|
||||
class TestExpandBboxDirectionalCompensation:
|
||||
"""Tests for directional compensation (extra ratios)."""
|
||||
|
||||
def test_extra_top_expands_upward(self):
|
||||
"""Verify extra_top_ratio adds expansion toward top."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Add 50% of height to top
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_top = 50 * 0.5 = 25
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 175 # y0 = 200 - 25
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_left_expands_leftward(self):
|
||||
"""Verify extra_left_ratio adds expansion toward left."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.8, # Add 80% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_left = 100 * 0.8 = 80
|
||||
assert result[0] == 20 # x0 = 100 - 80
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_right_expands_rightward(self):
|
||||
"""Verify extra_right_ratio adds expansion toward right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.3, # Add 30% of width to right
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_right = 100 * 0.3 = 30
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_bottom_expands_downward(self):
|
||||
"""Verify extra_bottom_ratio adds expansion toward bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.4, # Add 40% of height to bottom
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_bottom = 50 * 0.4 = 20
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_combined_scaling_and_directional(self):
|
||||
"""Verify scale + directional compensation work together."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider -> 120 width
|
||||
scale_y=1.0, # no height change
|
||||
extra_left_ratio=0.5, # Add 50% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Center: x=150
|
||||
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
|
||||
# After extra_left: x0 = 90 - (100 * 0.5) = 40
|
||||
assert result[0] == 40 # x0
|
||||
assert result[2] == 210 # x1
|
||||
|
||||
|
||||
class TestExpandBboxMaxPadClamping:
|
||||
"""Tests for max padding clamping."""
|
||||
|
||||
def test_max_pad_x_limits_horizontal_expansion(self):
|
||||
"""Verify max_pad_x limits expansion on left and right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=2.0, # Double width (would add 50 each side)
|
||||
scale_y=1.0,
|
||||
max_pad_x=30, # Limit to 30 pixels each side
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
|
||||
# But max_pad_x=30 limits to: x0=70, x1=230
|
||||
assert result[0] == 70 # x0 = 100 - 30
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
|
||||
def test_max_pad_y_limits_vertical_expansion(self):
|
||||
"""Verify max_pad_y limits expansion on top and bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=3.0, # Triple height (would add 50 each side)
|
||||
max_pad_x=1000,
|
||||
max_pad_y=20, # Limit to 20 pixels each side
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: y0=175, y1=275 (50px each side)
|
||||
# But max_pad_y=20 limits to: y0=180, y1=270
|
||||
assert result[1] == 180 # y0 = 200 - 20
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_max_pad_preserves_asymmetry(self):
|
||||
"""Verify max_pad clamping preserves asymmetric expansion."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=1.0, # 100px left expansion
|
||||
extra_right_ratio=0.0, # No right expansion
|
||||
max_pad_x=50, # Limit to 50 pixels
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Left would expand 100, clamped to 50
|
||||
# Right stays at 0
|
||||
assert result[0] == 50 # x0 = 100 - 50
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
|
||||
|
||||
class TestExpandBboxImageBoundaryClamping:
|
||||
"""Tests for image boundary clamping."""
|
||||
|
||||
def test_clamps_to_left_boundary(self):
|
||||
"""Verify x0 is clamped to 0."""
|
||||
bbox = (10, 200, 110, 250) # Close to left edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.5, # Would push x0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[0] == 0 # Clamped to 0
|
||||
|
||||
def test_clamps_to_top_boundary(self):
|
||||
"""Verify y0 is clamped to 0."""
|
||||
bbox = (100, 10, 200, 60) # Close to top edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Would push y0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[1] == 0 # Clamped to 0
|
||||
|
||||
def test_clamps_to_right_boundary(self):
|
||||
"""Verify x1 is clamped to image_width."""
|
||||
bbox = (900, 200, 990, 250) # Close to right edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.5, # Would push x1 beyond image_width
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[2] == 1000 # Clamped to image_width
|
||||
|
||||
def test_clamps_to_bottom_boundary(self):
|
||||
"""Verify y1 is clamped to image_height."""
|
||||
bbox = (100, 940, 200, 990) # Close to bottom edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[3] == 1000 # Clamped to image_height
|
||||
|
||||
|
||||
class TestExpandBboxUnknownField:
|
||||
"""Tests for unknown field handling."""
|
||||
|
||||
def test_unknown_field_uses_default_strategy(self):
|
||||
"""Verify unknown field types use DEFAULT_STRATEGY."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="unknown_field_xyz",
|
||||
)
|
||||
|
||||
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
|
||||
# Original: width=100, height=50
|
||||
# New: width=115, height=57.5
|
||||
# Center: (150, 225)
|
||||
# x0 = 150 - 57.5 = 92.5 -> 92
|
||||
# x1 = 150 + 57.5 = 207.5 -> 207
|
||||
# y0 = 225 - 28.75 = 196.25 -> 196
|
||||
# y1 = 225 + 28.75 = 253.75 -> 253
|
||||
# But max_pad_x=50 may clamp...
|
||||
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
|
||||
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
|
||||
assert result[0] == 92
|
||||
assert result[2] == 207
|
||||
|
||||
|
||||
class TestExpandBboxWithRealStrategies:
|
||||
"""Tests using actual FIELD_SCALE_STRATEGIES."""
|
||||
|
||||
def test_ocr_number_expands_significantly_upward(self):
|
||||
"""Verify ocr_number field gets significant upward expansion."""
|
||||
bbox = (100, 200, 200, 230) # Small height=30
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
)
|
||||
|
||||
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
|
||||
# y0 should decrease significantly
|
||||
assert result[1] < 200 - 10 # At least 10px upward expansion
|
||||
|
||||
def test_bankgiro_expands_significantly_leftward(self):
|
||||
"""Verify bankgiro field gets significant leftward expansion."""
|
||||
bbox = (200, 200, 300, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
)
|
||||
|
||||
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
|
||||
# x0 should decrease significantly
|
||||
assert result[0] < 200 - 30 # At least 30px leftward expansion
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount field gets rightward expansion for currency."""
|
||||
bbox = (100, 200, 200, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="amount",
|
||||
)
|
||||
|
||||
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
|
||||
# x1 should increase
|
||||
assert result[2] > 200 + 10 # At least 10px rightward expansion
|
||||
|
||||
|
||||
class TestExpandBboxReturnType:
|
||||
"""Tests for return type and value format."""
|
||||
|
||||
def test_returns_tuple_of_four_ints(self):
|
||||
"""Verify return type is tuple of 4 integers."""
|
||||
bbox = (100.5, 200.3, 200.7, 250.9)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 4
|
||||
assert all(isinstance(v, int) for v in result)
|
||||
|
||||
def test_returns_valid_bbox_format(self):
|
||||
"""Verify returned bbox has x0 < x1 and y0 < y1."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 < x1, "x0 should be less than x1"
|
||||
assert y0 < y1, "y0 should be less than y1"
|
||||
|
||||
|
||||
class TestManualLabelMode:
|
||||
"""Tests for manual_mode parameter."""
|
||||
|
||||
def test_manual_mode_uses_minimal_padding(self):
|
||||
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Would normally expand left significantly
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
|
||||
# Should only add 10px padding each side (but scale=1.0 means no scaling)
|
||||
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
|
||||
# Only max_pad=10 applies as a limit, but there's no expansion to limit
|
||||
# So result should be same as original
|
||||
assert result == (100, 200, 200, 250)
|
||||
|
||||
def test_manual_mode_ignores_field_type(self):
|
||||
"""Verify manual_mode ignores field-specific strategies."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
# Different fields should give same result in manual_mode
|
||||
result_bankgiro = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
result_ocr = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
assert result_bankgiro == result_ocr
|
||||
|
||||
def test_manual_mode_vs_auto_mode_different(self):
|
||||
"""Verify manual_mode produces different results than auto mode."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
auto_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Has extra_left_ratio=0.80
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Auto mode should expand more (especially to the left for bankgiro)
|
||||
assert auto_result[0] < manual_result[0] # Auto x0 is more left
|
||||
|
||||
def test_manual_mode_clamps_to_image_bounds(self):
|
||||
"""Verify manual_mode still respects image boundaries."""
|
||||
bbox = (5, 5, 50, 50) # Close to top-left corner
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Should clamp to 0
|
||||
assert result[0] >= 0
|
||||
assert result[1] >= 0
|
||||
192
tests/shared/bbox/test_scale_strategy.py
Normal file
192
tests/shared/bbox/test_scale_strategy.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Tests for ScaleStrategy configuration.
|
||||
|
||||
Tests verify that scale strategies are properly defined, immutable,
|
||||
and cover all required fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from shared.fields import CLASS_NAMES
|
||||
|
||||
|
||||
class TestScaleStrategyDataclass:
|
||||
"""Tests for ScaleStrategy dataclass behavior."""
|
||||
|
||||
def test_default_strategy_values(self):
|
||||
"""Verify default strategy has expected default values."""
|
||||
strategy = ScaleStrategy()
|
||||
assert strategy.scale_x == 1.15
|
||||
assert strategy.scale_y == 1.15
|
||||
assert strategy.extra_top_ratio == 0.0
|
||||
assert strategy.extra_bottom_ratio == 0.0
|
||||
assert strategy.extra_left_ratio == 0.0
|
||||
assert strategy.extra_right_ratio == 0.0
|
||||
assert strategy.max_pad_x == 50
|
||||
assert strategy.max_pad_y == 50
|
||||
|
||||
def test_scale_strategy_immutability(self):
|
||||
"""Verify ScaleStrategy is frozen (immutable)."""
|
||||
strategy = ScaleStrategy()
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.scale_x = 2.0 # type: ignore
|
||||
|
||||
def test_custom_strategy_values(self):
|
||||
"""Verify custom values are properly set."""
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.5,
|
||||
scale_y=1.8,
|
||||
extra_top_ratio=0.6,
|
||||
extra_left_ratio=0.8,
|
||||
max_pad_x=100,
|
||||
max_pad_y=150,
|
||||
)
|
||||
assert strategy.scale_x == 1.5
|
||||
assert strategy.scale_y == 1.8
|
||||
assert strategy.extra_top_ratio == 0.6
|
||||
assert strategy.extra_left_ratio == 0.8
|
||||
assert strategy.max_pad_x == 100
|
||||
assert strategy.max_pad_y == 150
|
||||
|
||||
|
||||
class TestDefaultStrategy:
|
||||
"""Tests for DEFAULT_STRATEGY constant."""
|
||||
|
||||
def test_default_strategy_is_scale_strategy(self):
|
||||
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_default_strategy_matches_default_values(self):
|
||||
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
|
||||
expected = ScaleStrategy()
|
||||
assert DEFAULT_STRATEGY == expected
|
||||
|
||||
|
||||
class TestManualLabelStrategy:
|
||||
"""Tests for MANUAL_LABEL_STRATEGY constant."""
|
||||
|
||||
def test_manual_label_strategy_is_scale_strategy(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_manual_label_strategy_has_no_scaling(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
|
||||
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
|
||||
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
|
||||
|
||||
def test_manual_label_strategy_has_no_directional_expansion(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
|
||||
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
|
||||
|
||||
def test_manual_label_strategy_has_small_max_pad(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
|
||||
|
||||
|
||||
class TestFieldScaleStrategies:
|
||||
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
|
||||
|
||||
def test_all_class_names_have_strategies(self):
|
||||
"""Verify all field class names have defined strategies."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in FIELD_SCALE_STRATEGIES, (
|
||||
f"Missing strategy for field: {class_name}"
|
||||
)
|
||||
|
||||
def test_strategies_are_scale_strategy_instances(self):
|
||||
"""Verify all strategies are ScaleStrategy instances."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert isinstance(strategy, ScaleStrategy), (
|
||||
f"Strategy for {field_name} is not a ScaleStrategy"
|
||||
)
|
||||
|
||||
def test_scale_values_are_greater_than_one(self):
|
||||
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.scale_x >= 1.0, (
|
||||
f"{field_name} scale_x should be >= 1.0"
|
||||
)
|
||||
assert strategy.scale_y >= 1.0, (
|
||||
f"{field_name} scale_y should be >= 1.0"
|
||||
)
|
||||
|
||||
def test_extra_ratios_are_non_negative(self):
|
||||
"""Verify all extra ratios are >= 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.extra_top_ratio >= 0, (
|
||||
f"{field_name} extra_top_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_bottom_ratio >= 0, (
|
||||
f"{field_name} extra_bottom_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_left_ratio >= 0, (
|
||||
f"{field_name} extra_left_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_right_ratio >= 0, (
|
||||
f"{field_name} extra_right_ratio should be >= 0"
|
||||
)
|
||||
|
||||
def test_max_pad_values_are_positive(self):
|
||||
"""Verify all max_pad values are > 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.max_pad_x > 0, (
|
||||
f"{field_name} max_pad_x should be > 0"
|
||||
)
|
||||
assert strategy.max_pad_y > 0, (
|
||||
f"{field_name} max_pad_y should be > 0"
|
||||
)
|
||||
|
||||
|
||||
class TestSpecificFieldStrategies:
|
||||
"""Tests for specific field strategy configurations."""
|
||||
|
||||
def test_ocr_number_expands_upward(self):
|
||||
"""Verify ocr_number strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
|
||||
|
||||
def test_bankgiro_expands_leftward(self):
|
||||
"""Verify bankgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
|
||||
|
||||
def test_plusgiro_expands_leftward(self):
|
||||
"""Verify plusgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount strategy expands rightward for currency symbol."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["amount"]
|
||||
assert strategy.extra_right_ratio > 0.0
|
||||
|
||||
def test_invoice_date_expands_upward(self):
|
||||
"""Verify invoice_date strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
|
||||
def test_invoice_due_date_expands_upward_and_leftward(self):
|
||||
"""Verify invoice_due_date strategy expands both up and left."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
|
||||
def test_payment_line_has_minimal_expansion(self):
|
||||
"""Verify payment_line has conservative expansion (machine code)."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
|
||||
# Payment line is machine-readable, needs minimal expansion
|
||||
assert strategy.scale_x <= 1.2
|
||||
assert strategy.scale_y <= 1.3
|
||||
@@ -16,6 +16,7 @@ from shared.fields import (
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
FIELD_TO_CLASS,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
NUM_CLASSES,
|
||||
@@ -133,6 +134,20 @@ class TestMappingConsistency:
|
||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||
|
||||
def test_field_to_class_is_inverse_of_class_to_field(self):
|
||||
"""Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses."""
|
||||
for class_name, field_name in CLASS_TO_FIELD.items():
|
||||
assert FIELD_TO_CLASS[field_name] == class_name
|
||||
|
||||
for field_name, class_name in FIELD_TO_CLASS.items():
|
||||
assert CLASS_TO_FIELD[class_name] == field_name
|
||||
|
||||
def test_field_to_class_has_all_fields(self):
|
||||
"""Verify FIELD_TO_CLASS has mapping for all field names."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
assert fd.field_name in FIELD_TO_CLASS
|
||||
assert FIELD_TO_CLASS[fd.field_name] == fd.class_name
|
||||
|
||||
|
||||
class TestSpecificFieldDefinitions:
|
||||
"""Tests for specific field definitions to catch common mistakes."""
|
||||
|
||||
@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create mock table detection result
|
||||
# Create mock table detection result with proper thead/tbody structure
|
||||
mock_table = MagicMock(spec=TableDetectionResult)
|
||||
mock_table.html = """
|
||||
<table>
|
||||
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
|
||||
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
|
||||
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
|
||||
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
|
||||
assert len(result.items) >= 1
|
||||
|
||||
|
||||
class TestPdfPathValidation:
|
||||
"""Tests for PDF path validation."""
|
||||
|
||||
def test_detect_tables_with_nonexistent_path(self):
|
||||
"""Test that non-existent PDF path returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create detector and call _detect_tables_with_parsing with non-existent path
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, "nonexistent.pdf"
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_with_directory_path(self, tmp_path):
|
||||
"""Test that directory path (not file) returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
|
||||
# tmp_path is a directory, not a file
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(tmp_path)
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_validates_file_exists(self, tmp_path):
|
||||
"""Test path validation for file existence.
|
||||
|
||||
This test verifies that the method correctly validates the path exists
|
||||
and is a file before attempting to process it.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create a real file path that exists
|
||||
fake_pdf = tmp_path / "test.pdf"
|
||||
fake_pdf.write_bytes(b"not a real pdf")
|
||||
|
||||
# Mock render_pdf_to_images to avoid actual PDF processing
|
||||
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
|
||||
# Return empty iterator - simulates file exists but no pages
|
||||
mock_render.return_value = iter([])
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
mock_detector._ensure_initialized = MagicMock()
|
||||
mock_detector._pipeline = MagicMock()
|
||||
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(fake_pdf)
|
||||
)
|
||||
|
||||
# render_pdf_to_images was called (path validation passed)
|
||||
mock_render.assert_called_once()
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
|
||||
assert result.items[0].is_deduction is False
|
||||
assert result.items[1].amount == "-2000"
|
||||
assert result.items[1].is_deduction is True
|
||||
|
||||
|
||||
class TestTextFallbackExtraction:
|
||||
"""Tests for text-based fallback extraction."""
|
||||
|
||||
def test_text_fallback_disabled_by_default(self):
|
||||
"""Test text fallback can be disabled."""
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
assert extractor.enable_text_fallback is False
|
||||
|
||||
def test_text_fallback_enabled_by_default(self):
|
||||
"""Test text fallback is enabled by default."""
|
||||
extractor = LineItemsExtractor()
|
||||
assert extractor.enable_text_fallback is True
|
||||
|
||||
def test_try_text_fallback_with_valid_parsing_res(self):
|
||||
"""Test text fallback with valid parsing results."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import (
|
||||
TextLineItemsExtractor,
|
||||
TextLineItem,
|
||||
TextLineItemsResult,
|
||||
)
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Mock parsing_res_list with text elements
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
|
||||
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
|
||||
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
|
||||
]
|
||||
|
||||
# Create mock text extraction result
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
|
||||
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
|
||||
result = extractor._try_text_fallback(parsing_res)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].description == "Product A"
|
||||
assert result.items[1].description == "Product B"
|
||||
|
||||
def test_try_text_fallback_returns_none_on_failure(self):
|
||||
"""Test text fallback returns None when extraction fails."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
|
||||
result = extractor._try_text_fallback([])
|
||||
assert result is None
|
||||
|
||||
def test_extract_from_pdf_uses_text_fallback(self):
|
||||
"""Test extract_from_pdf uses text fallback when no tables found."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=True)
|
||||
|
||||
# Mock _detect_tables_with_parsing to return no tables but parsing_res
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product", amount="100,00"),
|
||||
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Text fallback should be called
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
def test_extract_from_pdf_skips_fallback_when_disabled(self):
|
||||
"""Test extract_from_pdf skips text fallback when disabled."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Should return None, not use text fallback
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVerticallyMergedCellExtraction:
|
||||
"""Tests for vertically merged cell extraction."""
|
||||
|
||||
def test_detects_vertically_merged_cells(self):
|
||||
"""Test detection of vertically merged cells in rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Rows with multiple product numbers in single cell
|
||||
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
|
||||
assert extractor._has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_splits_vertically_merged_rows(self):
|
||||
"""Test splitting vertically merged rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = extractor._split_merged_rows(rows)
|
||||
|
||||
# Should split into header + data rows
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestDeductionDetection:
|
||||
"""Tests for deduction/discount detection."""
|
||||
|
||||
def test_detects_deduction_by_keyword_avdrag(self):
|
||||
"""Test detection of deduction by 'avdrag' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_keyword_rabatt(self):
|
||||
"""Test detection of deduction by 'rabatt' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_negative_amount(self):
|
||||
"""Test detection of deduction by negative amount."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Some credit</td><td>-250,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_normal_item_not_deduction(self):
|
||||
"""Test normal item is not marked as deduction."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Normal product</td><td>500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is False
|
||||
|
||||
|
||||
class TestHeaderDetection:
|
||||
"""Tests for header row detection."""
|
||||
|
||||
def test_detect_header_at_bottom(self):
|
||||
"""Test detecting header at bottom of table (reversed)."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at bottom
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 2
|
||||
assert is_at_end is True
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_detect_header_at_top(self):
|
||||
"""Test detecting header at top of table."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at top
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 0
|
||||
assert is_at_end is False
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_no_header_detected(self):
|
||||
"""Test when no header is detected."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == -1
|
||||
assert header == []
|
||||
assert is_at_end is False
|
||||
|
||||
448
tests/table/test_merged_cell_handler.py
Normal file
448
tests/table/test_merged_cell_handler.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Tests for Merged Cell Handler
|
||||
|
||||
Tests the detection and extraction of data from tables with merged cells,
|
||||
a common issue with PP-StructureV3 OCR output.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
|
||||
from backend.table.html_table_parser import ColumnMapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Create a MergedCellHandler with default ColumnMapper."""
|
||||
return MergedCellHandler(ColumnMapper())
|
||||
|
||||
|
||||
class TestHasVerticallyMergedCells:
|
||||
"""Tests for has_vertically_merged_cells detection."""
|
||||
|
||||
def test_empty_rows_returns_false(self, handler):
|
||||
"""Test empty rows returns False."""
|
||||
assert handler.has_vertically_merged_cells([]) is False
|
||||
|
||||
def test_short_cells_ignored(self, handler):
|
||||
"""Test cells shorter than 20 chars are ignored."""
|
||||
rows = [["Short cell", "Also short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_product_numbers(self, handler):
|
||||
"""Test detection of multiple 7-digit product numbers in cell."""
|
||||
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_product_number_not_merged(self, handler):
|
||||
"""Test single product number doesn't trigger detection."""
|
||||
rows = [["Produktnr 1457280 and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_prices(self, handler):
|
||||
"""Test detection of 3+ prices in cell (Swedish format)."""
|
||||
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_two_prices_not_merged(self, handler):
|
||||
"""Test two prices doesn't trigger detection (needs 3+)."""
|
||||
rows = [["Pris 127,20 234,56 total amount here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_quantities(self, handler):
|
||||
"""Test detection of multiple quantity patterns."""
|
||||
rows = [["Antal 6ST 6ST 1ST more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_quantity_not_merged(self, handler):
|
||||
"""Test single quantity doesn't trigger detection."""
|
||||
rows = [["Antal 6ST and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_empty_cell_skipped(self, handler):
|
||||
"""Test empty cells are skipped."""
|
||||
rows = [["", None, "Valid but short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_multiple_rows_checked(self, handler):
|
||||
"""Test all rows are checked for merged content."""
|
||||
rows = [
|
||||
["Normal row with nothing special"],
|
||||
["Produktnr 1457280 1457281 1060381 merged content"],
|
||||
]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
|
||||
class TestSplitMergedRows:
|
||||
"""Tests for split_merged_rows method."""
|
||||
|
||||
def test_empty_rows_returns_empty(self, handler):
|
||||
"""Test empty rows returns empty result."""
|
||||
header, data = handler.split_merged_rows([])
|
||||
assert header == []
|
||||
assert data == []
|
||||
|
||||
def test_all_empty_rows_returns_original(self, handler):
|
||||
"""Test all empty rows returns original rows."""
|
||||
rows = [["", ""], ["", ""]]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting rows by product numbers."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
assert len(header) == 3
|
||||
assert header[0] == "Produktnr"
|
||||
assert len(data) == 2
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting rows by quantity patterns."""
|
||||
rows = [
|
||||
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should detect 2 quantities and split accordingly
|
||||
assert len(data) >= 1
|
||||
|
||||
def test_single_row_not_split(self, handler):
|
||||
"""Test single item row is not split."""
|
||||
rows = [
|
||||
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Only 1 product number, so expected_rows <= 1
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_handles_missing_columns(self, handler):
|
||||
"""Test handles rows with different column counts."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", ""],
|
||||
["Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestCountExpectedRows:
|
||||
"""Tests for _count_expected_rows helper."""
|
||||
|
||||
def test_counts_product_numbers(self, handler):
|
||||
"""Test counting product numbers."""
|
||||
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_counts_quantities(self, handler):
|
||||
"""Test counting quantity patterns."""
|
||||
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 4
|
||||
|
||||
def test_returns_max_count(self, handler):
|
||||
"""Test returns maximum count across columns."""
|
||||
columns = [
|
||||
"Produktnr 1234567 1234568", # 2 products
|
||||
"Antal 5ST 10ST 15ST", # 3 quantities
|
||||
]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_empty_columns_return_zero(self, handler):
|
||||
"""Test empty columns return 0."""
|
||||
columns = ["", None, "Short"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestSplitCellContentForRows:
|
||||
"""Tests for _split_cell_content_for_rows helper."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by product numbers with expected count."""
|
||||
cell = "Produktnr 1234567 1234568"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result[1]
|
||||
assert "1234568" in result[2]
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by quantity patterns."""
|
||||
cell = "Antal 5ST 10ST"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Antal"
|
||||
|
||||
def test_splits_discount_totalsumma(self, handler):
|
||||
"""Test splitting discount+totalsumma columns."""
|
||||
cell = "Rabatt i% Totalsumma 686,88 123,45"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result[1]
|
||||
assert "123,45" in result[2]
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by price patterns."""
|
||||
cell = "Pris 127,20 234,56"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_fallback_returns_original(self, handler):
|
||||
"""Test fallback returns original cell."""
|
||||
cell = "No patterns here"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result == ["No patterns here"]
|
||||
|
||||
def test_product_number_with_description(self, handler):
|
||||
"""Test product numbers include trailing description text."""
|
||||
cell = "Art 1234567 Widget A 1234568 Widget B"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
class TestSplitCellContent:
|
||||
"""Tests for split_cell_content method."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by multiple product numbers."""
|
||||
cell = "Produktnr 1234567 1234568 1234569"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result
|
||||
assert "1234568" in result
|
||||
assert "1234569" in result
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by multiple quantities."""
|
||||
cell = "Antal 6ST 6ST 1ST"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Antal"
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_splits_discount_amount_interleaved(self, handler):
|
||||
"""Test splitting interleaved discount+amount patterns."""
|
||||
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
# Should extract amounts (3+ digit numbers with decimals)
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result
|
||||
assert "123,45" in result
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by prices."""
|
||||
cell = "Pris 127,20 127,20 159,20"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Pris"
|
||||
|
||||
def test_single_value_not_split(self, handler):
|
||||
"""Test single value is not split."""
|
||||
cell = "Single value"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Single value"]
|
||||
|
||||
def test_single_product_not_split(self, handler):
|
||||
"""Test single product number is not split."""
|
||||
cell = "Produktnr 1234567"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Produktnr 1234567"]
|
||||
|
||||
|
||||
class TestHasMergedHeader:
|
||||
"""Tests for has_merged_header method."""
|
||||
|
||||
def test_none_header_returns_false(self, handler):
|
||||
"""Test None header returns False."""
|
||||
assert handler.has_merged_header(None) is False
|
||||
|
||||
def test_empty_header_returns_false(self, handler):
|
||||
"""Test empty header returns False."""
|
||||
assert handler.has_merged_header([]) is False
|
||||
|
||||
def test_multiple_non_empty_cells_returns_false(self, handler):
|
||||
"""Test multiple non-empty cells returns False."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_single_cell_with_keywords_returns_true(self, handler):
|
||||
"""Test single cell with multiple keywords returns True."""
|
||||
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
def test_single_cell_one_keyword_returns_false(self, handler):
|
||||
"""Test single cell with only one keyword returns False."""
|
||||
header = ["Beskrivning only"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_ignores_empty_trailing_cells(self, handler):
|
||||
"""Test ignores empty trailing cells."""
|
||||
header = ["Specifikation Hyra Avdrag", "", "", ""]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
|
||||
class TestExtractFromMergedCells:
|
||||
"""Tests for extract_from_merged_cells method."""
|
||||
|
||||
def test_extracts_single_amount(self, handler):
|
||||
"""Test extracting a single amount."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
assert items[0].is_deduction is False
|
||||
assert items[0].article_number == "0218103-1201"
|
||||
assert items[0].description == "2 rum och kök"
|
||||
|
||||
def test_extracts_deduction(self, handler):
|
||||
"""Test extracting a deduction (negative amount)."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "-2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "-2000"
|
||||
assert items[0].is_deduction is True
|
||||
# First item (row_index=0) gets description from header, not "Avdrag"
|
||||
# "Avdrag" is only set for subsequent deduction items
|
||||
assert items[0].description is None
|
||||
|
||||
def test_extracts_multiple_amounts_same_row(self, handler):
|
||||
"""Test extracting multiple amounts from same row."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159 -2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0].amount == "8159"
|
||||
assert items[1].amount == "-2000"
|
||||
|
||||
def test_extracts_amounts_from_multiple_rows(self, handler):
|
||||
"""Test extracting amounts from multiple rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
def test_skips_small_amounts(self, handler):
|
||||
"""Test skipping small amounts below threshold."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_skips_empty_rows(self, handler):
|
||||
"""Test skipping empty rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", ""]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_handles_swedish_format_with_spaces(self, handler):
|
||||
"""Test handling Swedish number format with spaces."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8 159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
|
||||
def test_confidence_is_lower_for_merged(self, handler):
|
||||
"""Test confidence is 0.7 for merged cell extraction."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert items[0].confidence == 0.7
|
||||
|
||||
def test_empty_header_still_extracts(self, handler):
|
||||
"""Test extraction works with empty header."""
|
||||
header = []
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].description is None
|
||||
assert items[0].article_number is None
|
||||
|
||||
def test_row_index_increments(self, handler):
|
||||
"""Test row_index increments for each item."""
|
||||
header = ["Specifikation"]
|
||||
# Use separate rows to avoid regex grouping issues
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "5000"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
# Should have 3 items from 3 rows
|
||||
assert len(items) == 3
|
||||
assert items[0].row_index == 0
|
||||
assert items[1].row_index == 1
|
||||
assert items[2].row_index == 2
|
||||
|
||||
|
||||
class TestMinAmountThreshold:
|
||||
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
|
||||
|
||||
def test_threshold_value(self):
|
||||
"""Test the threshold constant value."""
|
||||
assert MIN_AMOUNT_THRESHOLD == 100
|
||||
|
||||
def test_amounts_at_threshold_included(self, handler):
|
||||
"""Test amounts exactly at threshold are included."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "100"]] # Exactly at threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "100"
|
||||
|
||||
def test_amounts_below_threshold_excluded(self, handler):
|
||||
"""Test amounts below threshold are excluded."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "99"]] # Below threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
157
tests/table/test_models.py
Normal file
157
tests/table/test_models.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Tests for Line Items Data Models
|
||||
|
||||
Tests for LineItem and LineItemsResult dataclasses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.models import LineItem, LineItemsResult
|
||||
|
||||
|
||||
class TestLineItem:
|
||||
"""Tests for LineItem dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values for optional fields."""
|
||||
item = LineItem(row_index=0)
|
||||
|
||||
assert item.row_index == 0
|
||||
assert item.description is None
|
||||
assert item.quantity is None
|
||||
assert item.unit is None
|
||||
assert item.unit_price is None
|
||||
assert item.amount is None
|
||||
assert item.article_number is None
|
||||
assert item.vat_rate is None
|
||||
assert item.is_deduction is False
|
||||
assert item.confidence == 0.9
|
||||
|
||||
def test_custom_confidence(self):
|
||||
"""Test setting custom confidence."""
|
||||
item = LineItem(row_index=0, confidence=0.7)
|
||||
assert item.confidence == 0.7
|
||||
|
||||
def test_is_deduction_true(self):
|
||||
"""Test is_deduction flag."""
|
||||
item = LineItem(row_index=0, is_deduction=True)
|
||||
assert item.is_deduction is True
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
def test_total_amount_empty_items(self):
|
||||
"""Test total_amount returns None for empty items."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_single_item(self):
|
||||
"""Test total_amount with single item."""
|
||||
items = [LineItem(row_index=0, amount="100,00")]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_multiple_items(self):
|
||||
"""Test total_amount with multiple items."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="200,50"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "300,50"
|
||||
|
||||
def test_total_amount_with_deduction(self):
|
||||
"""Test total_amount includes negative amounts (deductions)."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1000,00"),
|
||||
LineItem(row_index=1, amount="-200,00", is_deduction=True),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "800,00"
|
||||
|
||||
def test_total_amount_swedish_format_with_spaces(self):
|
||||
"""Test total_amount handles Swedish format with spaces."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1 234,56"),
|
||||
LineItem(row_index=1, amount="2 000,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "3 234,56"
|
||||
|
||||
def test_total_amount_invalid_amount_skipped(self):
|
||||
"""Test total_amount skips invalid amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="invalid"),
|
||||
LineItem(row_index=2, amount="200,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
# Invalid amount is skipped
|
||||
assert result.total_amount == "300,00"
|
||||
|
||||
def test_total_amount_none_amount_skipped(self):
|
||||
"""Test total_amount skips None amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount=None),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_all_invalid_returns_none(self):
|
||||
"""Test total_amount returns None when all amounts are invalid."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="invalid"),
|
||||
LineItem(row_index=1, amount="also invalid"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_large_numbers(self):
|
||||
"""Test total_amount handles large numbers."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="123 456,78"),
|
||||
LineItem(row_index=1, amount="876 543,22"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "1 000 000,00"
|
||||
|
||||
def test_total_amount_decimal_precision(self):
|
||||
"""Test total_amount maintains decimal precision."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="0,01"),
|
||||
LineItem(row_index=1, amount="0,02"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "0,03"
|
||||
|
||||
def test_is_reversed_default_false(self):
|
||||
"""Test is_reversed defaults to False."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.is_reversed is False
|
||||
|
||||
def test_is_reversed_can_be_set(self):
|
||||
"""Test is_reversed can be set to True."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
|
||||
assert result.is_reversed is True
|
||||
|
||||
def test_header_row_preserved(self):
|
||||
"""Test header_row is preserved."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
result = LineItemsResult(items=[], header_row=header, raw_html="")
|
||||
assert result.header_row == header
|
||||
|
||||
def test_raw_html_preserved(self):
|
||||
"""Test raw_html is preserved."""
|
||||
html = "<table><tr><td>Test</td></tr></table>"
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html=html)
|
||||
assert result.raw_html == html
|
||||
@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
|
||||
assert len(results) == 1
|
||||
assert results[0].cells == [] # Empty cells list
|
||||
assert results[0].html == "<table></table>"
|
||||
|
||||
def test_parse_paddlex_result_with_dict_ocr_data(self):
|
||||
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
|
||||
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
|
||||
"table_ocr_pred": {
|
||||
"rec_texts": ["A", "B"],
|
||||
"rec_scores": [0.99, 0.98],
|
||||
},
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "table", "bbox": [10, 20, 200, 300]},
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0].cells) == 2
|
||||
assert results[0].cells[0]["text"] == "A"
|
||||
assert results[0].cells[1]["text"] == "B"
|
||||
|
||||
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
|
||||
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20]],
|
||||
"pred_html": "<table><tr><td>A</td></tr></table>",
|
||||
"table_ocr_pred": ["A"],
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should use default bbox [0,0,0,0] when not found
|
||||
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
|
||||
|
||||
|
||||
class TestIteratorResults:
|
||||
"""Tests for iterator/generator result handling."""
|
||||
|
||||
def test_handles_iterator_results(self):
|
||||
"""Test handling of iterator results from pipeline."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Return a generator instead of list
|
||||
def result_generator():
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.html = "<table></table>"
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
yield mock_result
|
||||
|
||||
mock_pipeline.predict.return_value = result_generator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_handles_failed_iterator_conversion(self):
|
||||
"""Test handling when iterator conversion fails."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Create an object that has __iter__ but fails when converted to list
|
||||
class FailingIterator:
|
||||
def __iter__(self):
|
||||
raise RuntimeError("Iterator failed")
|
||||
|
||||
mock_pipeline.predict.return_value = FailingIterator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
# Should return empty list, not raise
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestPathConversion:
|
||||
"""Tests for path handling."""
|
||||
|
||||
def test_converts_path_object_to_string(self):
|
||||
"""Test that Path objects are converted to strings."""
|
||||
from pathlib import Path
|
||||
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.predict.return_value = []
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
path = Path("/some/path/to/image.png")
|
||||
|
||||
detector.detect(path)
|
||||
|
||||
# Should be called with string, not Path
|
||||
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
|
||||
|
||||
|
||||
class TestHtmlExtraction:
|
||||
"""Tests for HTML extraction from different element formats."""
|
||||
|
||||
def test_extracts_html_from_res_dict(self):
|
||||
"""Test extracting HTML from element.res dictionary."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove direct html attribute
|
||||
del element.html
|
||||
del element.table_html
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
|
||||
|
||||
def test_returns_empty_html_when_not_found(self):
|
||||
"""Test empty HTML when no html attribute found."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove all html attributes
|
||||
del element.html
|
||||
del element.table_html
|
||||
del element.res
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == ""
|
||||
|
||||
|
||||
class TestTableTypeDetection:
|
||||
"""Tests for table type detection."""
|
||||
|
||||
def test_detects_borderless_table(self):
|
||||
"""Test detection of borderless table type via _get_table_type."""
|
||||
detector = TableDetector()
|
||||
|
||||
# Create mock element with borderless label
|
||||
element = MagicMock()
|
||||
element.label = "borderless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_detects_wireless_table_label(self):
|
||||
"""Test detection of wireless table type."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "wireless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_defaults_to_wired_table(self):
|
||||
"""Test default table type is wired."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wired"
|
||||
|
||||
def test_type_attribute_instead_of_label(self):
|
||||
"""Test table type detection using type attribute."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.type = "wireless"
|
||||
del element.label # Remove label
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
|
||||
class TestPipelineRuntimeError:
|
||||
"""Tests for pipeline runtime errors."""
|
||||
|
||||
def test_raises_runtime_error_when_pipeline_none(self):
|
||||
"""Test RuntimeError when pipeline is None during detect."""
|
||||
detector = TableDetector()
|
||||
detector._initialized = True # Bypass lazy init
|
||||
detector._pipeline = None
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
detector.detect(image)
|
||||
|
||||
assert "not initialized" in str(exc_info.value).lower()
|
||||
|
||||
@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
|
||||
rows = extractor._group_by_row(elements)
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_group_by_row_varying_heights_uses_average(self, extractor):
|
||||
"""Test grouping handles varying element heights using dynamic average.
|
||||
|
||||
When elements have varying heights, the row center should be recalculated
|
||||
as new elements are added, preventing tall elements from being incorrectly
|
||||
grouped with the next row.
|
||||
"""
|
||||
# First element: small height, center_y = 105
|
||||
# Second element: tall, center_y = 115 (but should still be same row)
|
||||
# Third element: next row, center_y = 160
|
||||
elements = [
|
||||
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
|
||||
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
|
||||
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
|
||||
]
|
||||
rows = extractor._group_by_row(elements)
|
||||
|
||||
# With dynamic average, both first and second element should be same row
|
||||
assert len(rows) == 2
|
||||
assert len(rows[0]) == 2 # Short and Tall item
|
||||
assert len(rows[1]) == 1 # Next row
|
||||
|
||||
def test_group_by_row_empty_input(self, extractor):
|
||||
"""Test grouping with empty input returns empty list."""
|
||||
rows = extractor._group_by_row([])
|
||||
assert rows == []
|
||||
|
||||
def test_looks_like_line_item_with_amount(self, extractor):
|
||||
"""Test line item detection with amount."""
|
||||
row = [
|
||||
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
|
||||
assert len(elements) == 4
|
||||
|
||||
|
||||
class TestExceptionHandling:
|
||||
"""Tests for exception handling in text element extraction."""
|
||||
|
||||
def test_extract_text_elements_handles_missing_bbox(self):
|
||||
"""Test that missing bbox is handled gracefully."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "text": "No bbox"}, # Missing bbox
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
# Should only have 1 valid element
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_invalid_bbox(self):
|
||||
"""Test that invalid bbox (less than 4 values) is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_none_text(self):
|
||||
"""Test that None text is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_empty_string(self):
|
||||
"""Test that empty string text is skipped."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_malformed_element(self):
|
||||
"""Test that completely malformed elements are handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
"not a dict", # String instead of dict
|
||||
123, # Number instead of dict
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
|
||||
class TestConvertTextLineItem:
|
||||
"""Tests for convert_text_line_item function."""
|
||||
|
||||
|
||||
1
tests/training/__init__.py
Normal file
1
tests/training/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for training package."""
|
||||
1
tests/training/yolo/__init__.py
Normal file
1
tests/training/yolo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for training.yolo module."""
|
||||
342
tests/training/yolo/test_annotation_generator.py
Normal file
342
tests/training/yolo/test_annotation_generator.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for AnnotationGenerator with field-specific bbox expansion.
|
||||
|
||||
Tests verify that annotations are generated correctly using
|
||||
field-specific scale strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pytest
|
||||
|
||||
from training.yolo.annotation_generator import (
|
||||
AnnotationGenerator,
|
||||
YOLOAnnotation,
|
||||
)
|
||||
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMatch:
|
||||
"""Mock Match object for testing."""
|
||||
bbox: tuple[float, float, float, float]
|
||||
score: float
|
||||
|
||||
|
||||
class TestYOLOAnnotation:
|
||||
"""Tests for YOLOAnnotation dataclass."""
|
||||
|
||||
def test_to_string_format(self):
|
||||
"""Verify YOLO format string output."""
|
||||
ann = YOLOAnnotation(
|
||||
class_id=0,
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.1,
|
||||
height=0.05,
|
||||
confidence=0.9
|
||||
)
|
||||
result = ann.to_string()
|
||||
assert result == "0 0.500000 0.500000 0.100000 0.050000"
|
||||
|
||||
def test_default_confidence(self):
|
||||
"""Verify default confidence is 1.0."""
|
||||
ann = YOLOAnnotation(
|
||||
class_id=0,
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.1,
|
||||
height=0.05,
|
||||
)
|
||||
assert ann.confidence == 1.0
|
||||
|
||||
|
||||
class TestAnnotationGeneratorInit:
|
||||
"""Tests for AnnotationGenerator initialization."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Verify default initialization values."""
|
||||
gen = AnnotationGenerator()
|
||||
assert gen.min_confidence == 0.7
|
||||
assert gen.min_bbox_height_px == 30
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Verify custom initialization values."""
|
||||
gen = AnnotationGenerator(
|
||||
min_confidence=0.8,
|
||||
min_bbox_height_px=40,
|
||||
)
|
||||
assert gen.min_confidence == 0.8
|
||||
assert gen.min_bbox_height_px == 40
|
||||
|
||||
|
||||
class TestGenerateFromMatches:
|
||||
"""Tests for generate_from_matches method."""
|
||||
|
||||
def test_generates_annotation_for_valid_match(self):
|
||||
"""Verify annotation is generated for valid match."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Mock match in PDF points (72 DPI)
|
||||
# At 150 DPI, coords multiply by 150/72 = 2.083
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
ann = annotations[0]
|
||||
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
|
||||
assert ann.confidence == 0.8
|
||||
# Normalized values should be in 0-1 range
|
||||
assert 0 <= ann.x_center <= 1
|
||||
assert 0 <= ann.y_center <= 1
|
||||
assert 0 < ann.width <= 1
|
||||
assert 0 < ann.height <= 1
|
||||
|
||||
def test_skips_low_confidence_match(self):
|
||||
"""Verify low confidence matches are skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.7)
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_skips_unknown_field(self):
|
||||
"""Verify unknown fields are skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
matches = {
|
||||
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_takes_best_match_only(self):
|
||||
"""Verify only the best match is used per field."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": [
|
||||
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
|
||||
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
|
||||
]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0].confidence == 0.9
|
||||
|
||||
def test_handles_empty_matches(self):
|
||||
"""Verify empty matches list is handled."""
|
||||
gen = AnnotationGenerator()
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": []
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_applies_field_specific_expansion(self):
|
||||
"""Verify different fields get different expansion."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Same bbox, different fields
|
||||
bbox = (100, 200, 200, 230)
|
||||
|
||||
matches_invoice_number = {
|
||||
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
|
||||
}
|
||||
matches_bankgiro = {
|
||||
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
|
||||
}
|
||||
|
||||
ann_invoice = gen.generate_from_matches(
|
||||
matches=matches_invoice_number,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)[0]
|
||||
|
||||
ann_bankgiro = gen.generate_from_matches(
|
||||
matches=matches_bankgiro,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)[0]
|
||||
|
||||
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
|
||||
# They should have different widths due to different expansion
|
||||
# Bankgiro expands more to the left
|
||||
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
|
||||
|
||||
def test_enforces_min_bbox_height(self):
|
||||
"""Verify minimum bbox height is enforced."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
|
||||
|
||||
# Very small bbox
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=72 # 1:1 scale
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
# Height should be at least min_bbox_height_px / image_height
|
||||
# After scale strategy expansion, height should be >= 50/1000 = 0.05
|
||||
# Actually the min_bbox_height check happens AFTER expand_bbox
|
||||
# So the final height should meet the minimum
|
||||
|
||||
|
||||
class TestAddPaymentLineAnnotation:
|
||||
"""Tests for add_payment_line_annotation method."""
|
||||
|
||||
def test_adds_payment_line_annotation(self):
|
||||
"""Verify payment_line annotation is added."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
ann = result[0]
|
||||
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||
assert ann.confidence == 0.9
|
||||
|
||||
def test_skips_none_bbox(self):
|
||||
"""Verify None bbox is handled."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=None,
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_skips_low_confidence(self):
|
||||
"""Verify low confidence is skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.7)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.5,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_appends_to_existing_annotations(self):
|
||||
"""Verify payment_line is appended to existing list."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=existing,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].class_id == 0 # Original
|
||||
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||
|
||||
|
||||
class TestMultipleFieldsIntegration:
|
||||
"""Integration tests for multiple fields."""
|
||||
|
||||
def test_generates_annotations_for_all_field_types(self):
|
||||
"""Verify annotations can be generated for all field types."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Create matches for each field (except payment_line which is derived)
|
||||
field_names = [
|
||||
"InvoiceNumber",
|
||||
"InvoiceDate",
|
||||
"InvoiceDueDate",
|
||||
"OCR",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"Amount",
|
||||
"supplier_organisation_number",
|
||||
"customer_number",
|
||||
]
|
||||
|
||||
matches = {}
|
||||
for i, field_name in enumerate(field_names):
|
||||
# Stagger bboxes to avoid overlap
|
||||
matches[field_name] = [
|
||||
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
|
||||
]
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=2000,
|
||||
image_height=2000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == len(field_names)
|
||||
|
||||
# Verify all class_ids are present
|
||||
class_ids = {ann.class_id for ann in annotations}
|
||||
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
|
||||
assert class_ids == expected_class_ids
|
||||
Reference in New Issue
Block a user