"""Tests for db_dataset.py expand_bbox integration.""" import numpy as np import pytest from unittest.mock import MagicMock, patch from pathlib import Path from training.yolo.db_dataset import DBYOLODataset from training.yolo.annotation_generator import YOLOAnnotation from shared.bbox import UNIFORM_PAD from shared.fields import CLASS_NAMES class TestConvertLabelsWithExpandBbox: """Tests for _convert_labels using uniform expand_bbox.""" def test_convert_labels_uses_expand_bbox(self): """Verify _convert_labels calls expand_bbox with uniform padding.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 annotations = [ YOLOAnnotation( class_id=4, # bankgiro x_center=150, y_center=225, width=100, height=50, confidence=0.9 ) ] img_width = 2480 img_height = 3508 labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) assert labels.shape == (1, 5) assert labels[0, 0] == 4 x_center = labels[0, 1] y_center = labels[0, 2] width = labels[0, 3] height = labels[0, 4] assert 0 <= x_center <= 1 assert 0 <= y_center <= 1 assert 0 < width <= 1 assert 0 < height <= 1 def test_convert_labels_all_fields_get_same_expansion(self): """Verify all field types get the same uniform expansion.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 img_width = 2480 img_height = 3508 base_annotation = { 'x_center': 150, 'y_center': 225, 'width': 100, 'height': 50, 'confidence': 0.9 } # All field types should get the same uniform expansion ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)] ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False) bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)] bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False) # x_center and y_center should be the same (uniform padding is symmetric) assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001 assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001 # width and height should also be the same assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001 assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001 def test_convert_labels_clamps_to_image_bounds(self): """Verify labels are clamped to image boundaries.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 annotations = [ YOLOAnnotation( class_id=4, x_center=30, y_center=50, width=40, height=30, confidence=0.9 ) ] img_width = 2480 img_height = 3508 labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) assert 0 <= labels[0, 1] <= 1 assert 0 <= labels[0, 2] <= 1 assert 0 < labels[0, 3] <= 1 assert 0 < labels[0, 4] <= 1 def test_convert_labels_empty_annotations(self): """Verify empty annotations return empty array.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 labels = dataset._convert_labels([], 2480, 3508, is_scanned=False) assert labels.shape == (0, 5) assert labels.dtype == np.float32 def test_convert_labels_minimum_height(self): """Verify minimum height is enforced after expansion.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 50 annotations = [ YOLOAnnotation( class_id=9, x_center=100, y_center=100, width=200, height=5, confidence=0.9 ) ] labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False) min_normalized_height = 50 / 3508 assert labels[0, 4] >= min_normalized_height class TestCreateAnnotationWithClassName: """Tests for _create_annotation storing class_name for expand_bbox lookup.""" def test_create_annotation_stores_class_name(self): """Verify _create_annotation stores class_name for later use.""" dataset = object.__new__(DBYOLODataset) annotation = dataset._create_annotation( field_name="InvoiceNumber", bbox=[100, 200, 200, 250], score=0.9 ) assert annotation.class_id == 0 class TestLoadLabelsFromDbWithClassName: """Tests for _load_labels_from_db preserving field_name for expansion.""" def test_load_labels_maps_field_names_correctly(self): """Verify field names are mapped correctly.""" dataset = object.__new__(DBYOLODataset) dataset.min_confidence = 0.7 mock_db = MagicMock() mock_db.get_documents_batch.return_value = { 'doc1': { 'success': True, 'pdf_type': 'text', 'split': 'train', 'field_results': [ { 'matched': True, 'field_name': 'Bankgiro', 'score': 0.9, 'bbox': [100, 200, 200, 250], 'page_no': 0 }, { 'matched': True, 'field_name': 'supplier_accounts(Plusgiro)', 'score': 0.85, 'bbox': [300, 400, 400, 450], 'page_no': 0 } ] } } dataset.db = mock_db result = dataset._load_labels_from_db(['doc1']) assert 'doc1' in result page_labels, is_scanned, csv_split = result['doc1'] assert 0 in page_labels assert len(page_labels[0]) == 2 assert page_labels[0][0].class_id == 4 assert page_labels[0][1].class_id == 5