This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -400,6 +400,71 @@ class TestAmountNormalizer:
result = normalizer.normalize("Reference 12500")
assert result.value == "12500.00"
def test_payment_line_kronor_ore_format(self, normalizer):
"""Space between kronor and ore should be treated as decimal separator.
Swedish payment lines use space to separate kronor and ore:
"590 00" means 590.00 SEK, NOT 59000.
"""
result = normalizer.normalize("590 00")
assert result.value == "590.00"
assert result.is_valid is True
def test_payment_line_kronor_ore_large_amount(self, normalizer):
"""Large kronor/ore amount from payment line."""
result = normalizer.normalize("15658 00")
assert result.value == "15658.00"
assert result.is_valid is True
def test_payment_line_kronor_ore_with_nonzero_ore(self, normalizer):
"""Kronor/ore with non-zero ore."""
result = normalizer.normalize("736 50")
assert result.value == "736.50"
assert result.is_valid is True
def test_kronor_ore_not_confused_with_thousand_separator(self, normalizer):
"""Amount with comma decimal should NOT trigger kronor/ore pattern."""
result = normalizer.normalize("1 234,56")
assert result.value is not None
# Should parse as 1234.56, not as kronor=1234 ore=56 (which is same value)
assert float(result.value) == 1234.56
def test_european_dot_thousand_separator(self, normalizer):
"""European format: dot as thousand, comma as decimal."""
result = normalizer.normalize("2.254,50")
assert result.value == "2254.50"
assert result.is_valid is True
def test_european_dot_thousand_with_sek(self, normalizer):
"""European format with SEK suffix."""
result = normalizer.normalize("2.254,50 SEK")
assert result.value == "2254.50"
assert result.is_valid is True
def test_european_dot_thousand_with_kr(self, normalizer):
"""European format with kr suffix."""
result = normalizer.normalize("20.485,00 kr")
assert result.value == "20485.00"
assert result.is_valid is True
def test_european_large_amount(self, normalizer):
"""Large European format amount."""
result = normalizer.normalize("1.234.567,89")
assert result.value == "1234567.89"
assert result.is_valid is True
def test_european_in_label_context(self, normalizer):
"""European format inside label text (like the BAUHAUS invoice bug)."""
result = normalizer.normalize("ns Fakturabelopp: 2.254,50 SEK")
assert result.value == "2254.50"
assert result.is_valid is True
def test_anglo_comma_thousand_separator(self, normalizer):
"""Anglo format: comma as thousand, dot as decimal."""
result = normalizer.normalize("1,234.56")
assert result.value == "1234.56"
assert result.is_valid is True
def test_zero_amount_rejected(self, normalizer):
"""Test that zero amounts are rejected."""
result = normalizer.normalize("0,00 kr")
@@ -450,6 +515,18 @@ class TestEnhancedAmountNormalizer:
result = normalizer.normalize("Invoice for 1 234 567,89 kr")
assert result.value is not None
def test_enhanced_kronor_ore_format(self, normalizer):
"""Space between kronor and ore in enhanced normalizer."""
result = normalizer.normalize("590 00")
assert result.value == "590.00"
assert result.is_valid is True
def test_enhanced_kronor_ore_large(self, normalizer):
"""Large kronor/ore amount in enhanced normalizer."""
result = normalizer.normalize("15658 00")
assert result.value == "15658.00"
assert result.is_valid is True
def test_no_amount_fails(self, normalizer):
"""Test failure when no amount found."""
result = normalizer.normalize("no amount")
@@ -472,6 +549,22 @@ class TestEnhancedAmountNormalizer:
result = normalizer.normalize("Price: 1 234 567,89")
assert result.value is not None
def test_enhanced_european_dot_thousand(self, normalizer):
"""European format in enhanced normalizer."""
result = normalizer.normalize("2.254,50 SEK")
assert result.value == "2254.50"
assert result.is_valid is True
def test_enhanced_european_with_label(self, normalizer):
"""European format with Swedish label keyword."""
result = normalizer.normalize("Att betala: 2.254,50")
assert result.value == "2254.50"
def test_enhanced_anglo_format(self, normalizer):
"""Anglo format in enhanced normalizer."""
result = normalizer.normalize("Total: 1,234.56")
assert result.value == "1234.56"
def test_amount_out_of_range_rejected(self, normalizer):
"""Test that amounts >= 10,000,000 are rejected."""
result = normalizer.normalize("Summa: 99 999 999,00")

View File

@@ -497,5 +497,178 @@ class TestExtractBusinessFeaturesErrorHandling:
assert "NumericException" in result.errors[0]
class TestProcessPdfTokenPath:
"""Tests for PDF text token extraction path in process_pdf()."""
def _make_pipeline(self):
"""Create pipeline with mocked internals, bypassing __init__."""
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
p = InferencePipeline()
p.detector = MagicMock()
p.extractor = MagicMock()
p.payment_line_parser = MagicMock()
p.dpi = 300
p.enable_fallback = False
p.enable_business_features = False
p.vat_tolerance = 0.5
p.line_items_extractor = None
p.vat_extractor = None
p.vat_validator = None
p._business_ocr_engine = None
p._table_detector = None
return p
def _make_detection(self, class_name='Amount', confidence=0.85, page_no=0):
"""Create a Detection object."""
from backend.pipeline.yolo_detector import Detection
return Detection(
class_id=6,
class_name=class_name,
confidence=confidence,
bbox=(100.0, 200.0, 300.0, 250.0),
page_no=page_no,
)
def _make_extracted_field(self, field_name='Amount', raw_text='2.254,50',
normalized='2254.50', confidence=0.85):
"""Create an ExtractedField object."""
from backend.pipeline.field_extractor import ExtractedField
return ExtractedField(
field_name=field_name,
raw_text=raw_text,
normalized_value=normalized,
confidence=confidence,
detection_confidence=confidence,
ocr_confidence=1.0,
bbox=(100.0, 200.0, 300.0, 250.0),
page_no=0,
)
def _make_image_bytes(self):
"""Create minimal valid PNG bytes (100x100 white image)."""
from PIL import Image as PILImage
import io as _io
img = PILImage.new('RGB', (100, 100), color='white')
buf = _io.BytesIO()
img.save(buf, format='PNG')
return buf.getvalue()
@patch('shared.pdf.extractor.PDFDocument')
@patch('shared.pdf.renderer.render_pdf_to_images')
def test_text_pdf_uses_pdf_tokens(self, mock_render, mock_pdf_doc_cls):
"""When PDF is text-based, extract_from_detection_with_pdf is used."""
from shared.pdf.extractor import Token
pipeline = self._make_pipeline()
detection = self._make_detection()
image_bytes = self._make_image_bytes()
# Setup PDFDocument mock - text PDF with tokens
mock_pdf_doc = MagicMock()
mock_pdf_doc.is_text_pdf.return_value = True
mock_pdf_doc.page_count = 1
tokens = [Token(text="2.254,50", bbox=(100, 200, 200, 220), page_no=0)]
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
pipeline.detector.detect.return_value = [detection]
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
self._make_extracted_field()
)
mock_render.return_value = iter([(0, image_bytes)])
result = pipeline.process_pdf('/fake/invoice.pdf')
pipeline.extractor.extract_from_detection_with_pdf.assert_called_once()
pipeline.extractor.extract_from_detection.assert_not_called()
assert result.fields.get('Amount') == '2254.50'
assert result.success is True
@patch('shared.pdf.extractor.PDFDocument')
@patch('shared.pdf.renderer.render_pdf_to_images')
def test_scanned_pdf_uses_ocr(self, mock_render, mock_pdf_doc_cls):
"""When PDF is scanned, extract_from_detection (OCR) is used."""
pipeline = self._make_pipeline()
detection = self._make_detection()
image_bytes = self._make_image_bytes()
mock_pdf_doc = MagicMock()
mock_pdf_doc.is_text_pdf.return_value = False
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
pipeline.detector.detect.return_value = [detection]
pipeline.extractor.extract_from_detection.return_value = (
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
)
mock_render.return_value = iter([(0, image_bytes)])
result = pipeline.process_pdf('/fake/invoice.pdf')
pipeline.extractor.extract_from_detection.assert_called_once()
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
@patch('shared.pdf.extractor.PDFDocument')
@patch('shared.pdf.renderer.render_pdf_to_images')
def test_pdf_detection_error_falls_back_to_ocr(self, mock_render, mock_pdf_doc_cls):
"""When PDF text detection throws, fall back to OCR."""
pipeline = self._make_pipeline()
detection = self._make_detection()
image_bytes = self._make_image_bytes()
mock_ctx = MagicMock()
mock_ctx.__enter__ = MagicMock(side_effect=Exception("corrupt PDF"))
mock_ctx.__exit__ = MagicMock(return_value=False)
mock_pdf_doc_cls.return_value = mock_ctx
pipeline.detector.detect.return_value = [detection]
pipeline.extractor.extract_from_detection.return_value = (
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
)
mock_render.return_value = iter([(0, image_bytes)])
result = pipeline.process_pdf('/fake/invoice.pdf')
pipeline.extractor.extract_from_detection.assert_called_once()
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
@patch('shared.pdf.extractor.PDFDocument')
@patch('shared.pdf.renderer.render_pdf_to_images')
def test_text_pdf_passes_correct_args(self, mock_render, mock_pdf_doc_cls):
"""Verify correct token list and image dimensions are passed."""
from shared.pdf.extractor import Token
pipeline = self._make_pipeline()
detection = self._make_detection()
image_bytes = self._make_image_bytes() # 100x100 PNG
mock_pdf_doc = MagicMock()
mock_pdf_doc.is_text_pdf.return_value = True
mock_pdf_doc.page_count = 1
tokens = [
Token(text="Fakturabelopp:", bbox=(50, 190, 100, 210), page_no=0),
Token(text="2.254,50", bbox=(105, 190, 180, 210), page_no=0),
Token(text="SEK", bbox=(185, 190, 210, 210), page_no=0),
]
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
pipeline.detector.detect.return_value = [detection]
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
self._make_extracted_field()
)
mock_render.return_value = iter([(0, image_bytes)])
pipeline.process_pdf('/fake/invoice.pdf')
call_args = pipeline.extractor.extract_from_detection_with_pdf.call_args[0]
assert call_args[0] == detection
assert len(call_args[1]) == 3 # 3 tokens passed
assert call_args[2] == 100 # image width
assert call_args[3] == 100 # image height
if __name__ == '__main__':
pytest.main([__file__, '-v'])

View File

View File

