Files
invoice-master-poc-v2/tests/training/yolo/test_db_dataset.py
Yaojia Wang ad5ed46b4c WIP
2026-02-11 23:40:38 +01:00

204 lines
6.4 KiB
Python

"""Tests for db_dataset.py expand_bbox integration."""
import numpy as np
import pytest
from unittest.mock import MagicMock, patch
from pathlib import Path
from training.yolo.db_dataset import DBYOLODataset
from training.yolo.annotation_generator import YOLOAnnotation
from shared.bbox import UNIFORM_PAD
from shared.fields import CLASS_NAMES
class TestConvertLabelsWithExpandBbox:
"""Tests for _convert_labels using uniform expand_bbox."""
def test_convert_labels_uses_expand_bbox(self):
"""Verify _convert_labels calls expand_bbox with uniform padding."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro
x_center=150,
y_center=225,
width=100,
height=50,
confidence=0.9
)
]
img_width = 2480
img_height = 3508
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
assert labels.shape == (1, 5)
assert labels[0, 0] == 4
x_center = labels[0, 1]
y_center = labels[0, 2]
width = labels[0, 3]
height = labels[0, 4]
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 < width <= 1
assert 0 < height <= 1
def test_convert_labels_all_fields_get_same_expansion(self):
"""Verify all field types get the same uniform expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
img_width = 2480
img_height = 3508
base_annotation = {
'x_center': 150,
'y_center': 225,
'width': 100,
'height': 50,
'confidence': 0.9
}
# All field types should get the same uniform expansion
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
# x_center and y_center should be the same (uniform padding is symmetric)
assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001
assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001
# width and height should also be the same
assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001
assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001
def test_convert_labels_clamps_to_image_bounds(self):
"""Verify labels are clamped to image boundaries."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
annotations = [
YOLOAnnotation(
class_id=4,
x_center=30,
y_center=50,
width=40,
height=30,
confidence=0.9
)
]
img_width = 2480
img_height = 3508
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
assert 0 <= labels[0, 1] <= 1
assert 0 <= labels[0, 2] <= 1
assert 0 < labels[0, 3] <= 1
assert 0 < labels[0, 4] <= 1
def test_convert_labels_empty_annotations(self):
"""Verify empty annotations return empty array."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
labels = dataset._convert_labels([], 2480, 3508, is_scanned=False)
assert labels.shape == (0, 5)
assert labels.dtype == np.float32
def test_convert_labels_minimum_height(self):
"""Verify minimum height is enforced after expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 50
annotations = [
YOLOAnnotation(
class_id=9,
x_center=100,
y_center=100,
width=200,
height=5,
confidence=0.9
)
]
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
min_normalized_height = 50 / 3508
assert labels[0, 4] >= min_normalized_height
class TestCreateAnnotationWithClassName:
"""Tests for _create_annotation storing class_name for expand_bbox lookup."""
def test_create_annotation_stores_class_name(self):
"""Verify _create_annotation stores class_name for later use."""
dataset = object.__new__(DBYOLODataset)
annotation = dataset._create_annotation(
field_name="InvoiceNumber",
bbox=[100, 200, 200, 250],
score=0.9
)
assert annotation.class_id == 0
class TestLoadLabelsFromDbWithClassName:
"""Tests for _load_labels_from_db preserving field_name for expansion."""
def test_load_labels_maps_field_names_correctly(self):
"""Verify field names are mapped correctly."""
dataset = object.__new__(DBYOLODataset)
dataset.min_confidence = 0.7
mock_db = MagicMock()
mock_db.get_documents_batch.return_value = {
'doc1': {
'success': True,
'pdf_type': 'text',
'split': 'train',
'field_results': [
{
'matched': True,
'field_name': 'Bankgiro',
'score': 0.9,
'bbox': [100, 200, 200, 250],
'page_no': 0
},
{
'matched': True,
'field_name': 'supplier_accounts(Plusgiro)',
'score': 0.85,
'bbox': [300, 400, 400, 450],
'page_no': 0
}
]
}
}
dataset.db = mock_db
result = dataset._load_labels_from_db(['doc1'])
assert 'doc1' in result
page_labels, is_scanned, csv_split = result['doc1']
assert 0 in page_labels
assert len(page_labels[0]) == 2
assert page_labels[0][0].class_id == 4
assert page_labels[0][1].class_id == 5