This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -7,27 +7,23 @@ from pathlib import Path
from training.yolo.db_dataset import DBYOLODataset
from training.yolo.annotation_generator import YOLOAnnotation
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
from shared.bbox import UNIFORM_PAD
from shared.fields import CLASS_NAMES
class TestConvertLabelsWithExpandBbox:
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
"""Tests for _convert_labels using uniform expand_bbox."""
def test_convert_labels_uses_expand_bbox(self):
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
# Create a mock dataset without loading from DB
"""Verify _convert_labels calls expand_bbox with uniform padding."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Create annotation for bankgiro (has extra_left_ratio)
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
# center: (150, 225), width: 100, height: 50
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro
x_center=150, # in PDF points
x_center=150,
y_center=225,
width=100,
height=50,
@@ -35,48 +31,26 @@ class TestConvertLabelsWithExpandBbox:
)
]
# Image size in pixels (at 300 DPI)
img_width = 2480 # A4 width at 300 DPI
img_height = 3508 # A4 height at 300 DPI
img_width = 2480
img_height = 3508
# Convert labels
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# Should have one label
assert labels.shape == (1, 5)
# Check class_id
assert labels[0, 0] == 4
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
# Original bbox at 300 DPI:
# x0 = 100 * (300/72) = 416.67
# y0 = 200 * (300/72) = 833.33
# x1 = 200 * (300/72) = 833.33
# y1 = 250 * (300/72) = 1041.67
# width_px = 416.67, height_px = 208.33
# After expand_bbox with bankgiro strategy:
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
# The x_center should shift left due to extra_left_ratio
x_center = labels[0, 1]
y_center = labels[0, 2]
width = labels[0, 3]
height = labels[0, 4]
# Verify normalized values are in valid range
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 < width <= 1
assert 0 < height <= 1
# Width should be larger than original due to scaling and extra_left
# Original normalized width: 416.67 / 2480 = 0.168
# After bankgiro expansion it should be wider
assert width > 0.168
def test_convert_labels_different_field_types(self):
"""Verify different field types use their specific strategies."""
def test_convert_labels_all_fields_get_same_expansion(self):
"""Verify all field types get the same uniform expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
@@ -84,7 +58,6 @@ class TestConvertLabelsWithExpandBbox:
img_width = 2480
img_height = 3508
# Same bbox for different field types
base_annotation = {
'x_center': 150,
'y_center': 225,
@@ -93,30 +66,20 @@ class TestConvertLabelsWithExpandBbox:
'confidence': 0.9
}
# OCR number (class_id=3) - has extra_top_ratio=0.60
# All field types should get the same uniform expansion
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
# Amount (class_id=6) - has extra_right_ratio=0.30
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
# x_center and y_center should be the same (uniform padding is symmetric)
assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001
assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001
# Each field type should have different expansion
# OCR should expand more vertically (extra_top)
# Bankgiro should expand more to the left
# Amount should expand more to the right
# OCR: extra_top shifts y_center up
# Bankgiro: extra_left shifts x_center left
# So bankgiro x_center < OCR x_center
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
# OCR has higher scale_y (1.80) than amount (1.35)
assert ocr_labels[0, 4] > amount_labels[0, 4]
# width and height should also be the same
assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001
assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001
def test_convert_labels_clamps_to_image_bounds(self):
"""Verify labels are clamped to image boundaries."""
@@ -124,11 +87,10 @@ class TestConvertLabelsWithExpandBbox:
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Annotation near edge of image (in PDF points)
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro - will expand left
x_center=30, # Very close to left edge
class_id=4,
x_center=30,
y_center=50,
width=40,
height=30,
@@ -141,11 +103,10 @@ class TestConvertLabelsWithExpandBbox:
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# All values should be in valid range
assert 0 <= labels[0, 1] <= 1 # x_center
assert 0 <= labels[0, 2] <= 1 # y_center
assert 0 < labels[0, 3] <= 1 # width
assert 0 < labels[0, 4] <= 1 # height
assert 0 <= labels[0, 1] <= 1
assert 0 <= labels[0, 2] <= 1
assert 0 < labels[0, 3] <= 1
assert 0 < labels[0, 4] <= 1
def test_convert_labels_empty_annotations(self):
"""Verify empty annotations return empty array."""
@@ -162,23 +123,21 @@ class TestConvertLabelsWithExpandBbox:
"""Verify minimum height is enforced after expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 50 # Higher minimum
dataset.min_bbox_height_px = 50
# Very small annotation
annotations = [
YOLOAnnotation(
class_id=9, # payment_line - minimal expansion
class_id=9,
x_center=100,
y_center=100,
width=200,
height=5, # Very small height
height=5,
confidence=0.9
)
]
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
# Height should be at least min_bbox_height_px / img_height
min_normalized_height = 50 / 3508
assert labels[0, 4] >= min_normalized_height
@@ -190,25 +149,23 @@ class TestCreateAnnotationWithClassName:
"""Verify _create_annotation stores class_name for later use."""
dataset = object.__new__(DBYOLODataset)
# Create annotation for invoice_number
annotation = dataset._create_annotation(
field_name="InvoiceNumber",
bbox=[100, 200, 200, 250],
score=0.9
)
assert annotation.class_id == 0 # invoice_number class_id
assert annotation.class_id == 0
class TestLoadLabelsFromDbWithClassName:
"""Tests for _load_labels_from_db preserving field_name for expansion."""
def test_load_labels_maps_field_names_correctly(self):
"""Verify field names are mapped correctly for expand_bbox."""
"""Verify field names are mapped correctly."""
dataset = object.__new__(DBYOLODataset)
dataset.min_confidence = 0.7
# Mock database
mock_db = MagicMock()
mock_db.get_documents_batch.return_value = {
'doc1': {
@@ -240,12 +197,7 @@ class TestLoadLabelsFromDbWithClassName:
assert 'doc1' in result
page_labels, is_scanned, csv_split = result['doc1']
# Should have 2 annotations on page 0
assert 0 in page_labels
assert len(page_labels[0]) == 2
# First annotation: Bankgiro (class_id=4)
assert page_labels[0][0].class_id == 4
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
assert page_labels[0][1].class_id == 5