@@ -0,0 +1,318 @@
"""
Tests for ValueSelector -- field-aware OCR token selection.
Verifies that ValueSelector picks the most likely value token(s)
from OCR output, filtering out label text before sending to normalizer.
"""
import pytest
from shared.ocr.paddle_ocr import OCRToken
from backend.pipeline.value_selector import ValueSelector
def _token(text: str) -> OCRToken:
"""Helper to create OCRToken with dummy bbox and confidence."""
return OCRToken(text=text, bbox=(0, 0, 100, 20), confidence=0.95)
def _tokens(*texts: str) -> list[OCRToken]:
"""Helper to create multiple OCRTokens."""
return [_token(t) for t in texts]
class TestValueSelectorDateFields:
"""Tests for date field value selection (InvoiceDate, InvoiceDueDate)."""
def test_selects_iso_date_from_label_and_value(self):
tokens = _tokens("Fakturadatum", "2024-01-15")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) == 1
assert result[0].text == "2024-01-15"
def test_selects_dot_separated_date(self):
tokens = _tokens("Datum", "2024.03.20")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) == 1
assert result[0].text == "2024.03.20"
def test_selects_slash_separated_date(self):
tokens = _tokens("Forfallodag", "15/01/2024")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDueDate")
assert len(result) == 1
assert result[0].text == "15/01/2024"
def test_selects_compact_date(self):
tokens = _tokens("Datum", "20240115")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) == 1
assert result[0].text == "20240115"
def test_fallback_when_no_date_pattern(self):
"""No date pattern found -> return all tokens."""
tokens = _tokens("Fakturadatum", "pending")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) == 2
class TestValueSelectorAmountField:
"""Tests for amount field value selection."""
def test_selects_amount_with_comma_decimal(self):
tokens = _tokens("Belopp", "1 234,56", "kr")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "1 234,56"
def test_selects_amount_with_dot_decimal(self):
tokens = _tokens("Summa", "1234.56")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "1234.56"
def test_selects_simple_amount(self):
tokens = _tokens("Att", "betala", "500,00")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "500,00"
def test_selects_european_amount_with_dot_thousand(self):
"""European format: dot as thousand separator, comma as decimal."""
tokens = _tokens("Fakturabelopp:", "2.254,50 SEK")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "2.254,50 SEK"
def test_selects_european_amount_without_currency(self):
"""European format without currency suffix."""
tokens = _tokens("Belopp", "1.234,56")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "1.234,56"
def test_selects_amount_with_kr_suffix(self):
"""Amount with 'kr' currency suffix."""
tokens = _tokens("Summa", "20.485,00 kr")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "20.485,00 kr"
def test_selects_anglo_amount_with_sek(self):
"""Anglo format with SEK suffix."""
tokens = _tokens("Amount", "1,234.56 SEK")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 1
assert result[0].text == "1,234.56 SEK"
def test_fallback_when_no_amount_pattern(self):
tokens = _tokens("Belopp", "TBD")
result = ValueSelector.select_value_tokens(tokens, "Amount")
assert len(result) == 2
class TestValueSelectorBankgiroField:
"""Tests for Bankgiro field value selection."""
def test_selects_hyphenated_bankgiro(self):
tokens = _tokens("BG:", "123-4567")
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
assert len(result) == 1
assert result[0].text == "123-4567"
def test_selects_bankgiro_digits(self):
tokens = _tokens("Bankgiro", "1234567")
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
assert len(result) == 1
assert result[0].text == "1234567"
def test_selects_eight_digit_bankgiro(self):
tokens = _tokens("Bankgiro:", "12345678")
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
assert len(result) == 1
assert result[0].text == "12345678"
class TestValueSelectorPlusgiroField:
"""Tests for Plusgiro field value selection."""
def test_selects_hyphenated_plusgiro(self):
tokens = _tokens("PG:", "12345-6")
result = ValueSelector.select_value_tokens(tokens, "Plusgiro")
assert len(result) == 1
assert result[0].text == "12345-6"
def test_selects_plusgiro_digits(self):
tokens = _tokens("Plusgiro", "1234567")
result = ValueSelector.select_value_tokens(tokens, "Plusgiro")
assert len(result) == 1
assert result[0].text == "1234567"
class TestValueSelectorOcrField:
"""Tests for OCR reference number field value selection."""
def test_selects_longest_digit_sequence(self):
tokens = _tokens("OCR", "1234567890")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "1234567890"
def test_selects_token_with_most_digits(self):
tokens = _tokens("Ref", "nr", "94228110015950070")
result = ValueSelector.select_value_tokens(tokens, "OCR")
assert len(result) == 1
assert result[0].text == "94228110015950070"
def test_ignores_short_digit_tokens(self):
"""Tokens with fewer than 5 digits are not OCR references."""
tokens = _tokens("OCR", "123")
result = ValueSelector.select_value_tokens(tokens, "OCR")
# Fallback: return all tokens since no valid OCR found
assert len(result) == 2
class TestValueSelectorInvoiceNumberField:
"""Tests for InvoiceNumber field value selection."""
def test_removes_swedish_label_keywords(self):
tokens = _tokens("Fakturanummer", "INV-2024-001")
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
assert len(result) == 1
assert result[0].text == "INV-2024-001"
def test_keeps_non_label_tokens(self):
tokens = _tokens("Nr", "12345")
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
assert len(result) == 1
assert result[0].text == "12345"
def test_multiple_value_tokens_kept(self):
"""Multiple non-label tokens are all kept."""
tokens = _tokens("Fakturanr", "INV", "2024", "001")
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
# "Fakturanr" is a label keyword, the rest are values
result_texts = [t.text for t in result]
assert "Fakturanr" not in result_texts
assert "INV" in result_texts
class TestValueSelectorOrgNumberField:
"""Tests for supplier_org_number field value selection."""
def test_selects_org_number_with_hyphen(self):
tokens = _tokens("Org.nr", "556123-4567")
result = ValueSelector.select_value_tokens(tokens, "supplier_org_number")
assert len(result) == 1
assert result[0].text == "556123-4567"
def test_selects_org_number_without_hyphen(self):
tokens = _tokens("Organisationsnummer", "5561234567")
result = ValueSelector.select_value_tokens(tokens, "supplier_org_number")
assert len(result) == 1
assert result[0].text == "5561234567"
class TestValueSelectorCustomerNumberField:
"""Tests for customer_number field value selection."""
def test_removes_label_keeps_value(self):
tokens = _tokens("Kundnummer", "ABC-123")
result = ValueSelector.select_value_tokens(tokens, "customer_number")
assert len(result) == 1
assert result[0].text == "ABC-123"
class TestValueSelectorPaymentLineField:
"""Tests for payment_line field -- keeps all tokens."""
def test_keeps_all_tokens(self):
tokens = _tokens("#", "94228110015950070", "#", "15658", "00", "8", ">", "48666036#14#")
result = ValueSelector.select_value_tokens(tokens, "payment_line")
assert len(result) == len(tokens)
class TestValueSelectorFallback:
"""Tests for fallback behavior."""
def test_unknown_field_returns_all_tokens(self):
tokens = _tokens("some", "unknown", "text")
result = ValueSelector.select_value_tokens(tokens, "unknown_field")
assert len(result) == 3
def test_empty_tokens_returns_empty(self):
result = ValueSelector.select_value_tokens([], "InvoiceDate")
assert result == []
def test_single_token_returns_it(self):
tokens = _tokens("2024-01-15")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) == 1
def test_never_returns_empty_when_tokens_exist(self):
"""Fallback ensures we never lose data -- always return something."""
tokens = _tokens("Fakturadatum", "unknown_format")
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
assert len(result) > 0

View File

@@ -0,0 +1 @@
"""Tests for backend services."""

View File

@@ -0,0 +1,344 @@
"""
Tests for Data Mixing Service.
Tests cover:
1. get_mixing_ratio boundary values
2. build_mixed_dataset with temp filesystem
3. _find_pool_images matching logic
4. _image_to_label_path conversion
5. Edge cases (empty pool, no old data, cap)
"""
import pytest
from pathlib import Path
from uuid import uuid4
from backend.web.services.data_mixer import (
get_mixing_ratio,
build_mixed_dataset,
_collect_images,
_image_to_label_path,
_find_pool_images,
MIXING_RATIOS,
DEFAULT_MULTIPLIER,
MAX_OLD_SAMPLES,
MIN_POOL_SIZE,
)
# =============================================================================
# Test Constants
# =============================================================================
class TestConstants:
"""Tests for data mixer constants."""
def test_mixing_ratios_defined(self):
"""MIXING_RATIOS should have expected entries."""
assert len(MIXING_RATIOS) == 4
assert MIXING_RATIOS[0] == (10, 50)
assert MIXING_RATIOS[1] == (50, 20)
assert MIXING_RATIOS[2] == (200, 10)
assert MIXING_RATIOS[3] == (500, 5)
def test_default_multiplier(self):
"""DEFAULT_MULTIPLIER should be 5."""
assert DEFAULT_MULTIPLIER == 5
def test_max_old_samples(self):
"""MAX_OLD_SAMPLES should be 3000."""
assert MAX_OLD_SAMPLES == 3000
def test_min_pool_size(self):
"""MIN_POOL_SIZE should be 50."""
assert MIN_POOL_SIZE == 50
# =============================================================================
# Test get_mixing_ratio
# =============================================================================
class TestGetMixingRatio:
"""Tests for get_mixing_ratio function."""
def test_1_sample_returns_50x(self):
"""1 new sample should get 50x old data."""
assert get_mixing_ratio(1) == 50
def test_10_samples_returns_50x(self):
"""10 new samples (boundary) should get 50x."""
assert get_mixing_ratio(10) == 50
def test_11_samples_returns_20x(self):
"""11 new samples should get 20x."""
assert get_mixing_ratio(11) == 20
def test_50_samples_returns_20x(self):
"""50 new samples (boundary) should get 20x."""
assert get_mixing_ratio(50) == 20
def test_51_samples_returns_10x(self):
"""51 new samples should get 10x."""
assert get_mixing_ratio(51) == 10
def test_200_samples_returns_10x(self):
"""200 new samples (boundary) should get 10x."""
assert get_mixing_ratio(200) == 10
def test_201_samples_returns_5x(self):
"""201 new samples should get 5x."""
assert get_mixing_ratio(201) == 5
def test_500_samples_returns_5x(self):
"""500 new samples (boundary) should get 5x."""
assert get_mixing_ratio(500) == 5
def test_1000_samples_returns_default(self):
"""1000+ samples should get default multiplier (5x)."""
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
# =============================================================================
# Test _collect_images
# =============================================================================
class TestCollectImages:
"""Tests for _collect_images function."""
def test_collects_png_files(self, tmp_path: Path):
"""Should collect .png files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.png").write_bytes(b"fake png")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_collects_jpg_files(self, tmp_path: Path):
"""Should collect .jpg files."""
(tmp_path / "img1.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_collects_both_types(self, tmp_path: Path):
"""Should collect both .png and .jpg files."""
(tmp_path / "img1.png").write_bytes(b"fake png")
(tmp_path / "img2.jpg").write_bytes(b"fake jpg")
images = _collect_images(tmp_path)
assert len(images) == 2
def test_ignores_other_files(self, tmp_path: Path):
"""Should ignore non-image files."""
(tmp_path / "data.txt").write_text("not an image")
(tmp_path / "data.yaml").write_text("yaml")
(tmp_path / "img.png").write_bytes(b"png")
images = _collect_images(tmp_path)
assert len(images) == 1
def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path):
"""Should return empty list for nonexistent directory."""
images = _collect_images(tmp_path / "nonexistent")
assert images == []
# =============================================================================
# Test _image_to_label_path
# =============================================================================
class TestImageToLabelPath:
"""Tests for _image_to_label_path function."""
def test_converts_train_image_to_label(self, tmp_path: Path):
"""Should convert images/train/img.png to labels/train/img.txt."""
image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc1_page1.txt"
assert "labels" in str(label_path)
assert "train" in str(label_path)
def test_converts_val_image_to_label(self, tmp_path: Path):
"""Should convert images/val/img.jpg to labels/val/img.txt."""
image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg"
label_path = _image_to_label_path(image_path)
assert label_path.name == "doc2_page3.txt"
assert "labels" in str(label_path)
assert "val" in str(label_path)
# =============================================================================
# Test _find_pool_images
# =============================================================================
class TestFindPoolImages:
"""Tests for _find_pool_images function."""
def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None:
"""Helper to create a dataset structure with images."""
images_dir = base_path / "images" / split
images_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
(images_dir / f"{doc_id}_page1.png").write_bytes(b"img")
(images_dir / f"{doc_id}_page2.png").write_bytes(b"img")
def test_finds_matching_images(self, tmp_path: Path):
"""Should find images matching pool document IDs."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
assert len(images) == 2 # 2 pages for doc_id1
assert all(doc_id1 in str(img) for img in images)
def test_ignores_non_pool_images(self, tmp_path: Path):
"""Should not return images for documents not in pool."""
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
self._create_dataset(tmp_path, [doc_id1, doc_id2])
pool_ids = {doc_id1}
images = _find_pool_images(tmp_path, pool_ids)
# Only doc_id1 images should be found
for img in images:
assert doc_id1 in str(img)
assert doc_id2 not in str(img)
def test_searches_all_splits(self, tmp_path: Path):
"""Should search train, val, and test splits."""
doc_id = str(uuid4())
for split in ("train", "val", "test"):
self._create_dataset(tmp_path, [doc_id], split=split)
images = _find_pool_images(tmp_path, {doc_id})
assert len(images) == 6 # 2 pages * 3 splits
def test_empty_pool_returns_empty(self, tmp_path: Path):
"""Should return empty list for empty pool IDs."""
self._create_dataset(tmp_path, [str(uuid4())])
images = _find_pool_images(tmp_path, set())
assert images == []
# =============================================================================
# Test build_mixed_dataset
# =============================================================================
class TestBuildMixedDataset:
"""Tests for build_mixed_dataset function."""
def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None:
"""Create a base dataset with old training images."""
for split in ("train", "val"):
img_dir = base_path / "images" / split
lbl_dir = base_path / "labels" / split
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8)
for i in range(count):
doc_id = str(uuid4())
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"fake image data")
lbl_file.write_text("0 0.5 0.5 0.1 0.1\n")
def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None:
"""Add pool images to the base dataset."""
img_dir = base_path / "images" / "train"
lbl_dir = base_path / "labels" / "train"
img_dir.mkdir(parents=True, exist_ok=True)
lbl_dir.mkdir(parents=True, exist_ok=True)
for doc_id in doc_ids:
img_file = img_dir / f"{doc_id}_page1.png"
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
img_file.write_bytes(b"pool image data")
lbl_file.write_text("0 0.5 0.5 0.2 0.2\n")
@pytest.fixture
def base_dataset(self, tmp_path: Path) -> Path:
"""Create a base dataset for testing."""
base_path = tmp_path / "base_dataset"
self._setup_base_dataset(base_path, num_old=20)
return base_path
def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path):
"""Should create proper YOLO directory structure."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert (output_dir / "images" / "train").exists()
assert (output_dir / "images" / "val").exists()
assert (output_dir / "labels" / "train").exists()
assert (output_dir / "labels" / "val").exists()
assert (output_dir / "data.yaml").exists()
def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path):
"""Should return correct counts and metadata."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
assert "data_yaml" in result
assert "total_images" in result
assert "old_images" in result
assert "new_images" in result
assert "mixing_ratio" in result
assert result["total_images"] == result["old_images"] + result["new_images"]
def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path):
"""Should use correct mixing ratio based on pool size."""
pool_ids = [uuid4() for _ in range(5)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base_dataset,
output_dir=output_dir,
)
# 5 new samples -> 50x multiplier
assert result["mixing_ratio"] == 50
def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path):
"""Same seed should produce same output."""
pool_ids = [uuid4() for _ in range(3)]
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
out1 = tmp_path / "out1"
out2 = tmp_path / "out2"
r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42)
r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42)
assert r1["old_images"] == r2["old_images"]
assert r1["new_images"] == r2["new_images"]
assert r1["total_images"] == r2["total_images"]

View File

