"""Tests for db_dataset.py expand_bbox integration.""" import numpy as np import pytest from unittest.mock import MagicMock, patch from pathlib import Path from training.yolo.db_dataset import DBYOLODataset from training.yolo.annotation_generator import YOLOAnnotation from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY from shared.fields import CLASS_NAMES class TestConvertLabelsWithExpandBbox: """Tests for _convert_labels using expand_bbox instead of fixed padding.""" def test_convert_labels_uses_expand_bbox(self): """Verify _convert_labels calls expand_bbox for field-specific expansion.""" # Create a mock dataset without loading from DB dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 # Create annotation for bankgiro (has extra_left_ratio) # bbox in PDF points: x0=100, y0=200, x1=200, y1=250 # center: (150, 225), width: 100, height: 50 annotations = [ YOLOAnnotation( class_id=4, # bankgiro x_center=150, # in PDF points y_center=225, width=100, height=50, confidence=0.9 ) ] # Image size in pixels (at 300 DPI) img_width = 2480 # A4 width at 300 DPI img_height = 3508 # A4 height at 300 DPI # Convert labels labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) # Should have one label assert labels.shape == (1, 5) # Check class_id assert labels[0, 0] == 4 # The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80) # Original bbox at 300 DPI: # x0 = 100 * (300/72) = 416.67 # y0 = 200 * (300/72) = 833.33 # x1 = 200 * (300/72) = 833.33 # y1 = 250 * (300/72) = 1041.67 # width_px = 416.67, height_px = 208.33 # After expand_bbox with bankgiro strategy: # scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80 # The x_center should shift left due to extra_left_ratio x_center = labels[0, 1] y_center = labels[0, 2] width = labels[0, 3] height = labels[0, 4] # Verify normalized values are in valid range assert 0 <= x_center <= 1 assert 0 <= y_center <= 1 assert 0 < width <= 1 assert 0 < height <= 1 # Width should be larger than original due to scaling and extra_left # Original normalized width: 416.67 / 2480 = 0.168 # After bankgiro expansion it should be wider assert width > 0.168 def test_convert_labels_different_field_types(self): """Verify different field types use their specific strategies.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 img_width = 2480 img_height = 3508 # Same bbox for different field types base_annotation = { 'x_center': 150, 'y_center': 225, 'width': 100, 'height': 50, 'confidence': 0.9 } # OCR number (class_id=3) - has extra_top_ratio=0.60 ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)] ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False) # Bankgiro (class_id=4) - has extra_left_ratio=0.80 bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)] bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False) # Amount (class_id=6) - has extra_right_ratio=0.30 amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)] amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False) # Each field type should have different expansion # OCR should expand more vertically (extra_top) # Bankgiro should expand more to the left # Amount should expand more to the right # OCR: extra_top shifts y_center up # Bankgiro: extra_left shifts x_center left # So bankgiro x_center < OCR x_center assert bankgiro_labels[0, 1] < ocr_labels[0, 1] # OCR has higher scale_y (1.80) than amount (1.35) assert ocr_labels[0, 4] > amount_labels[0, 4] def test_convert_labels_clamps_to_image_bounds(self): """Verify labels are clamped to image boundaries.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 # Annotation near edge of image (in PDF points) annotations = [ YOLOAnnotation( class_id=4, # bankgiro - will expand left x_center=30, # Very close to left edge y_center=50, width=40, height=30, confidence=0.9 ) ] img_width = 2480 img_height = 3508 labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False) # All values should be in valid range assert 0 <= labels[0, 1] <= 1 # x_center assert 0 <= labels[0, 2] <= 1 # y_center assert 0 < labels[0, 3] <= 1 # width assert 0 < labels[0, 4] <= 1 # height def test_convert_labels_empty_annotations(self): """Verify empty annotations return empty array.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 30 labels = dataset._convert_labels([], 2480, 3508, is_scanned=False) assert labels.shape == (0, 5) assert labels.dtype == np.float32 def test_convert_labels_minimum_height(self): """Verify minimum height is enforced after expansion.""" dataset = object.__new__(DBYOLODataset) dataset.dpi = 300 dataset.min_bbox_height_px = 50 # Higher minimum # Very small annotation annotations = [ YOLOAnnotation( class_id=9, # payment_line - minimal expansion x_center=100, y_center=100, width=200, height=5, # Very small height confidence=0.9 ) ] labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False) # Height should be at least min_bbox_height_px / img_height min_normalized_height = 50 / 3508 assert labels[0, 4] >= min_normalized_height class TestCreateAnnotationWithClassName: """Tests for _create_annotation storing class_name for expand_bbox lookup.""" def test_create_annotation_stores_class_name(self): """Verify _create_annotation stores class_name for later use.""" dataset = object.__new__(DBYOLODataset) # Create annotation for invoice_number annotation = dataset._create_annotation( field_name="InvoiceNumber", bbox=[100, 200, 200, 250], score=0.9 ) assert annotation.class_id == 0 # invoice_number class_id class TestLoadLabelsFromDbWithClassName: """Tests for _load_labels_from_db preserving field_name for expansion.""" def test_load_labels_maps_field_names_correctly(self): """Verify field names are mapped correctly for expand_bbox.""" dataset = object.__new__(DBYOLODataset) dataset.min_confidence = 0.7 # Mock database mock_db = MagicMock() mock_db.get_documents_batch.return_value = { 'doc1': { 'success': True, 'pdf_type': 'text', 'split': 'train', 'field_results': [ { 'matched': True, 'field_name': 'Bankgiro', 'score': 0.9, 'bbox': [100, 200, 200, 250], 'page_no': 0 }, { 'matched': True, 'field_name': 'supplier_accounts(Plusgiro)', 'score': 0.85, 'bbox': [300, 400, 400, 450], 'page_no': 0 } ] } } dataset.db = mock_db result = dataset._load_labels_from_db(['doc1']) assert 'doc1' in result page_labels, is_scanned, csv_split = result['doc1'] # Should have 2 annotations on page 0 assert 0 in page_labels assert len(page_labels[0]) == 2 # First annotation: Bankgiro (class_id=4) assert page_labels[0][0].class_id == 4 # Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5) assert page_labels[0][1].class_id == 5