@@ -0,0 +1,540 @@
"""
Unit tests for gating validation service.
Tests the quality gate validation logic for model deployment:
- Gate 1: mAP regression validation
- Gate 2: detection rate validation
- Overall status computation
- Full validation workflow with mocked dependencies
"""
import pytest
from unittest.mock import MagicMock, Mock, patch
from uuid import UUID, uuid4
from backend.web.services.gating_validator import (
GATE1_PASS_THRESHOLD,
GATE1_REVIEW_THRESHOLD,
GATE2_PASS_THRESHOLD,
classify_gate1,
classify_gate2,
compute_overall_status,
run_gating_validation,
)
from backend.data.admin_models import GatingResult
class TestClassifyGate1:
"""Test Gate 1 classification (mAP drop thresholds)."""
def test_pass_below_threshold(self):
"""Test mAP drop < 0.01 returns pass."""
assert classify_gate1(0.009) == "pass"
assert classify_gate1(0.005) == "pass"
assert classify_gate1(0.0) == "pass"
assert classify_gate1(-0.01) == "pass" # negative drop (improvement)
def test_pass_boundary(self):
"""Test mAP drop exactly at pass threshold."""
# 0.01 should be review (not pass), since condition is < 0.01
assert classify_gate1(GATE1_PASS_THRESHOLD) == "review"
def test_review_in_range(self):
"""Test mAP drop in review range [0.01, 0.03)."""
assert classify_gate1(0.01) == "review"
assert classify_gate1(0.015) == "review"
assert classify_gate1(0.02) == "review"
assert classify_gate1(0.029) == "review"
def test_review_boundary(self):
"""Test mAP drop exactly at review threshold."""
# 0.03 should be reject (not review), since condition is < 0.03
assert classify_gate1(GATE1_REVIEW_THRESHOLD) == "reject"
def test_reject_above_threshold(self):
"""Test mAP drop >= 0.03 returns reject."""
assert classify_gate1(0.03) == "reject"
assert classify_gate1(0.05) == "reject"
assert classify_gate1(0.10) == "reject"
assert classify_gate1(1.0) == "reject"
class TestClassifyGate2:
"""Test Gate 2 classification (detection rate thresholds)."""
def test_pass_above_threshold(self):
"""Test detection rate >= 0.80 returns pass."""
assert classify_gate2(0.80) == "pass"
assert classify_gate2(0.85) == "pass"
assert classify_gate2(0.99) == "pass"
assert classify_gate2(1.0) == "pass"
def test_pass_boundary(self):
"""Test detection rate exactly at pass threshold."""
assert classify_gate2(GATE2_PASS_THRESHOLD) == "pass"
def test_review_below_threshold(self):
"""Test detection rate < 0.80 returns review."""
assert classify_gate2(0.79) == "review"
assert classify_gate2(0.75) == "review"
assert classify_gate2(0.50) == "review"
assert classify_gate2(0.0) == "review"
class TestComputeOverallStatus:
"""Test overall status computation from individual gates."""
def test_both_pass(self):
"""Test both gates pass -> overall pass."""
assert compute_overall_status("pass", "pass") == "pass"
def test_gate1_reject_gate2_pass(self):
"""Test any reject -> overall reject."""
assert compute_overall_status("reject", "pass") == "reject"
def test_gate1_pass_gate2_reject(self):
"""Test any reject -> overall reject."""
assert compute_overall_status("pass", "reject") == "reject"
def test_both_reject(self):
"""Test both reject -> overall reject."""
assert compute_overall_status("reject", "reject") == "reject"
def test_gate1_review_gate2_pass(self):
"""Test any review (no reject) -> overall review."""
assert compute_overall_status("review", "pass") == "review"
def test_gate1_pass_gate2_review(self):
"""Test any review (no reject) -> overall review."""
assert compute_overall_status("pass", "review") == "review"
def test_both_review(self):
"""Test both review -> overall review."""
assert compute_overall_status("review", "review") == "review"
def test_gate1_reject_gate2_review(self):
"""Test reject takes precedence over review."""
assert compute_overall_status("reject", "review") == "reject"
def test_gate1_review_gate2_reject(self):
"""Test reject takes precedence over review."""
assert compute_overall_status("review", "reject") == "reject"
class TestRunGatingValidation:
"""Test full gating validation workflow with mocked dependencies."""
@pytest.fixture
def mock_model_version_id(self):
"""Generate a model version ID for testing."""
return uuid4()
@pytest.fixture
def mock_base_model_version_id(self):
"""Generate a base model version ID for testing."""
return uuid4()
@pytest.fixture
def mock_task_id(self):
"""Generate a task ID for testing."""
return uuid4()
@pytest.fixture
def mock_base_model(self):
"""Create a mock base model with metrics."""
model = Mock()
model.metrics_mAP = 0.85
return model
@pytest.fixture
def mock_new_model(self):
"""Create a mock new model with metrics."""
model = Mock()
model.metrics_mAP = 0.82
return model
def test_gate1_pass_gate2_pass(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test validation with both gates passing."""
# Setup: base mAP=0.85, new mAP=0.84 -> drop=0.01 (review)
# But new model mAP=0.82 -> gate2 pass
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.84}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
# Mock repository
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
# Mock session context
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
# Mock YOLO trainer
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
# Execute
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
# Verify
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
assert result.gate1_original_mAP == 0.85
assert result.gate1_new_mAP == 0.84
assert result.gate1_mAP_drop == pytest.approx(0.01, abs=1e-6)
assert result.gate2_status == "pass" # 0.82 >= 0.80
assert result.gate2_detection_rate == 0.82
assert result.overall_status == "review" # Any review -> review
# Verify DB operations
mock_session.add.assert_called()
mock_session.commit.assert_called()
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_gate1_reject_due_to_large_drop(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 reject when mAP drop >= 0.03."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.81} # 0.85 - 0.81 = 0.04 (reject)
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "reject"
assert result.gate1_mAP_drop == pytest.approx(0.04, abs=1e-6)
assert result.overall_status == "reject" # Any reject -> reject
mock_update.assert_called_once_with(str(mock_model_version_id), "reject")
def test_gate2_review_due_to_low_detection_rate(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 2 review when detection rate < 0.80."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.75 # Below 0.80 threshold
mock_val_metrics = {"mAP50": 0.845} # Gate 1: 0.85 - 0.845 = 0.005 (pass)
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass"
assert result.gate2_status == "review" # 0.75 < 0.80
assert result.gate2_detection_rate == 0.75
assert result.overall_status == "review"
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_no_base_model_skips_gate1(
self,
mock_model_version_id,
mock_task_id,
mock_new_model,
):
"""Test Gate 1 passes when no base model is provided."""
mock_new_model.metrics_mAP = 0.85
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.return_value = mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=None,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass" # Skipped
assert result.gate1_original_mAP is None
assert result.gate1_new_mAP is None
assert result.gate1_mAP_drop is None
assert result.gate2_status == "pass" # 0.85 >= 0.80
assert result.overall_status == "pass"
mock_update.assert_called_once_with(str(mock_model_version_id), "pass")
def test_base_model_without_metrics_skips_gate1(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 passes when base model has no metrics."""
mock_base_model.metrics_mAP = None
mock_new_model.metrics_mAP = 0.85
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "pass" # Skipped due to no base metrics
assert result.gate2_status == "pass"
assert result.overall_status == "pass"
def test_validation_failure_marks_gate1_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 review when validation raises exception."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
# Mock trainer to raise exception
mock_trainer = MockTrainer.return_value
mock_trainer.validate.side_effect = RuntimeError("Validation failed")
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # Exception -> review
assert result.gate2_status == "pass"
assert result.overall_status == "review"
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
def test_validation_returns_none_mAP_marks_gate1_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 1 review when validation returns None mAP."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": None} # No mAP returned
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # None mAP -> review
assert result.gate1_new_mAP is None
assert result.gate2_status == "pass"
assert result.overall_status == "review"
def test_gate2_exception_marks_gate2_review(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test Gate 2 review when accessing new model metrics raises exception."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.82
mock_val_metrics = {"mAP50": 0.84}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
# Mock to raise exception for new model on second call
def get_side_effect(id):
if str(id) == str(mock_base_model_version_id):
return mock_base_model
elif str(id) == str(mock_model_version_id):
raise RuntimeError("Cannot fetch new model")
return None
mock_repo.get.side_effect = get_side_effect
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
result = run_gating_validation(
model_version_id=mock_model_version_id,
new_model_path="/path/to/model.pt",
base_model_version_id=mock_base_model_version_id,
data_yaml="/path/to/data.yaml",
task_id=mock_task_id,
)
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
assert result.gate2_status == "review" # Exception -> review
assert result.overall_status == "review"
def test_string_uuids_accepted(
self,
mock_model_version_id,
mock_base_model_version_id,
mock_task_id,
mock_base_model,
mock_new_model,
):
"""Test that string UUIDs are accepted and converted properly."""
mock_base_model.metrics_mAP = 0.85
mock_new_model.metrics_mAP = 0.85
mock_val_metrics = {"mAP50": 0.85}
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
patch("shared.training.YOLOTrainer") as MockTrainer, \
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
mock_repo = MockRepo.return_value
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
mock_session = MagicMock()
mock_session_ctx.return_value.__enter__.return_value = mock_session
mock_trainer = MockTrainer.return_value
mock_trainer.validate.return_value = mock_val_metrics
# Pass string UUIDs
result = run_gating_validation(
model_version_id=str(mock_model_version_id),
new_model_path="/path/to/model.pt",
base_model_version_id=str(mock_base_model_version_id),
data_yaml="/path/to/data.yaml",
task_id=str(mock_task_id),
)
assert result.model_version_id == mock_model_version_id
assert result.task_id == mock_task_id
assert result.overall_status == "pass"

View File

@@ -1,556 +1,170 @@
"""
Tests for expand_bbox function.
Tests for expand_bbox function with uniform pixel padding.
Tests verify that bbox expansion works correctly with center-point scaling,
directional compensation, max padding clamping, and image boundary handling.
Verifies that bbox expansion adds uniform padding on all sides,
clamps to image boundaries, and returns integer tuples.
"""
import pytest
from shared.bbox import (
expand_bbox,
ScaleStrategy,
FIELD_SCALE_STRATEGIES,
DEFAULT_STRATEGY,
)
from shared.bbox import expand_bbox
from shared.bbox.scale_strategy import UNIFORM_PAD
class TestExpandBboxCenterScaling:
"""Tests for center-point based scaling."""
class TestExpandBboxUniformPadding:
"""Tests for uniform padding on all sides."""
def test_center_scaling_expands_symmetrically(self):
"""Verify bbox expands symmetrically around center when no extra ratios."""
# 100x50 bbox at (100, 200)
bbox = (100, 200, 200, 250)
strategy = ScaleStrategy(
scale_x=1.2, # 20% wider
scale_y=1.4, # 40% taller
max_pad_x=1000, # Large to avoid clamping
max_pad_y=1000,
)
def test_adds_uniform_pad_on_all_sides(self):
"""Verify default pad is applied equally on all four sides."""
bbox = (100, 200, 300, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Original: width=100, height=50
# New: width=120, height=70
# Center: (150, 225)
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
assert result[0] == 90 # x0
assert result[1] == 190 # y0
assert result[2] == 210 # x1
assert result[3] == 260 # y1
def test_no_scaling_returns_original(self):
"""Verify scale=1.0 with no extras returns original bbox."""
bbox = (100, 200, 200, 250)
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
max_pad_x=1000,
max_pad_y=1000,
assert result == (
100 - UNIFORM_PAD,
200 - UNIFORM_PAD,
300 + UNIFORM_PAD,
250 + UNIFORM_PAD,
)
def test_custom_pad_value(self):
"""Verify custom pad overrides default."""
bbox = (100, 200, 300, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
pad=20,
)
assert result == (100, 200, 200, 250)
assert result == (80, 180, 320, 270)
class TestExpandBboxDirectionalCompensation:
"""Tests for directional compensation (extra ratios)."""
def test_extra_top_expands_upward(self):
"""Verify extra_top_ratio adds expansion toward top."""
bbox = (100, 200, 200, 250) # width=100, height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.5, # Add 50% of height to top
max_pad_x=1000,
max_pad_y=1000,
)
def test_zero_pad_returns_original(self):
"""Verify pad=0 returns original bbox as integers."""
bbox = (100, 200, 300, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
pad=0,
)
# extra_top = 50 * 0.5 = 25
assert result[0] == 100 # x0 unchanged
assert result[1] == 175 # y0 = 200 - 25
assert result[2] == 200 # x1 unchanged
assert result[3] == 250 # y1 unchanged
assert result == (100, 200, 300, 250)
def test_extra_left_expands_leftward(self):
"""Verify extra_left_ratio adds expansion toward left."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=0.8, # Add 80% of width to left
max_pad_x=1000,
max_pad_y=1000,
)
def test_all_field_types_get_same_padding(self):
"""Verify no field-specific expansion -- same result regardless of field."""
bbox = (100, 200, 300, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
result_a = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
result_b = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
# extra_left = 100 * 0.8 = 80
assert result[0] == 20 # x0 = 100 - 80
assert result[1] == 200 # y0 unchanged
assert result[2] == 200 # x1 unchanged
assert result[3] == 250 # y1 unchanged
def test_extra_right_expands_rightward(self):
"""Verify extra_right_ratio adds expansion toward right."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_right_ratio=0.3, # Add 30% of width to right
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_right = 100 * 0.3 = 30
assert result[0] == 100 # x0 unchanged
assert result[1] == 200 # y0 unchanged
assert result[2] == 230 # x1 = 200 + 30
assert result[3] == 250 # y1 unchanged
def test_extra_bottom_expands_downward(self):
"""Verify extra_bottom_ratio adds expansion toward bottom."""
bbox = (100, 200, 200, 250) # height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_bottom_ratio=0.4, # Add 40% of height to bottom
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# extra_bottom = 50 * 0.4 = 20
assert result[0] == 100 # x0 unchanged
assert result[1] == 200 # y0 unchanged
assert result[2] == 200 # x1 unchanged
assert result[3] == 270 # y1 = 250 + 20
def test_combined_scaling_and_directional(self):
"""Verify scale + directional compensation work together."""
bbox = (100, 200, 200, 250) # width=100, height=50
strategy = ScaleStrategy(
scale_x=1.2, # 20% wider -> 120 width
scale_y=1.0, # no height change
extra_left_ratio=0.5, # Add 50% of width to left
max_pad_x=1000,
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Center: x=150
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
# After extra_left: x0 = 90 - (100 * 0.5) = 40
assert result[0] == 40 # x0
assert result[2] == 210 # x1
class TestExpandBboxMaxPadClamping:
"""Tests for max padding clamping."""
def test_max_pad_x_limits_horizontal_expansion(self):
"""Verify max_pad_x limits expansion on left and right."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=2.0, # Double width (would add 50 each side)
scale_y=1.0,
max_pad_x=30, # Limit to 30 pixels each side
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
# But max_pad_x=30 limits to: x0=70, x1=230
assert result[0] == 70 # x0 = 100 - 30
assert result[2] == 230 # x1 = 200 + 30
def test_max_pad_y_limits_vertical_expansion(self):
"""Verify max_pad_y limits expansion on top and bottom."""
bbox = (100, 200, 200, 250) # height=50
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=3.0, # Triple height (would add 50 each side)
max_pad_x=1000,
max_pad_y=20, # Limit to 20 pixels each side
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Scale would make: y0=175, y1=275 (50px each side)
# But max_pad_y=20 limits to: y0=180, y1=270
assert result[1] == 180 # y0 = 200 - 20
assert result[3] == 270 # y1 = 250 + 20
def test_max_pad_preserves_asymmetry(self):
"""Verify max_pad clamping preserves asymmetric expansion."""
bbox = (100, 200, 200, 250) # width=100
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=1.0, # 100px left expansion
extra_right_ratio=0.0, # No right expansion
max_pad_x=50, # Limit to 50 pixels
max_pad_y=1000,
)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
# Left would expand 100, clamped to 50
# Right stays at 0
assert result[0] == 50 # x0 = 100 - 50
assert result[2] == 200 # x1 unchanged
assert result_a == result_b
class TestExpandBboxImageBoundaryClamping:
"""Tests for image boundary clamping."""
"""Tests for clamping to image boundaries."""
def test_clamps_to_left_boundary(self):
"""Verify x0 is clamped to 0."""
bbox = (10, 200, 110, 250) # Close to left edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_left_ratio=0.5, # Would push x0 below 0
max_pad_x=1000,
max_pad_y=1000,
)
def test_clamps_x0_to_zero(self):
bbox = (5, 200, 100, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
assert result[0] == 0 # Clamped to 0
assert result[0] == 0
def test_clamps_to_top_boundary(self):
"""Verify y0 is clamped to 0."""
bbox = (100, 10, 200, 60) # Close to top edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.5, # Would push y0 below 0
max_pad_x=1000,
max_pad_y=1000,
)
def test_clamps_y0_to_zero(self):
bbox = (100, 3, 300, 50)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
assert result[1] == 0 # Clamped to 0
assert result[1] == 0
def test_clamps_to_right_boundary(self):
"""Verify x1 is clamped to image_width."""
bbox = (900, 200, 990, 250) # Close to right edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_right_ratio=0.5, # Would push x1 beyond image_width
max_pad_x=1000,
max_pad_y=1000,
)
def test_clamps_x1_to_image_width(self):
bbox = (900, 200, 995, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
assert result[2] == 1000 # Clamped to image_width
assert result[2] == 1000
def test_clamps_to_bottom_boundary(self):
"""Verify y1 is clamped to image_height."""
bbox = (100, 940, 200, 990) # Close to bottom edge
strategy = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
max_pad_x=1000,
max_pad_y=1000,
)
def test_clamps_y1_to_image_height(self):
bbox = (100, 900, 300, 995)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test_field",
strategies={"test_field": strategy},
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
assert result[3] == 1000 # Clamped to image_height
assert result[3] == 1000
def test_corner_bbox_clamps_multiple_sides(self):
"""Bbox near top-left corner clamps both x0 and y0."""
bbox = (2, 3, 50, 60)
class TestExpandBboxUnknownField:
"""Tests for unknown field handling."""
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
def test_unknown_field_uses_default_strategy(self):
"""Verify unknown field types use DEFAULT_STRATEGY."""
bbox = (100, 200, 200, 250)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="unknown_field_xyz",
)
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
# Original: width=100, height=50
# New: width=115, height=57.5
# Center: (150, 225)
# x0 = 150 - 57.5 = 92.5 -> 92
# x1 = 150 + 57.5 = 207.5 -> 207
# y0 = 225 - 28.75 = 196.25 -> 196
# y1 = 225 + 28.75 = 253.75 -> 253
# But max_pad_x=50 may clamp...
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
assert result[0] == 92
assert result[2] == 207
class TestExpandBboxWithRealStrategies:
"""Tests using actual FIELD_SCALE_STRATEGIES."""
def test_ocr_number_expands_significantly_upward(self):
"""Verify ocr_number field gets significant upward expansion."""
bbox = (100, 200, 200, 230) # Small height=30
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="ocr_number",
)
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
# y0 should decrease significantly
assert result[1] < 200 - 10 # At least 10px upward expansion
def test_bankgiro_expands_significantly_leftward(self):
"""Verify bankgiro field gets significant leftward expansion."""
bbox = (200, 200, 300, 230) # width=100
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
)
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
# x0 should decrease significantly
assert result[0] < 200 - 30 # At least 30px leftward expansion
def test_amount_expands_rightward(self):
"""Verify amount field gets rightward expansion for currency."""
bbox = (100, 200, 200, 230) # width=100
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="amount",
)
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
# x1 should increase
assert result[2] > 200 + 10 # At least 10px rightward expansion
assert result[0] == 0
assert result[1] == 0
assert result[2] == 50 + UNIFORM_PAD
assert result[3] == 60 + UNIFORM_PAD
class TestExpandBboxReturnType:
"""Tests for return type and value format."""
def test_returns_tuple_of_four_ints(self):
"""Verify return type is tuple of 4 integers."""
bbox = (100.5, 200.3, 200.7, 250.9)
bbox = (100.5, 200.3, 300.7, 250.9)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
assert isinstance(result, tuple)
assert len(result) == 4
assert all(isinstance(v, int) for v in result)
def test_returns_valid_bbox_format(self):
"""Verify returned bbox has x0 < x1 and y0 < y1."""
bbox = (100, 200, 200, 250)
def test_float_bbox_floors_correctly(self):
"""Verify float coordinates are converted to int properly."""
bbox = (100.7, 200.3, 300.2, 250.8)
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=0)
# int() truncates toward zero
assert result == (100, 200, 300, 250)
def test_returns_valid_bbox_ordering(self):
"""Verify x0 < x1 and y0 < y1."""
bbox = (100, 200, 300, 250)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
x0, y0, x1, y1 = result
assert x0 < x1, "x0 should be less than x1"
assert y0 < y1, "y0 should be less than y1"
assert x0 < x1
assert y0 < y1
class TestManualLabelMode:
"""Tests for manual_mode parameter."""
class TestExpandBboxEdgeCases:
"""Tests for edge cases."""
def test_manual_mode_uses_minimal_padding(self):
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
bbox = (100, 200, 200, 250) # width=100, height=50
def test_small_bbox_with_large_pad(self):
"""Pad larger than bbox still works correctly."""
bbox = (100, 200, 105, 203) # 5x3 pixel bbox
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro", # Would normally expand left significantly
manual_mode=True,
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=50)
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
# Should only add 10px padding each side (but scale=1.0 means no scaling)
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
# Only max_pad=10 applies as a limit, but there's no expansion to limit
# So result should be same as original
assert result == (100, 200, 200, 250)
assert result == (50, 150, 155, 253)
def test_manual_mode_ignores_field_type(self):
"""Verify manual_mode ignores field-specific strategies."""
bbox = (100, 200, 200, 250)
def test_bbox_at_origin(self):
bbox = (0, 0, 50, 30)
# Different fields should give same result in manual_mode
result_bankgiro = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
manual_mode=True,
)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
result_ocr = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="ocr_number",
manual_mode=True,
)
assert result[0] == 0
assert result[1] == 0
assert result_bankgiro == result_ocr
def test_bbox_at_image_edge(self):
bbox = (950, 970, 1000, 1000)
def test_manual_mode_vs_auto_mode_different(self):
"""Verify manual_mode produces different results than auto mode."""
bbox = (100, 200, 200, 250)
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
auto_result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro", # Has extra_left_ratio=0.80
manual_mode=False,
)
manual_result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="bankgiro",
manual_mode=True,
)
# Auto mode should expand more (especially to the left for bankgiro)
assert auto_result[0] < manual_result[0] # Auto x0 is more left
def test_manual_mode_clamps_to_image_bounds(self):
"""Verify manual_mode still respects image boundaries."""
bbox = (5, 5, 50, 50) # Close to top-left corner
result = expand_bbox(
bbox=bbox,
image_width=1000,
image_height=1000,
field_type="test",
manual_mode=True,
)
# Should clamp to 0
assert result[0] >= 0
assert result[1] >= 0
assert result[2] == 1000
assert result[3] == 1000

View File

@@ -1,192 +1,24 @@
"""
Tests for ScaleStrategy configuration.
Tests for simplified scale strategy configuration.
Tests verify that scale strategies are properly defined, immutable,
and cover all required fields.
Verifies that UNIFORM_PAD constant is properly defined
and replaces the old field-specific strategies.
"""
import pytest
from shared.bbox import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
from shared.fields import CLASS_NAMES
from shared.bbox.scale_strategy import UNIFORM_PAD
class TestScaleStrategyDataclass:
"""Tests for ScaleStrategy dataclass behavior."""
class TestUniformPad:
"""Tests for UNIFORM_PAD constant."""
def test_default_strategy_values(self):
"""Verify default strategy has expected default values."""
strategy = ScaleStrategy()
assert strategy.scale_x == 1.15
assert strategy.scale_y == 1.15
assert strategy.extra_top_ratio == 0.0
assert strategy.extra_bottom_ratio == 0.0
assert strategy.extra_left_ratio == 0.0
assert strategy.extra_right_ratio == 0.0
assert strategy.max_pad_x == 50
assert strategy.max_pad_y == 50
def test_uniform_pad_is_integer(self):
assert isinstance(UNIFORM_PAD, int)
def test_scale_strategy_immutability(self):
"""Verify ScaleStrategy is frozen (immutable)."""
strategy = ScaleStrategy()
with pytest.raises(AttributeError):
strategy.scale_x = 2.0 # type: ignore
def test_uniform_pad_value_is_15(self):
"""15px at 150 DPI provides ~2.5mm real-world padding."""
assert UNIFORM_PAD == 15
def test_custom_strategy_values(self):
"""Verify custom values are properly set."""
strategy = ScaleStrategy(
scale_x=1.5,
scale_y=1.8,
extra_top_ratio=0.6,
extra_left_ratio=0.8,
max_pad_x=100,
max_pad_y=150,
)
assert strategy.scale_x == 1.5
assert strategy.scale_y == 1.8
assert strategy.extra_top_ratio == 0.6
assert strategy.extra_left_ratio == 0.8
assert strategy.max_pad_x == 100
assert strategy.max_pad_y == 150
class TestDefaultStrategy:
"""Tests for DEFAULT_STRATEGY constant."""
def test_default_strategy_is_scale_strategy(self):
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
def test_default_strategy_matches_default_values(self):
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
expected = ScaleStrategy()
assert DEFAULT_STRATEGY == expected
class TestManualLabelStrategy:
"""Tests for MANUAL_LABEL_STRATEGY constant."""
def test_manual_label_strategy_is_scale_strategy(self):
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
def test_manual_label_strategy_has_no_scaling(self):
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
def test_manual_label_strategy_has_no_directional_expansion(self):
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
def test_manual_label_strategy_has_small_max_pad(self):
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
class TestFieldScaleStrategies:
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
def test_all_class_names_have_strategies(self):
"""Verify all field class names have defined strategies."""
for class_name in CLASS_NAMES:
assert class_name in FIELD_SCALE_STRATEGIES, (
f"Missing strategy for field: {class_name}"
)
def test_strategies_are_scale_strategy_instances(self):
"""Verify all strategies are ScaleStrategy instances."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert isinstance(strategy, ScaleStrategy), (
f"Strategy for {field_name} is not a ScaleStrategy"
)
def test_scale_values_are_greater_than_one(self):
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.scale_x >= 1.0, (
f"{field_name} scale_x should be >= 1.0"
)
assert strategy.scale_y >= 1.0, (
f"{field_name} scale_y should be >= 1.0"
)
def test_extra_ratios_are_non_negative(self):
"""Verify all extra ratios are >= 0."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.extra_top_ratio >= 0, (
f"{field_name} extra_top_ratio should be >= 0"
)
assert strategy.extra_bottom_ratio >= 0, (
f"{field_name} extra_bottom_ratio should be >= 0"
)
assert strategy.extra_left_ratio >= 0, (
f"{field_name} extra_left_ratio should be >= 0"
)
assert strategy.extra_right_ratio >= 0, (
f"{field_name} extra_right_ratio should be >= 0"
)
def test_max_pad_values_are_positive(self):
"""Verify all max_pad values are > 0."""
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
assert strategy.max_pad_x > 0, (
f"{field_name} max_pad_x should be > 0"
)
assert strategy.max_pad_y > 0, (
f"{field_name} max_pad_y should be > 0"
)
class TestSpecificFieldStrategies:
"""Tests for specific field strategy configurations."""
def test_ocr_number_expands_upward(self):
"""Verify ocr_number strategy expands upward to capture label."""
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
assert strategy.extra_top_ratio > 0.0
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
def test_bankgiro_expands_leftward(self):
"""Verify bankgiro strategy expands leftward to capture prefix."""
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
assert strategy.extra_left_ratio > 0.0
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
def test_plusgiro_expands_leftward(self):
"""Verify plusgiro strategy expands leftward to capture prefix."""
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
assert strategy.extra_left_ratio > 0.0
assert strategy.extra_left_ratio >= 0.5
def test_amount_expands_rightward(self):
"""Verify amount strategy expands rightward for currency symbol."""
strategy = FIELD_SCALE_STRATEGIES["amount"]
assert strategy.extra_right_ratio > 0.0
def test_invoice_date_expands_upward(self):
"""Verify invoice_date strategy expands upward to capture label."""
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
assert strategy.extra_top_ratio > 0.0
def test_invoice_due_date_expands_upward_and_leftward(self):
"""Verify invoice_due_date strategy expands both up and left."""
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
assert strategy.extra_top_ratio > 0.0
assert strategy.extra_left_ratio > 0.0
def test_payment_line_has_minimal_expansion(self):
"""Verify payment_line has conservative expansion (machine code)."""
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
# Payment line is machine-readable, needs minimal expansion
assert strategy.scale_x <= 1.2
assert strategy.scale_y <= 1.3
def test_uniform_pad_is_positive(self):
assert UNIFORM_PAD > 0

View File

@@ -171,8 +171,8 @@ class TestGenerateFromMatches:
assert len(annotations) == 0
def test_applies_field_specific_expansion(self):
"""Verify different fields get different expansion."""
def test_applies_uniform_expansion(self):
"""Verify all fields get the same uniform expansion."""
gen = AnnotationGenerator(min_confidence=0.5)
# Same bbox, different fields
@@ -199,10 +199,11 @@ class TestGenerateFromMatches:
dpi=150
)[0]
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
# They should have different widths due to different expansion
# Bankgiro expands more to the left
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
# Uniform expansion: same bbox -> same dimensions (only class_id differs)
assert ann_bankgiro.width == ann_invoice.width
assert ann_bankgiro.height == ann_invoice.height
assert ann_bankgiro.x_center == ann_invoice.x_center
assert ann_bankgiro.y_center == ann_invoice.y_center
def test_enforces_min_bbox_height(self):
"""Verify minimum bbox height is enforced."""

View File

@@ -7,27 +7,23 @@ from pathlib import Path
from training.yolo.db_dataset import DBYOLODataset
from training.yolo.annotation_generator import YOLOAnnotation
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
from shared.bbox import UNIFORM_PAD
from shared.fields import CLASS_NAMES
class TestConvertLabelsWithExpandBbox:
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
"""Tests for _convert_labels using uniform expand_bbox."""
def test_convert_labels_uses_expand_bbox(self):
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
# Create a mock dataset without loading from DB
"""Verify _convert_labels calls expand_bbox with uniform padding."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Create annotation for bankgiro (has extra_left_ratio)
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
# center: (150, 225), width: 100, height: 50
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro
x_center=150, # in PDF points
x_center=150,
y_center=225,
width=100,
height=50,
@@ -35,48 +31,26 @@ class TestConvertLabelsWithExpandBbox:
)
]
# Image size in pixels (at 300 DPI)
img_width = 2480 # A4 width at 300 DPI
img_height = 3508 # A4 height at 300 DPI
img_width = 2480
img_height = 3508
# Convert labels
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# Should have one label
assert labels.shape == (1, 5)
# Check class_id
assert labels[0, 0] == 4
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
# Original bbox at 300 DPI:
# x0 = 100 * (300/72) = 416.67
# y0 = 200 * (300/72) = 833.33
# x1 = 200 * (300/72) = 833.33
# y1 = 250 * (300/72) = 1041.67
# width_px = 416.67, height_px = 208.33
# After expand_bbox with bankgiro strategy:
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
# The x_center should shift left due to extra_left_ratio
x_center = labels[0, 1]
y_center = labels[0, 2]
width = labels[0, 3]
height = labels[0, 4]
# Verify normalized values are in valid range
assert 0 <= x_center <= 1
assert 0 <= y_center <= 1
assert 0 < width <= 1
assert 0 < height <= 1
# Width should be larger than original due to scaling and extra_left
# Original normalized width: 416.67 / 2480 = 0.168
# After bankgiro expansion it should be wider
assert width > 0.168
def test_convert_labels_different_field_types(self):
"""Verify different field types use their specific strategies."""
def test_convert_labels_all_fields_get_same_expansion(self):
"""Verify all field types get the same uniform expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 30
@@ -84,7 +58,6 @@ class TestConvertLabelsWithExpandBbox:
img_width = 2480
img_height = 3508
# Same bbox for different field types
base_annotation = {
'x_center': 150,
'y_center': 225,
@@ -93,30 +66,20 @@ class TestConvertLabelsWithExpandBbox:
'confidence': 0.9
}
# OCR number (class_id=3) - has extra_top_ratio=0.60
# All field types should get the same uniform expansion
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
# Amount (class_id=6) - has extra_right_ratio=0.30
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
# x_center and y_center should be the same (uniform padding is symmetric)
assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001
assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001
# Each field type should have different expansion
# OCR should expand more vertically (extra_top)
# Bankgiro should expand more to the left
# Amount should expand more to the right
# OCR: extra_top shifts y_center up
# Bankgiro: extra_left shifts x_center left
# So bankgiro x_center < OCR x_center
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
# OCR has higher scale_y (1.80) than amount (1.35)
assert ocr_labels[0, 4] > amount_labels[0, 4]
# width and height should also be the same
assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001
assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001
def test_convert_labels_clamps_to_image_bounds(self):
"""Verify labels are clamped to image boundaries."""
@@ -124,11 +87,10 @@ class TestConvertLabelsWithExpandBbox:
dataset.dpi = 300
dataset.min_bbox_height_px = 30
# Annotation near edge of image (in PDF points)
annotations = [
YOLOAnnotation(
class_id=4, # bankgiro - will expand left
x_center=30, # Very close to left edge
class_id=4,
x_center=30,
y_center=50,
width=40,
height=30,
@@ -141,11 +103,10 @@ class TestConvertLabelsWithExpandBbox:
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
# All values should be in valid range
assert 0 <= labels[0, 1] <= 1 # x_center
assert 0 <= labels[0, 2] <= 1 # y_center
assert 0 < labels[0, 3] <= 1 # width
assert 0 < labels[0, 4] <= 1 # height
assert 0 <= labels[0, 1] <= 1
assert 0 <= labels[0, 2] <= 1
assert 0 < labels[0, 3] <= 1
assert 0 < labels[0, 4] <= 1
def test_convert_labels_empty_annotations(self):
"""Verify empty annotations return empty array."""
@@ -162,23 +123,21 @@ class TestConvertLabelsWithExpandBbox:
"""Verify minimum height is enforced after expansion."""
dataset = object.__new__(DBYOLODataset)
dataset.dpi = 300
dataset.min_bbox_height_px = 50 # Higher minimum
dataset.min_bbox_height_px = 50
# Very small annotation
annotations = [
YOLOAnnotation(
class_id=9, # payment_line - minimal expansion
class_id=9,
x_center=100,
y_center=100,
width=200,
height=5, # Very small height
height=5,
confidence=0.9
)
]
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
# Height should be at least min_bbox_height_px / img_height
min_normalized_height = 50 / 3508
assert labels[0, 4] >= min_normalized_height
@@ -190,25 +149,23 @@ class TestCreateAnnotationWithClassName:
"""Verify _create_annotation stores class_name for later use."""
dataset = object.__new__(DBYOLODataset)
# Create annotation for invoice_number
annotation = dataset._create_annotation(
field_name="InvoiceNumber",
bbox=[100, 200, 200, 250],
score=0.9
)
assert annotation.class_id == 0 # invoice_number class_id
assert annotation.class_id == 0
class TestLoadLabelsFromDbWithClassName:
"""Tests for _load_labels_from_db preserving field_name for expansion."""
def test_load_labels_maps_field_names_correctly(self):
"""Verify field names are mapped correctly for expand_bbox."""
"""Verify field names are mapped correctly."""
dataset = object.__new__(DBYOLODataset)
dataset.min_confidence = 0.7
# Mock database
mock_db = MagicMock()
mock_db.get_documents_batch.return_value = {
'doc1': {
@@ -240,12 +197,7 @@ class TestLoadLabelsFromDbWithClassName:
assert 'doc1' in result
page_labels, is_scanned, csv_split = result['doc1']
# Should have 2 annotations on page 0
assert 0 in page_labels
assert len(page_labels[0]) == 2
# First annotation: Bankgiro (class_id=4)
assert page_labels[0][0].class_id == 4
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
assert page_labels[0][1].class_id == 5

View File

@@ -53,7 +53,7 @@ class TestTrainingConfigSchema:
"""Test default training configuration."""
config = TrainingConfig()
assert config.model_name == "yolo11n.pt"
assert config.model_name == "yolo26s.pt"
assert config.epochs == 100
assert config.batch_size == 16
assert config.image_size == 640
@@ -63,7 +63,7 @@ class TestTrainingConfigSchema:
def test_custom_config(self):
"""Test custom training configuration."""
config = TrainingConfig(
model_name="yolo11s.pt",
model_name="yolo26s.pt",
epochs=50,
batch_size=8,
image_size=416,
@@ -71,7 +71,7 @@ class TestTrainingConfigSchema:
device="cpu",
)
assert config.model_name == "yolo11s.pt"
assert config.model_name == "yolo26s.pt"
assert config.epochs == 50
assert config.batch_size == 8
@@ -136,7 +136,7 @@ class TestTrainingTaskModel:
def test_task_with_config(self):
"""Test task with configuration."""
config = {
"model_name": "yolo11n.pt",
"model_name": "yolo26s.pt",
"epochs": 100,
}
task = TrainingTask(

View File

@@ -0,0 +1,784 @@
"""
Comprehensive unit tests for Data Mixing Service.
Tests the data mixing service functions for YOLO fine-tuning:
- Mixing ratio calculation based on sample counts
- Dataset building with old/new sample mixing
- Image collection and path conversion
- Pool document matching
"""
from pathlib import Path
from uuid import UUID, uuid4
import pytest
from backend.web.services.data_mixer import (
get_mixing_ratio,
build_mixed_dataset,
_collect_images,
_image_to_label_path,
_find_pool_images,
MIXING_RATIOS,
DEFAULT_MULTIPLIER,
MAX_OLD_SAMPLES,
MIN_POOL_SIZE,
)
class TestGetMixingRatio:
"""Tests for get_mixing_ratio function."""
def test_mixing_ratio_at_first_threshold(self):
"""Test mixing ratio at first threshold boundary (10 samples)."""
assert get_mixing_ratio(1) == 50
assert get_mixing_ratio(5) == 50
assert get_mixing_ratio(10) == 50
def test_mixing_ratio_at_second_threshold(self):
"""Test mixing ratio at second threshold boundary (50 samples)."""
assert get_mixing_ratio(11) == 20
assert get_mixing_ratio(30) == 20
assert get_mixing_ratio(50) == 20
def test_mixing_ratio_at_third_threshold(self):
"""Test mixing ratio at third threshold boundary (200 samples)."""
assert get_mixing_ratio(51) == 10
assert get_mixing_ratio(100) == 10
assert get_mixing_ratio(200) == 10
def test_mixing_ratio_at_fourth_threshold(self):
"""Test mixing ratio at fourth threshold boundary (500 samples)."""
assert get_mixing_ratio(201) == 5
assert get_mixing_ratio(350) == 5
assert get_mixing_ratio(500) == 5
def test_mixing_ratio_above_all_thresholds(self):
"""Test mixing ratio for samples above all thresholds."""
assert get_mixing_ratio(501) == DEFAULT_MULTIPLIER
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
assert get_mixing_ratio(10000) == DEFAULT_MULTIPLIER
def test_mixing_ratio_boundary_values(self):
"""Test exact threshold boundaries match expected ratios."""
# Verify threshold boundaries from MIXING_RATIOS
for threshold, expected_multiplier in MIXING_RATIOS:
assert get_mixing_ratio(threshold) == expected_multiplier
# One above threshold should give next ratio
if threshold < MIXING_RATIOS[-1][0]:
next_idx = MIXING_RATIOS.index((threshold, expected_multiplier)) + 1
next_multiplier = MIXING_RATIOS[next_idx][1]
assert get_mixing_ratio(threshold + 1) == next_multiplier
class TestCollectImages:
"""Tests for _collect_images function."""
def test_collect_images_empty_directory(self, tmp_path):
"""Test collecting images from empty directory."""
images_dir = tmp_path / "images"
images_dir.mkdir()
result = _collect_images(images_dir)
assert result == []
def test_collect_images_nonexistent_directory(self, tmp_path):
"""Test collecting images from non-existent directory."""
images_dir = tmp_path / "nonexistent"
result = _collect_images(images_dir)
assert result == []
def test_collect_png_images(self, tmp_path):
"""Test collecting PNG images."""
images_dir = tmp_path / "images"
images_dir.mkdir()
# Create PNG files
(images_dir / "img1.png").touch()
(images_dir / "img2.png").touch()
(images_dir / "img3.png").touch()
result = _collect_images(images_dir)
assert len(result) == 3
assert all(img.suffix == ".png" for img in result)
# Verify sorted order
assert result == sorted(result)
def test_collect_jpg_images(self, tmp_path):
"""Test collecting JPG images."""
images_dir = tmp_path / "images"
images_dir.mkdir()
# Create JPG files
(images_dir / "img1.jpg").touch()
(images_dir / "img2.jpg").touch()
result = _collect_images(images_dir)
assert len(result) == 2
assert all(img.suffix == ".jpg" for img in result)
def test_collect_mixed_image_types(self, tmp_path):
"""Test collecting both PNG and JPG images."""
images_dir = tmp_path / "images"
images_dir.mkdir()
# Create mixed files
(images_dir / "img1.png").touch()
(images_dir / "img2.jpg").touch()
(images_dir / "img3.png").touch()
(images_dir / "img4.jpg").touch()
result = _collect_images(images_dir)
assert len(result) == 4
# PNG files should come first (sorted separately)
png_files = [r for r in result if r.suffix == ".png"]
jpg_files = [r for r in result if r.suffix == ".jpg"]
assert len(png_files) == 2
assert len(jpg_files) == 2
def test_collect_images_ignores_other_files(self, tmp_path):
"""Test that non-image files are ignored."""
images_dir = tmp_path / "images"
images_dir.mkdir()
# Create various files
(images_dir / "img1.png").touch()
(images_dir / "img2.jpg").touch()
(images_dir / "doc.txt").touch()
(images_dir / "data.json").touch()
(images_dir / "notes.md").touch()
result = _collect_images(images_dir)
assert len(result) == 2
assert all(img.suffix in [".png", ".jpg"] for img in result)
class TestImageToLabelPath:
"""Tests for _image_to_label_path function."""
def test_image_to_label_path_train(self, tmp_path):
"""Test converting train image path to label path."""
base = tmp_path / "dataset"
image_path = base / "images" / "train" / "doc123_page1.png"
label_path = _image_to_label_path(image_path)
expected = base / "labels" / "train" / "doc123_page1.txt"
assert label_path == expected
def test_image_to_label_path_val(self, tmp_path):
"""Test converting val image path to label path."""
base = tmp_path / "dataset"
image_path = base / "images" / "val" / "doc456_page2.jpg"
label_path = _image_to_label_path(image_path)
expected = base / "labels" / "val" / "doc456_page2.txt"
assert label_path == expected
def test_image_to_label_path_test(self, tmp_path):
"""Test converting test image path to label path."""
base = tmp_path / "dataset"
image_path = base / "images" / "test" / "doc789_page3.png"
label_path = _image_to_label_path(image_path)
expected = base / "labels" / "test" / "doc789_page3.txt"
assert label_path == expected
def test_image_to_label_path_preserves_filename(self, tmp_path):
"""Test that filename (without extension) is preserved."""
base = tmp_path / "dataset"
image_path = base / "images" / "train" / "complex_filename_123_page5.png"
label_path = _image_to_label_path(image_path)
assert label_path.stem == "complex_filename_123_page5"
assert label_path.suffix == ".txt"
def test_image_to_label_path_jpg_to_txt(self, tmp_path):
"""Test that JPG extension is converted to TXT."""
base = tmp_path / "dataset"
image_path = base / "images" / "train" / "image.jpg"
label_path = _image_to_label_path(image_path)
assert label_path.suffix == ".txt"
class TestFindPoolImages:
"""Tests for _find_pool_images function."""
def test_find_pool_images_in_train(self, tmp_path):
"""Test finding pool images in train split."""
base = tmp_path / "dataset"
train_dir = base / "images" / "train"
train_dir.mkdir(parents=True)
doc_id = str(uuid4())
pool_doc_ids = {doc_id}
# Create images
(train_dir / f"{doc_id}_page1.png").touch()
(train_dir / f"{doc_id}_page2.png").touch()
(train_dir / "other_doc_page1.png").touch()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 2
assert all(doc_id in str(img) for img in result)
def test_find_pool_images_in_val(self, tmp_path):
"""Test finding pool images in val split."""
base = tmp_path / "dataset"
val_dir = base / "images" / "val"
val_dir.mkdir(parents=True)
doc_id = str(uuid4())
pool_doc_ids = {doc_id}
# Create images
(val_dir / f"{doc_id}_page1.png").touch()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 1
assert doc_id in str(result[0])
def test_find_pool_images_across_splits(self, tmp_path):
"""Test finding pool images across train, val, and test splits."""
base = tmp_path / "dataset"
doc_id1 = str(uuid4())
doc_id2 = str(uuid4())
pool_doc_ids = {doc_id1, doc_id2}
# Create images in different splits
train_dir = base / "images" / "train"
val_dir = base / "images" / "val"
test_dir = base / "images" / "test"
train_dir.mkdir(parents=True)
val_dir.mkdir(parents=True)
test_dir.mkdir(parents=True)
(train_dir / f"{doc_id1}_page1.png").touch()
(val_dir / f"{doc_id1}_page2.png").touch()
(test_dir / f"{doc_id2}_page1.png").touch()
(train_dir / "other_doc_page1.png").touch()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 3
doc1_images = [img for img in result if doc_id1 in str(img)]
doc2_images = [img for img in result if doc_id2 in str(img)]
assert len(doc1_images) == 2
assert len(doc2_images) == 1
def test_find_pool_images_empty_pool(self, tmp_path):
"""Test finding images with empty pool."""
base = tmp_path / "dataset"
train_dir = base / "images" / "train"
train_dir.mkdir(parents=True)
(train_dir / "doc123_page1.png").touch()
result = _find_pool_images(base, set())
assert len(result) == 0
def test_find_pool_images_no_matches(self, tmp_path):
"""Test finding images when no documents match pool."""
base = tmp_path / "dataset"
train_dir = base / "images" / "train"
train_dir.mkdir(parents=True)
pool_doc_ids = {str(uuid4())}
(train_dir / "other_doc_page1.png").touch()
(train_dir / "another_doc_page1.png").touch()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 0
def test_find_pool_images_multiple_pages(self, tmp_path):
"""Test finding multiple pages for same document."""
base = tmp_path / "dataset"
train_dir = base / "images" / "train"
train_dir.mkdir(parents=True)
doc_id = str(uuid4())
pool_doc_ids = {doc_id}
# Create multiple pages
for i in range(1, 6):
(train_dir / f"{doc_id}_page{i}.png").touch()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 5
def test_find_pool_images_ignores_non_files(self, tmp_path):
"""Test that directories are ignored."""
base = tmp_path / "dataset"
train_dir = base / "images" / "train"
train_dir.mkdir(parents=True)
doc_id = str(uuid4())
pool_doc_ids = {doc_id}
(train_dir / f"{doc_id}_page1.png").touch()
(train_dir / "subdir").mkdir()
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 1
def test_find_pool_images_nonexistent_splits(self, tmp_path):
"""Test handling non-existent split directories."""
base = tmp_path / "dataset"
# Don't create any directories
pool_doc_ids = {str(uuid4())}
result = _find_pool_images(base, pool_doc_ids)
assert len(result) == 0
class TestBuildMixedDataset:
"""Tests for build_mixed_dataset function."""
@pytest.fixture
def setup_base_dataset(self, tmp_path):
"""Create a base dataset with old training data."""
base = tmp_path / "base_dataset"
# Create directory structure
for split in ("train", "val"):
(base / "images" / split).mkdir(parents=True)
(base / "labels" / split).mkdir(parents=True)
# Create old training images and labels
for i in range(1, 11):
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
img_path.write_text(f"image {i}")
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
for i in range(1, 6):
img_path = base / "images" / "val" / f"old_doc_val_{i}_page1.png"
label_path = base / "labels" / "val" / f"old_doc_val_{i}_page1.txt"
img_path.write_text(f"val image {i}")
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
return base
@pytest.fixture
def setup_pool_documents(self, tmp_path, setup_base_dataset):
"""Create pool documents in base dataset."""
base = setup_base_dataset
pool_ids = [uuid4() for _ in range(5)]
# Add pool documents to train split
for doc_id in pool_ids:
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
img_path.write_text(f"pool image {doc_id}")
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
return base, pool_ids
def test_build_mixed_dataset_basic(self, tmp_path, setup_pool_documents):
"""Test basic mixed dataset building."""
base, pool_ids = setup_pool_documents
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Verify result structure
assert "data_yaml" in result
assert "total_images" in result
assert "old_images" in result
assert "new_images" in result
assert "mixing_ratio" in result
# Verify counts - new images should be > 0 (at least some were copied)
# Note: new images are split 80/20 and copied without overwriting
assert result["new_images"] > 0
assert result["old_images"] > 0
assert result["total_images"] == result["old_images"] + result["new_images"]
# Verify output structure
assert output_dir.exists()
assert (output_dir / "images" / "train").exists()
assert (output_dir / "images" / "val").exists()
assert (output_dir / "labels" / "train").exists()
assert (output_dir / "labels" / "val").exists()
# Verify data.yaml exists
yaml_path = Path(result["data_yaml"])
assert yaml_path.exists()
yaml_content = yaml_path.read_text()
assert "train: images/train" in yaml_content
assert "val: images/val" in yaml_content
assert "nc:" in yaml_content
assert "names:" in yaml_content
def test_build_mixed_dataset_respects_mixing_ratio(self, tmp_path, setup_pool_documents):
"""Test that mixing ratio is correctly applied."""
base, pool_ids = setup_pool_documents
output_dir = tmp_path / "mixed_output"
# With 5 pool documents, get_mixing_ratio(5) returns 50
# (because 5 <= 10, first threshold)
# So target old_samples = 5 * 50 = 250
# But limited by available data: 10 old train + 5 old val + 5 pool = 20 total
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Pool images are in the base dataset, so they can be sampled as "old"
# Total available: 20 images (15 pure old + 5 pool images)
assert result["old_images"] <= 20 # Can't exceed available in base dataset
assert result["old_images"] > 0 # Should have some old data
assert result["mixing_ratio"] == 50 # Correct ratio for 5 samples
def test_build_mixed_dataset_max_old_samples_limit(self, tmp_path):
"""Test that MAX_OLD_SAMPLES limit is applied."""
base = tmp_path / "base_dataset"
# Create directory structure
for split in ("train", "val"):
(base / "images" / split).mkdir(parents=True)
(base / "labels" / split).mkdir(parents=True)
# Create MORE than MAX_OLD_SAMPLES old images
for i in range(MAX_OLD_SAMPLES + 500):
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
img_path.write_text(f"image {i}")
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
# Create pool documents (100 samples, ratio=10, so target=1000)
# But should be capped at MAX_OLD_SAMPLES (3000)
pool_ids = [uuid4() for _ in range(100)]
for doc_id in pool_ids:
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
img_path.write_text(f"pool image {doc_id}")
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Should be capped at MAX_OLD_SAMPLES
assert result["old_images"] <= MAX_OLD_SAMPLES
def test_build_mixed_dataset_empty_pool(self, tmp_path, setup_base_dataset):
"""Test building dataset with empty pool."""
base = setup_base_dataset
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=[],
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# With 0 new samples, all counts should be 0
assert result["new_images"] == 0
assert result["old_images"] == 0
assert result["total_images"] == 0
def test_build_mixed_dataset_no_old_data(self, tmp_path):
"""Test building dataset with ONLY pool data (no separate old data)."""
base = tmp_path / "base_dataset"
# Create empty directory structure
for split in ("train", "val"):
(base / "images" / split).mkdir(parents=True)
(base / "labels" / split).mkdir(parents=True)
# Create only pool documents
# NOTE: These are placed in base dataset train split
# So they will be sampled as "old" data first, then skipped as "new"
pool_ids = [uuid4() for _ in range(5)]
for doc_id in pool_ids:
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
img_path.write_text(f"pool image {doc_id}")
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Pool images are in base dataset, so they get sampled as "old" images
# Then when copying "new" images, they're skipped because they already exist
# So we expect: old_images > 0, new_images may be 0, total >= 0
assert result["total_images"] > 0
assert result["total_images"] == result["old_images"] + result["new_images"]
def test_build_mixed_dataset_train_val_split(self, tmp_path, setup_pool_documents):
"""Test that images are split into train/val (80/20)."""
base, pool_ids = setup_pool_documents
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Count images in train and val
train_images = list((output_dir / "images" / "train").glob("*.png"))
val_images = list((output_dir / "images" / "val").glob("*.png"))
total_output_images = len(train_images) + len(val_images)
# Should match total_images count
assert total_output_images == result["total_images"]
# Check approximate 80/20 split (allow some variance due to small sample size)
if total_output_images > 0:
train_ratio = len(train_images) / total_output_images
assert 0.6 <= train_ratio <= 0.9 # Allow some variance
def test_build_mixed_dataset_reproducible_with_seed(self, tmp_path, setup_pool_documents):
"""Test that same seed produces same results."""
base, pool_ids = setup_pool_documents
output_dir1 = tmp_path / "mixed_output1"
output_dir2 = tmp_path / "mixed_output2"
result1 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir1,
seed=123,
)
result2 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir2,
seed=123,
)
# Same counts
assert result1["old_images"] == result2["old_images"]
assert result1["new_images"] == result2["new_images"]
# Same files in train/val
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
assert train_files1 == train_files2
def test_build_mixed_dataset_different_seeds(self, tmp_path, setup_pool_documents):
"""Test that different seeds produce different sampling."""
base, pool_ids = setup_pool_documents
output_dir1 = tmp_path / "mixed_output1"
output_dir2 = tmp_path / "mixed_output2"
result1 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir1,
seed=123,
)
result2 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir2,
seed=456,
)
# Both should have processed images
assert result1["total_images"] > 0
assert result2["total_images"] > 0
# Both should have the same mixing ratio (based on pool size)
assert result1["mixing_ratio"] == result2["mixing_ratio"]
# File distribution in train/val may differ due to different shuffling
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
# With different seeds, we expect some difference in file distribution
# But this is not strictly guaranteed, so we just verify both have files
assert len(train_files1) > 0
assert len(train_files2) > 0
def test_build_mixed_dataset_copies_labels(self, tmp_path, setup_pool_documents):
"""Test that corresponding label files are copied."""
base, pool_ids = setup_pool_documents
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Count labels
train_labels = list((output_dir / "labels" / "train").glob("*.txt"))
val_labels = list((output_dir / "labels" / "val").glob("*.txt"))
# Each image should have a corresponding label
train_images = list((output_dir / "images" / "train").glob("*.png"))
val_images = list((output_dir / "images" / "val").glob("*.png"))
# Allow label count to be <= image count (in case some labels are missing)
assert len(train_labels) <= len(train_images)
assert len(val_labels) <= len(val_images)
def test_build_mixed_dataset_skips_duplicate_files(self, tmp_path, setup_pool_documents):
"""Test behavior when running build_mixed_dataset multiple times."""
base, pool_ids = setup_pool_documents
output_dir = tmp_path / "mixed_output"
# First build
result1 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
initial_count = result1["total_images"]
# Find a file in output and modify it
train_images = list((output_dir / "images" / "train").glob("*.png"))
if len(train_images) > 0:
test_file = train_images[0]
test_file.write_text("modified content")
# Second build with same seed
result2 = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# The implementation uses shutil.copy2 which WILL overwrite
# So the file will be restored to original content
# Just verify the build completed successfully
assert result2["total_images"] >= 0
# Verify the file was overwritten (shutil.copy2 overwrites by default)
content = test_file.read_text()
assert content != "modified content" # Should be restored
def test_build_mixed_dataset_handles_jpg_images(self, tmp_path):
"""Test that JPG images are handled correctly."""
base = tmp_path / "base_dataset"
# Create directory structure
for split in ("train", "val"):
(base / "images" / split).mkdir(parents=True)
(base / "labels" / split).mkdir(parents=True)
# Create JPG images as old data
for i in range(1, 6):
img_path = base / "images" / "train" / f"old_doc_{i}_page1.jpg"
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
img_path.write_text(f"jpg image {i}")
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
# Create pool with JPG - use multiple pages to ensure at least one gets copied
pool_ids = [uuid4()]
doc_id = pool_ids[0]
for page_num in range(1, 4):
img_path = base / "images" / "train" / f"{doc_id}_page{page_num}.jpg"
label_path = base / "labels" / "train" / f"{doc_id}_page{page_num}.txt"
img_path.write_text(f"pool jpg {doc_id} page {page_num}")
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
output_dir = tmp_path / "mixed_output"
result = build_mixed_dataset(
pool_document_ids=pool_ids,
base_dataset_path=base,
output_dir=output_dir,
seed=42,
)
# Should have some new JPG images (at least 1 from the pool)
assert result["new_images"] > 0
assert result["old_images"] > 0
# Verify JPG files exist in output
all_images = list((output_dir / "images" / "train").glob("*.jpg")) + \
list((output_dir / "images" / "val").glob("*.jpg"))
assert len(all_images) > 0
class TestConstants:
"""Tests for module constants."""
def test_mixing_ratios_structure(self):
"""Test MIXING_RATIOS constant structure."""
assert isinstance(MIXING_RATIOS, list)
assert len(MIXING_RATIOS) == 4
# Verify format: (threshold, multiplier)
for item in MIXING_RATIOS:
assert isinstance(item, tuple)
assert len(item) == 2
assert isinstance(item[0], int)
assert isinstance(item[1], int)
# Verify thresholds are ascending
thresholds = [t for t, _ in MIXING_RATIOS]
assert thresholds == sorted(thresholds)
# Verify multipliers are descending
multipliers = [m for _, m in MIXING_RATIOS]
assert multipliers == sorted(multipliers, reverse=True)
def test_default_multiplier(self):
"""Test DEFAULT_MULTIPLIER constant."""
assert DEFAULT_MULTIPLIER == 5
assert DEFAULT_MULTIPLIER == MIXING_RATIOS[-1][1]
def test_max_old_samples(self):
"""Test MAX_OLD_SAMPLES constant."""
assert MAX_OLD_SAMPLES == 3000
assert MAX_OLD_SAMPLES > 0
def test_min_pool_size(self):
"""Test MIN_POOL_SIZE constant."""
assert MIN_POOL_SIZE == 50
assert MIN_POOL_SIZE > 0

View File

@@ -310,7 +310,7 @@ class TestSchedulerDatasetStatusUpdates:
try:
scheduler._execute_task(
task_id=task_id,
config={"model_name": "yolo11n.pt"},
config={"model_name": "yolo26s.pt"},
dataset_id=dataset_id,
)
except Exception:

View File

@@ -0,0 +1,467 @@
"""
Tests for Fine-Tune Pool feature.
Tests cover:
1. FineTunePoolEntry database model
2. PoolAddRequest/PoolStatsResponse schemas
3. Chain prevention logic
4. Pool threshold enforcement
5. Model lineage fields on ModelVersion
6. Gating enforcement on model activation
"""
import pytest
from datetime import datetime
from unittest.mock import MagicMock, patch
from uuid import uuid4, UUID
# =============================================================================
# Test Database Models
# =============================================================================
class TestFineTunePoolEntryModel:
"""Tests for FineTunePoolEntry model."""
def test_creates_with_defaults(self):
"""FineTunePoolEntry should have correct defaults."""
from backend.data.admin_models import FineTunePoolEntry
entry = FineTunePoolEntry(document_id=uuid4())
assert entry.entry_id is not None
assert entry.is_verified is False
assert entry.verified_at is None
assert entry.verified_by is None
assert entry.added_by is None
assert entry.reason is None
def test_creates_with_all_fields(self):
"""FineTunePoolEntry should accept all fields."""
from backend.data.admin_models import FineTunePoolEntry
doc_id = uuid4()
entry = FineTunePoolEntry(
document_id=doc_id,
added_by="admin",
reason="user_reported_failure",
is_verified=True,
verified_by="reviewer",
)
assert entry.document_id == doc_id
assert entry.added_by == "admin"
assert entry.reason == "user_reported_failure"
assert entry.is_verified is True
assert entry.verified_by == "reviewer"
class TestGatingResultModel:
"""Tests for GatingResult model."""
def test_creates_with_defaults(self):
"""GatingResult should have correct defaults."""
from backend.data.admin_models import GatingResult
model_version_id = uuid4()
result = GatingResult(
model_version_id=model_version_id,
gate1_status="pass",
gate2_status="pass",
overall_status="pass",
)
assert result.result_id is not None
assert result.model_version_id == model_version_id
assert result.gate1_status == "pass"
assert result.gate2_status == "pass"
assert result.overall_status == "pass"
assert result.gate1_mAP_drop is None
assert result.gate2_detection_rate is None
def test_creates_with_full_metrics(self):
"""GatingResult should store full metrics."""
from backend.data.admin_models import GatingResult
result = GatingResult(
model_version_id=uuid4(),
gate1_status="review",
gate1_original_mAP=0.95,
gate1_new_mAP=0.93,
gate1_mAP_drop=0.02,
gate2_status="pass",
gate2_detection_rate=0.85,
gate2_total_samples=100,
gate2_detected_samples=85,
overall_status="review",
)
assert result.gate1_original_mAP == 0.95
assert result.gate1_new_mAP == 0.93
assert result.gate1_mAP_drop == 0.02
assert result.gate2_detection_rate == 0.85
class TestModelVersionLineage:
"""Tests for ModelVersion lineage fields."""
def test_default_model_type_is_base(self):
"""ModelVersion should default to 'base' model_type."""
from backend.data.admin_models import ModelVersion
mv = ModelVersion(
version="v1.0",
name="test-model",
model_path="/path/to/model.pt",
)
assert mv.model_type == "base"
assert mv.base_model_version_id is None
assert mv.base_training_dataset_id is None
assert mv.gating_status == "pending"
def test_finetune_model_type(self):
"""ModelVersion should support 'finetune' type with lineage."""
from backend.data.admin_models import ModelVersion
base_id = uuid4()
dataset_id = uuid4()
mv = ModelVersion(
version="v2.0",
name="finetuned-model",
model_path="/path/to/ft_model.pt",
model_type="finetune",
base_model_version_id=base_id,
base_training_dataset_id=dataset_id,
gating_status="pending",
)
assert mv.model_type == "finetune"
assert mv.base_model_version_id == base_id
assert mv.base_training_dataset_id == dataset_id
assert mv.gating_status == "pending"
# =============================================================================
# Test Schemas
# =============================================================================
class TestPoolSchemas:
"""Tests for pool Pydantic schemas."""
def test_pool_add_request_defaults(self):
"""PoolAddRequest should have default reason."""
from backend.web.schemas.admin.pool import PoolAddRequest
req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001")
assert req.document_id == "550e8400-e29b-41d4-a716-446655440001"
assert req.reason == "user_reported_failure"
def test_pool_add_request_custom_reason(self):
"""PoolAddRequest should accept custom reason."""
from backend.web.schemas.admin.pool import PoolAddRequest
req = PoolAddRequest(
document_id="550e8400-e29b-41d4-a716-446655440001",
reason="manual_addition",
)
assert req.reason == "manual_addition"
def test_pool_stats_response(self):
"""PoolStatsResponse should compute readiness correctly."""
from backend.web.schemas.admin.pool import PoolStatsResponse
# Not ready
stats = PoolStatsResponse(
total_entries=30,
verified_entries=20,
unverified_entries=10,
is_ready=False,
)
assert stats.is_ready is False
assert stats.min_required == 50
# Ready
stats_ready = PoolStatsResponse(
total_entries=80,
verified_entries=60,
unverified_entries=20,
is_ready=True,
)
assert stats_ready.is_ready is True
def test_pool_entry_item(self):
"""PoolEntryItem should serialize correctly."""
from backend.web.schemas.admin.pool import PoolEntryItem
entry = PoolEntryItem(
entry_id="entry-uuid",
document_id="doc-uuid",
is_verified=True,
verified_at=datetime.utcnow(),
verified_by="admin",
created_at=datetime.utcnow(),
)
assert entry.is_verified is True
assert entry.verified_by == "admin"
def test_gating_result_item(self):
"""GatingResultItem should serialize all gate fields."""
from backend.web.schemas.admin.pool import GatingResultItem
item = GatingResultItem(
result_id="result-uuid",
model_version_id="model-uuid",
gate1_status="pass",
gate1_original_mAP=0.95,
gate1_new_mAP=0.94,
gate1_mAP_drop=0.01,
gate2_status="pass",
gate2_detection_rate=0.90,
gate2_total_samples=50,
gate2_detected_samples=45,
overall_status="pass",
created_at=datetime.utcnow(),
)
assert item.gate1_status == "pass"
assert item.overall_status == "pass"
# =============================================================================
# Test Chain Prevention
# =============================================================================
class TestChainPrevention:
"""Tests for fine-tune chain prevention logic."""
def test_rejects_finetune_from_finetune_model(self):
"""Should reject training when base model is already a fine-tune."""
# Simulate the chain prevention check from datasets.py
model_type = "finetune"
base_model_version_id = str(uuid4())
# This should trigger rejection
assert model_type == "finetune"
def test_allows_finetune_from_base_model(self):
"""Should allow training when base model is a base model."""
model_type = "base"
assert model_type != "finetune"
def test_allows_fresh_training(self):
"""Should allow fresh training (no base model)."""
base_model_version_id = None
assert base_model_version_id is None # No chain check needed
# =============================================================================
# Test Pool Threshold
# =============================================================================
class TestPoolThreshold:
"""Tests for minimum pool size enforcement."""
def test_min_pool_size_constant(self):
"""MIN_POOL_SIZE should be 50."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
assert MIN_POOL_SIZE == 50
def test_pool_below_threshold_blocks_finetune(self):
"""Pool with fewer than 50 verified entries should block fine-tuning."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = 30
assert verified_count < MIN_POOL_SIZE
def test_pool_at_threshold_allows_finetune(self):
"""Pool with exactly 50 verified entries should allow fine-tuning."""
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = 50
assert verified_count >= MIN_POOL_SIZE
# =============================================================================
# Test Gating Enforcement on Activation
# =============================================================================
class TestGatingEnforcement:
"""Tests for gating enforcement when activating models."""
def test_base_model_skips_gating(self):
"""Base models should have gating_status 'skipped'."""
from backend.data.admin_models import ModelVersion
mv = ModelVersion(
version="v1.0",
name="base",
model_path="/model.pt",
model_type="base",
)
# Base models skip gating - activation should work
assert mv.model_type == "base"
# Gating should not block base model activation
def test_finetune_model_rejected_blocks_activation(self):
"""Fine-tuned models with 'reject' gating should block activation."""
model_type = "finetune"
gating_status = "reject"
# Simulates the check in models.py activation endpoint
should_block = model_type == "finetune" and gating_status == "reject"
assert should_block is True
def test_finetune_model_pending_blocks_activation(self):
"""Fine-tuned models with 'pending' gating should block activation."""
model_type = "finetune"
gating_status = "pending"
should_block = model_type == "finetune" and gating_status == "pending"
assert should_block is True
def test_finetune_model_pass_allows_activation(self):
"""Fine-tuned models with 'pass' gating should allow activation."""
model_type = "finetune"
gating_status = "pass"
should_block_reject = model_type == "finetune" and gating_status == "reject"
should_block_pending = model_type == "finetune" and gating_status == "pending"
assert should_block_reject is False
assert should_block_pending is False
def test_finetune_model_review_allows_with_warning(self):
"""Fine-tuned models with 'review' gating should allow but warn."""
model_type = "finetune"
gating_status = "review"
should_block_reject = model_type == "finetune" and gating_status == "reject"
should_block_pending = model_type == "finetune" and gating_status == "pending"
assert should_block_reject is False
assert should_block_pending is False
# Should include warning in message
# =============================================================================
# Test Pool API Route Registration
# =============================================================================
class TestPoolRouteRegistration:
"""Tests for pool route registration."""
def test_pool_routes_registered(self):
"""Pool routes should be registered on training router."""
from backend.web.api.v1.admin.training import create_training_router
router = create_training_router()
paths = [route.path for route in router.routes]
assert any("/pool" in p for p in paths)
assert any("/pool/stats" in p for p in paths)
# =============================================================================
# Test Scheduler Fine-Tune Parameter Override
# =============================================================================
class TestSchedulerFineTuneParams:
"""Tests for scheduler fine-tune parameter overrides."""
def test_finetune_detected_from_base_model_path(self):
"""Scheduler should detect fine-tune mode from base_model_path."""
config = {"base_model_path": "/path/to/base_model.pt"}
is_finetune = bool(config.get("base_model_path"))
assert is_finetune is True
def test_fresh_training_not_finetune(self):
"""Scheduler should not enable fine-tune for fresh training."""
config = {"model_name": "yolo26s.pt"}
is_finetune = bool(config.get("base_model_path"))
assert is_finetune is False
def test_finetune_defaults_correct_epochs(self):
"""Fine-tune should default to 10 epochs."""
config = {"base_model_path": "/path/to/model.pt"}
is_finetune = bool(config.get("base_model_path"))
if is_finetune:
epochs = config.get("epochs", 10)
learning_rate = config.get("learning_rate", 0.001)
else:
epochs = config.get("epochs", 100)
learning_rate = config.get("learning_rate", 0.01)
assert epochs == 10
assert learning_rate == 0.001
def test_model_lineage_set_for_finetune(self):
"""Scheduler should set model_type and base_model_version_id for fine-tune."""
config = {
"base_model_path": "/path/to/model.pt",
"base_model_version_id": str(uuid4()),
}
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
gating_status = "pending" if is_finetune else "skipped"
assert model_type == "finetune"
assert base_model_version_id is not None
assert gating_status == "pending"
def test_model_lineage_skipped_for_base(self):
"""Scheduler should set model_type='base' for fresh training."""
config = {"model_name": "yolo26s.pt"}
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
gating_status = "pending" if is_finetune else "skipped"
assert model_type == "base"
assert gating_status == "skipped"
# =============================================================================
# Test TrainingConfig freeze/cos_lr
# =============================================================================
class TestTrainingConfigFineTuneFields:
"""Tests for freeze and cos_lr fields in shared TrainingConfig."""
def test_default_freeze_is_zero(self):
"""TrainingConfig freeze should default to 0."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="test.pt",
data_yaml="data.yaml",
)
assert config.freeze == 0
def test_default_cos_lr_is_false(self):
"""TrainingConfig cos_lr should default to False."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="test.pt",
data_yaml="data.yaml",
)
assert config.cos_lr is False
def test_finetune_config(self):
"""TrainingConfig should accept fine-tune parameters."""
from shared.training import TrainingConfig
config = TrainingConfig(
model_path="base_model.pt",
data_yaml="data.yaml",
epochs=10,
learning_rate=0.001,
freeze=10,
cos_lr=True,
)
assert config.freeze == 10
assert config.cos_lr is True
assert config.epochs == 10
assert config.learning_rate == 0.001

View File

@@ -1,14 +1,14 @@
"""
Tests for Training Export with expand_bbox integration.
Tests for Training Export with uniform expand_bbox integration.
Tests the export endpoint's integration with field-specific bbox expansion.
Tests the export endpoint's integration with uniform bbox expansion.
"""
import pytest
from unittest.mock import MagicMock, patch
from uuid import uuid4
from shared.bbox import expand_bbox
from shared.bbox import expand_bbox, UNIFORM_PAD
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
@@ -17,149 +17,87 @@ class TestExpandBboxForExport:
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
"""Verify expand_bbox works with pixel-to-normalized conversion."""
# Annotation stored as normalized coords
x_center_norm = 0.5
y_center_norm = 0.5
width_norm = 0.1
height_norm = 0.05
# Image dimensions
img_width = 2480 # A4 at 300 DPI
img_width = 2480
img_height = 3508
# Convert to pixel coords
x_center_px = x_center_norm * img_width
y_center_px = y_center_norm * img_height
width_px = width_norm * img_width
height_px = height_norm * img_height
# Convert to corner coords
x0 = x_center_px - width_px / 2
y0 = y_center_px - height_px / 2
x1 = x_center_px + width_px / 2
y1 = y_center_px + height_px / 2
# Apply expansion
class_name = "invoice_number"
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# Verify expanded bbox is larger
assert ex0 < x0 # Left expanded
assert ey0 < y0 # Top expanded
assert ex1 > x1 # Right expanded
assert ey1 > y1 # Bottom expanded
assert ex0 < x0
assert ey0 < y0
assert ex1 > x1
assert ey1 > y1
# Convert back to normalized
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Verify valid normalized coords
assert 0 <= new_x_center <= 1
assert 0 <= new_y_center <= 1
assert 0 <= new_width <= 1
assert 0 <= new_height <= 1
def test_expand_bbox_manual_mode_minimal_expansion(self):
"""Verify manual annotations use minimal expansion."""
# Small bbox
def test_expand_bbox_uniform_for_all_sources(self):
"""Verify all annotation sources get the same uniform expansion."""
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
# Auto mode (field-specific expansion)
auto_result = expand_bbox(
# All sources now get the same uniform expansion
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
manual_mode=False,
)
# Manual mode (minimal expansion)
manual_result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
manual_mode=True,
expected = (
100 - UNIFORM_PAD,
100 - UNIFORM_PAD,
200 + UNIFORM_PAD,
150 + UNIFORM_PAD,
)
# Auto expansion should be larger than manual
auto_width = auto_result[2] - auto_result[0]
manual_width = manual_result[2] - manual_result[0]
assert auto_width > manual_width
auto_height = auto_result[3] - auto_result[1]
manual_height = manual_result[3] - manual_result[1]
assert auto_height > manual_height
def test_expand_bbox_different_sources_use_correct_mode(self):
"""Verify different annotation sources use correct expansion mode."""
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
# Define source to manual_mode mapping
source_mode_mapping = {
"manual": True, # Manual annotations -> minimal expansion
"auto": False, # Auto-labeled -> field-specific expansion
"imported": True, # Imported (from CSV) -> minimal expansion
}
results = {}
for source, manual_mode in source_mode_mapping.items():
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="ocr_number",
manual_mode=manual_mode,
)
results[source] = result
# Auto should have largest expansion
auto_area = (results["auto"][2] - results["auto"][0]) * \
(results["auto"][3] - results["auto"][1])
manual_area = (results["manual"][2] - results["manual"][0]) * \
(results["manual"][3] - results["manual"][1])
imported_area = (results["imported"][2] - results["imported"][0]) * \
(results["imported"][3] - results["imported"][1])
assert auto_area > manual_area
assert auto_area > imported_area
# Manual and imported should be the same (both use minimal mode)
assert manual_area == imported_area
assert result == expected
def test_expand_bbox_all_field_types_work(self):
"""Verify expand_bbox works for all field types."""
"""Verify expand_bbox works for all field types (same result)."""
bbox = (100, 100, 200, 150)
img_width = 2480
img_height = 3508
for class_name in CLASS_NAMES:
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# All fields should produce the same result with uniform padding
first_result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
)
# Verify result is a valid bbox
assert len(result) == 4
x0, y0, x1, y1 = result
assert x0 >= 0
assert y0 >= 0
assert x1 <= img_width
assert y1 <= img_height
assert x1 > x0
assert y1 > y0
assert len(first_result) == 4
x0, y0, x1, y1 = first_result
assert x0 >= 0
assert y0 >= 0
assert x1 <= img_width
assert y1 <= img_height
assert x1 > x0
assert y1 > y0
class TestExportAnnotationExpansion:
@@ -167,7 +105,6 @@ class TestExportAnnotationExpansion:
def test_annotation_bbox_conversion_workflow(self):
"""Test full annotation bbox conversion workflow."""
# Simulate stored annotation (normalized coords)
class MockAnnotation:
class_id = FIELD_CLASS_IDS["invoice_number"]
class_name = "invoice_number"
@@ -181,7 +118,6 @@ class TestExportAnnotationExpansion:
img_width = 2480
img_height = 3508
# Step 1: Convert normalized to pixel corner coords
half_w = (ann.width * img_width) / 2
half_h = (ann.height * img_height) / 2
x0 = ann.x_center * img_width - half_w
@@ -189,38 +125,27 @@ class TestExportAnnotationExpansion:
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Step 2: Determine manual_mode based on source
manual_mode = ann.source in ("manual", "imported")
# Step 3: Apply expand_bbox
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Step 4: Convert back to normalized
new_x_center = (ex0 + ex1) / 2 / img_width
new_y_center = (ey0 + ey1) / 2 / img_height
new_width = (ex1 - ex0) / img_width
new_height = (ey1 - ey0) / img_height
# Verify expansion happened (auto mode)
assert new_width > ann.width
assert new_height > ann.height
# Verify valid YOLO format
assert 0 <= new_x_center <= 1
assert 0 <= new_y_center <= 1
assert 0 < new_width <= 1
assert 0 < new_height <= 1
def test_export_applies_expansion_to_each_annotation(self):
"""Test that export applies expansion to each annotation."""
# Simulate multiple annotations with different sources
# Use smaller bboxes so manual mode padding has visible effect
def test_export_applies_uniform_expansion_to_all_annotations(self):
"""Test that export applies uniform expansion to all annotations."""
annotations = [
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
@@ -232,7 +157,6 @@ class TestExportAnnotationExpansion:
expanded_annotations = []
for ann in annotations:
# Convert to pixel coords
half_w = (ann["width"] * img_width) / 2
half_h = (ann["height"] * img_height) / 2
x0 = ann["x_center"] * img_width - half_w
@@ -240,19 +164,12 @@ class TestExportAnnotationExpansion:
x1 = ann["x_center"] * img_width + half_w
y1 = ann["y_center"] * img_height + half_h
# Determine manual_mode
manual_mode = ann["source"] in ("manual", "imported")
# Apply expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann["class_name"],
manual_mode=manual_mode,
)
# Convert back to normalized
expanded_annotations.append({
"class_name": ann["class_name"],
"source": ann["source"],
@@ -262,106 +179,48 @@ class TestExportAnnotationExpansion:
"height": (ey1 - ey0) / img_height,
})
# Verify auto-labeled annotation expanded more than manual/imported
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
# Auto mode should expand more than manual mode
# (auto has larger scale factors and max_pad)
assert auto_ann["width"] > manual_ann["width"]
assert auto_ann["height"] > manual_ann["height"]
# All annotations should be expanded (at least slightly for manual mode)
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
# Width and height should be >= original (expansion or equal, with small tolerance)
tolerance = 0.01 # 1% tolerance for integer rounding
assert exp["width"] >= orig["width"] * (1 - tolerance), \
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
assert exp["height"] >= orig["height"] * (1 - tolerance), \
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
# All annotations get the same expansion
tolerance = 0.01
for orig, exp in zip(annotations, expanded_annotations):
assert exp["width"] >= orig["width"] * (1 - tolerance)
assert exp["height"] >= orig["height"] * (1 - tolerance)
class TestExpandBboxEdgeCases:
"""Tests for edge cases in export bbox expansion."""
def test_bbox_at_image_edge_left(self):
"""Test bbox at left edge of image."""
bbox = (0, 100, 50, 150)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
# Left edge should be clamped to 0
assert result[0] >= 0
def test_bbox_at_image_edge_right(self):
"""Test bbox at right edge of image."""
bbox = (2400, 100, 2480, 150)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
# Right edge should be clamped to image width
assert result[2] <= img_width
assert result[2] <= 2480
def test_bbox_at_image_edge_top(self):
"""Test bbox at top edge of image."""
bbox = (100, 0, 200, 50)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
# Top edge should be clamped to 0
assert result[1] >= 0
def test_bbox_at_image_edge_bottom(self):
"""Test bbox at bottom edge of image."""
bbox = (100, 3400, 200, 3508)
img_width = 2480
img_height = 3508
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
# Bottom edge should be clamped to image height
assert result[3] <= img_height
assert result[3] <= 3508
def test_very_small_bbox(self):
"""Test very small bbox gets expanded."""
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
img_width = 2480
img_height = 3508
bbox = (100, 100, 105, 105)
result = expand_bbox(
bbox=bbox,
image_width=img_width,
image_height=img_height,
field_type="invoice_number",
)
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
# Should still produce a valid expanded bbox
assert result[2] > result[0]
assert result[3] > result[1]