WIP
This commit is contained in:
@@ -400,6 +400,71 @@ class TestAmountNormalizer:
|
||||
result = normalizer.normalize("Reference 12500")
|
||||
assert result.value == "12500.00"
|
||||
|
||||
def test_payment_line_kronor_ore_format(self, normalizer):
|
||||
"""Space between kronor and ore should be treated as decimal separator.
|
||||
|
||||
Swedish payment lines use space to separate kronor and ore:
|
||||
"590 00" means 590.00 SEK, NOT 59000.
|
||||
"""
|
||||
result = normalizer.normalize("590 00")
|
||||
assert result.value == "590.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_payment_line_kronor_ore_large_amount(self, normalizer):
|
||||
"""Large kronor/ore amount from payment line."""
|
||||
result = normalizer.normalize("15658 00")
|
||||
assert result.value == "15658.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_payment_line_kronor_ore_with_nonzero_ore(self, normalizer):
|
||||
"""Kronor/ore with non-zero ore."""
|
||||
result = normalizer.normalize("736 50")
|
||||
assert result.value == "736.50"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_kronor_ore_not_confused_with_thousand_separator(self, normalizer):
|
||||
"""Amount with comma decimal should NOT trigger kronor/ore pattern."""
|
||||
result = normalizer.normalize("1 234,56")
|
||||
assert result.value is not None
|
||||
# Should parse as 1234.56, not as kronor=1234 ore=56 (which is same value)
|
||||
assert float(result.value) == 1234.56
|
||||
|
||||
def test_european_dot_thousand_separator(self, normalizer):
|
||||
"""European format: dot as thousand, comma as decimal."""
|
||||
result = normalizer.normalize("2.254,50")
|
||||
assert result.value == "2254.50"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_european_dot_thousand_with_sek(self, normalizer):
|
||||
"""European format with SEK suffix."""
|
||||
result = normalizer.normalize("2.254,50 SEK")
|
||||
assert result.value == "2254.50"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_european_dot_thousand_with_kr(self, normalizer):
|
||||
"""European format with kr suffix."""
|
||||
result = normalizer.normalize("20.485,00 kr")
|
||||
assert result.value == "20485.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_european_large_amount(self, normalizer):
|
||||
"""Large European format amount."""
|
||||
result = normalizer.normalize("1.234.567,89")
|
||||
assert result.value == "1234567.89"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_european_in_label_context(self, normalizer):
|
||||
"""European format inside label text (like the BAUHAUS invoice bug)."""
|
||||
result = normalizer.normalize("ns Fakturabelopp: 2.254,50 SEK")
|
||||
assert result.value == "2254.50"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_anglo_comma_thousand_separator(self, normalizer):
|
||||
"""Anglo format: comma as thousand, dot as decimal."""
|
||||
result = normalizer.normalize("1,234.56")
|
||||
assert result.value == "1234.56"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_zero_amount_rejected(self, normalizer):
|
||||
"""Test that zero amounts are rejected."""
|
||||
result = normalizer.normalize("0,00 kr")
|
||||
@@ -450,6 +515,18 @@ class TestEnhancedAmountNormalizer:
|
||||
result = normalizer.normalize("Invoice for 1 234 567,89 kr")
|
||||
assert result.value is not None
|
||||
|
||||
def test_enhanced_kronor_ore_format(self, normalizer):
|
||||
"""Space between kronor and ore in enhanced normalizer."""
|
||||
result = normalizer.normalize("590 00")
|
||||
assert result.value == "590.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_enhanced_kronor_ore_large(self, normalizer):
|
||||
"""Large kronor/ore amount in enhanced normalizer."""
|
||||
result = normalizer.normalize("15658 00")
|
||||
assert result.value == "15658.00"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_no_amount_fails(self, normalizer):
|
||||
"""Test failure when no amount found."""
|
||||
result = normalizer.normalize("no amount")
|
||||
@@ -472,6 +549,22 @@ class TestEnhancedAmountNormalizer:
|
||||
result = normalizer.normalize("Price: 1 234 567,89")
|
||||
assert result.value is not None
|
||||
|
||||
def test_enhanced_european_dot_thousand(self, normalizer):
|
||||
"""European format in enhanced normalizer."""
|
||||
result = normalizer.normalize("2.254,50 SEK")
|
||||
assert result.value == "2254.50"
|
||||
assert result.is_valid is True
|
||||
|
||||
def test_enhanced_european_with_label(self, normalizer):
|
||||
"""European format with Swedish label keyword."""
|
||||
result = normalizer.normalize("Att betala: 2.254,50")
|
||||
assert result.value == "2254.50"
|
||||
|
||||
def test_enhanced_anglo_format(self, normalizer):
|
||||
"""Anglo format in enhanced normalizer."""
|
||||
result = normalizer.normalize("Total: 1,234.56")
|
||||
assert result.value == "1234.56"
|
||||
|
||||
def test_amount_out_of_range_rejected(self, normalizer):
|
||||
"""Test that amounts >= 10,000,000 are rejected."""
|
||||
result = normalizer.normalize("Summa: 99 999 999,00")
|
||||
|
||||
@@ -497,5 +497,178 @@ class TestExtractBusinessFeaturesErrorHandling:
|
||||
assert "NumericException" in result.errors[0]
|
||||
|
||||
|
||||
class TestProcessPdfTokenPath:
|
||||
"""Tests for PDF text token extraction path in process_pdf()."""
|
||||
|
||||
def _make_pipeline(self):
|
||||
"""Create pipeline with mocked internals, bypassing __init__."""
|
||||
with patch.object(InferencePipeline, '__init__', lambda self, **kw: None):
|
||||
p = InferencePipeline()
|
||||
p.detector = MagicMock()
|
||||
p.extractor = MagicMock()
|
||||
p.payment_line_parser = MagicMock()
|
||||
p.dpi = 300
|
||||
p.enable_fallback = False
|
||||
p.enable_business_features = False
|
||||
p.vat_tolerance = 0.5
|
||||
p.line_items_extractor = None
|
||||
p.vat_extractor = None
|
||||
p.vat_validator = None
|
||||
p._business_ocr_engine = None
|
||||
p._table_detector = None
|
||||
return p
|
||||
|
||||
def _make_detection(self, class_name='Amount', confidence=0.85, page_no=0):
|
||||
"""Create a Detection object."""
|
||||
from backend.pipeline.yolo_detector import Detection
|
||||
return Detection(
|
||||
class_id=6,
|
||||
class_name=class_name,
|
||||
confidence=confidence,
|
||||
bbox=(100.0, 200.0, 300.0, 250.0),
|
||||
page_no=page_no,
|
||||
)
|
||||
|
||||
def _make_extracted_field(self, field_name='Amount', raw_text='2.254,50',
|
||||
normalized='2254.50', confidence=0.85):
|
||||
"""Create an ExtractedField object."""
|
||||
from backend.pipeline.field_extractor import ExtractedField
|
||||
return ExtractedField(
|
||||
field_name=field_name,
|
||||
raw_text=raw_text,
|
||||
normalized_value=normalized,
|
||||
confidence=confidence,
|
||||
detection_confidence=confidence,
|
||||
ocr_confidence=1.0,
|
||||
bbox=(100.0, 200.0, 300.0, 250.0),
|
||||
page_no=0,
|
||||
)
|
||||
|
||||
def _make_image_bytes(self):
|
||||
"""Create minimal valid PNG bytes (100x100 white image)."""
|
||||
from PIL import Image as PILImage
|
||||
import io as _io
|
||||
img = PILImage.new('RGB', (100, 100), color='white')
|
||||
buf = _io.BytesIO()
|
||||
img.save(buf, format='PNG')
|
||||
return buf.getvalue()
|
||||
|
||||
@patch('shared.pdf.extractor.PDFDocument')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
def test_text_pdf_uses_pdf_tokens(self, mock_render, mock_pdf_doc_cls):
|
||||
"""When PDF is text-based, extract_from_detection_with_pdf is used."""
|
||||
from shared.pdf.extractor import Token
|
||||
|
||||
pipeline = self._make_pipeline()
|
||||
detection = self._make_detection()
|
||||
image_bytes = self._make_image_bytes()
|
||||
|
||||
# Setup PDFDocument mock - text PDF with tokens
|
||||
mock_pdf_doc = MagicMock()
|
||||
mock_pdf_doc.is_text_pdf.return_value = True
|
||||
mock_pdf_doc.page_count = 1
|
||||
tokens = [Token(text="2.254,50", bbox=(100, 200, 200, 220), page_no=0)]
|
||||
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
|
||||
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
||||
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
pipeline.detector.detect.return_value = [detection]
|
||||
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
|
||||
self._make_extracted_field()
|
||||
)
|
||||
|
||||
mock_render.return_value = iter([(0, image_bytes)])
|
||||
result = pipeline.process_pdf('/fake/invoice.pdf')
|
||||
|
||||
pipeline.extractor.extract_from_detection_with_pdf.assert_called_once()
|
||||
pipeline.extractor.extract_from_detection.assert_not_called()
|
||||
assert result.fields.get('Amount') == '2254.50'
|
||||
assert result.success is True
|
||||
|
||||
@patch('shared.pdf.extractor.PDFDocument')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
def test_scanned_pdf_uses_ocr(self, mock_render, mock_pdf_doc_cls):
|
||||
"""When PDF is scanned, extract_from_detection (OCR) is used."""
|
||||
pipeline = self._make_pipeline()
|
||||
detection = self._make_detection()
|
||||
image_bytes = self._make_image_bytes()
|
||||
|
||||
mock_pdf_doc = MagicMock()
|
||||
mock_pdf_doc.is_text_pdf.return_value = False
|
||||
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
||||
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
pipeline.detector.detect.return_value = [detection]
|
||||
pipeline.extractor.extract_from_detection.return_value = (
|
||||
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
|
||||
)
|
||||
|
||||
mock_render.return_value = iter([(0, image_bytes)])
|
||||
result = pipeline.process_pdf('/fake/invoice.pdf')
|
||||
|
||||
pipeline.extractor.extract_from_detection.assert_called_once()
|
||||
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
|
||||
|
||||
@patch('shared.pdf.extractor.PDFDocument')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
def test_pdf_detection_error_falls_back_to_ocr(self, mock_render, mock_pdf_doc_cls):
|
||||
"""When PDF text detection throws, fall back to OCR."""
|
||||
pipeline = self._make_pipeline()
|
||||
detection = self._make_detection()
|
||||
image_bytes = self._make_image_bytes()
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.__enter__ = MagicMock(side_effect=Exception("corrupt PDF"))
|
||||
mock_ctx.__exit__ = MagicMock(return_value=False)
|
||||
mock_pdf_doc_cls.return_value = mock_ctx
|
||||
|
||||
pipeline.detector.detect.return_value = [detection]
|
||||
pipeline.extractor.extract_from_detection.return_value = (
|
||||
self._make_extracted_field(raw_text='4.50', normalized='4.50', confidence=0.75)
|
||||
)
|
||||
|
||||
mock_render.return_value = iter([(0, image_bytes)])
|
||||
result = pipeline.process_pdf('/fake/invoice.pdf')
|
||||
|
||||
pipeline.extractor.extract_from_detection.assert_called_once()
|
||||
pipeline.extractor.extract_from_detection_with_pdf.assert_not_called()
|
||||
|
||||
@patch('shared.pdf.extractor.PDFDocument')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
def test_text_pdf_passes_correct_args(self, mock_render, mock_pdf_doc_cls):
|
||||
"""Verify correct token list and image dimensions are passed."""
|
||||
from shared.pdf.extractor import Token
|
||||
|
||||
pipeline = self._make_pipeline()
|
||||
detection = self._make_detection()
|
||||
image_bytes = self._make_image_bytes() # 100x100 PNG
|
||||
|
||||
mock_pdf_doc = MagicMock()
|
||||
mock_pdf_doc.is_text_pdf.return_value = True
|
||||
mock_pdf_doc.page_count = 1
|
||||
tokens = [
|
||||
Token(text="Fakturabelopp:", bbox=(50, 190, 100, 210), page_no=0),
|
||||
Token(text="2.254,50", bbox=(105, 190, 180, 210), page_no=0),
|
||||
Token(text="SEK", bbox=(185, 190, 210, 210), page_no=0),
|
||||
]
|
||||
mock_pdf_doc.extract_text_tokens.return_value = iter(tokens)
|
||||
mock_pdf_doc_cls.return_value.__enter__ = MagicMock(return_value=mock_pdf_doc)
|
||||
mock_pdf_doc_cls.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
pipeline.detector.detect.return_value = [detection]
|
||||
pipeline.extractor.extract_from_detection_with_pdf.return_value = (
|
||||
self._make_extracted_field()
|
||||
)
|
||||
|
||||
mock_render.return_value = iter([(0, image_bytes)])
|
||||
pipeline.process_pdf('/fake/invoice.pdf')
|
||||
|
||||
call_args = pipeline.extractor.extract_from_detection_with_pdf.call_args[0]
|
||||
assert call_args[0] == detection
|
||||
assert len(call_args[1]) == 3 # 3 tokens passed
|
||||
assert call_args[2] == 100 # image width
|
||||
assert call_args[3] == 100 # image height
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
|
||||
0
tests/pipeline/__init__.py
Normal file
0
tests/pipeline/__init__.py
Normal file
318
tests/pipeline/test_value_selector.py
Normal file
318
tests/pipeline/test_value_selector.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Tests for ValueSelector -- field-aware OCR token selection.
|
||||
|
||||
Verifies that ValueSelector picks the most likely value token(s)
|
||||
from OCR output, filtering out label text before sending to normalizer.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.ocr.paddle_ocr import OCRToken
|
||||
from backend.pipeline.value_selector import ValueSelector
|
||||
|
||||
|
||||
def _token(text: str) -> OCRToken:
|
||||
"""Helper to create OCRToken with dummy bbox and confidence."""
|
||||
return OCRToken(text=text, bbox=(0, 0, 100, 20), confidence=0.95)
|
||||
|
||||
|
||||
def _tokens(*texts: str) -> list[OCRToken]:
|
||||
"""Helper to create multiple OCRTokens."""
|
||||
return [_token(t) for t in texts]
|
||||
|
||||
|
||||
class TestValueSelectorDateFields:
|
||||
"""Tests for date field value selection (InvoiceDate, InvoiceDueDate)."""
|
||||
|
||||
def test_selects_iso_date_from_label_and_value(self):
|
||||
tokens = _tokens("Fakturadatum", "2024-01-15")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "2024-01-15"
|
||||
|
||||
def test_selects_dot_separated_date(self):
|
||||
tokens = _tokens("Datum", "2024.03.20")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "2024.03.20"
|
||||
|
||||
def test_selects_slash_separated_date(self):
|
||||
tokens = _tokens("Forfallodag", "15/01/2024")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDueDate")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "15/01/2024"
|
||||
|
||||
def test_selects_compact_date(self):
|
||||
tokens = _tokens("Datum", "20240115")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "20240115"
|
||||
|
||||
def test_fallback_when_no_date_pattern(self):
|
||||
"""No date pattern found -> return all tokens."""
|
||||
tokens = _tokens("Fakturadatum", "pending")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestValueSelectorAmountField:
|
||||
"""Tests for amount field value selection."""
|
||||
|
||||
def test_selects_amount_with_comma_decimal(self):
|
||||
tokens = _tokens("Belopp", "1 234,56", "kr")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1 234,56"
|
||||
|
||||
def test_selects_amount_with_dot_decimal(self):
|
||||
tokens = _tokens("Summa", "1234.56")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1234.56"
|
||||
|
||||
def test_selects_simple_amount(self):
|
||||
tokens = _tokens("Att", "betala", "500,00")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "500,00"
|
||||
|
||||
def test_selects_european_amount_with_dot_thousand(self):
|
||||
"""European format: dot as thousand separator, comma as decimal."""
|
||||
tokens = _tokens("Fakturabelopp:", "2.254,50 SEK")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "2.254,50 SEK"
|
||||
|
||||
def test_selects_european_amount_without_currency(self):
|
||||
"""European format without currency suffix."""
|
||||
tokens = _tokens("Belopp", "1.234,56")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1.234,56"
|
||||
|
||||
def test_selects_amount_with_kr_suffix(self):
|
||||
"""Amount with 'kr' currency suffix."""
|
||||
tokens = _tokens("Summa", "20.485,00 kr")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "20.485,00 kr"
|
||||
|
||||
def test_selects_anglo_amount_with_sek(self):
|
||||
"""Anglo format with SEK suffix."""
|
||||
tokens = _tokens("Amount", "1,234.56 SEK")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1,234.56 SEK"
|
||||
|
||||
def test_fallback_when_no_amount_pattern(self):
|
||||
tokens = _tokens("Belopp", "TBD")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Amount")
|
||||
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestValueSelectorBankgiroField:
|
||||
"""Tests for Bankgiro field value selection."""
|
||||
|
||||
def test_selects_hyphenated_bankgiro(self):
|
||||
tokens = _tokens("BG:", "123-4567")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "123-4567"
|
||||
|
||||
def test_selects_bankgiro_digits(self):
|
||||
tokens = _tokens("Bankgiro", "1234567")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1234567"
|
||||
|
||||
def test_selects_eight_digit_bankgiro(self):
|
||||
tokens = _tokens("Bankgiro:", "12345678")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Bankgiro")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "12345678"
|
||||
|
||||
|
||||
class TestValueSelectorPlusgiroField:
|
||||
"""Tests for Plusgiro field value selection."""
|
||||
|
||||
def test_selects_hyphenated_plusgiro(self):
|
||||
tokens = _tokens("PG:", "12345-6")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Plusgiro")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "12345-6"
|
||||
|
||||
def test_selects_plusgiro_digits(self):
|
||||
tokens = _tokens("Plusgiro", "1234567")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "Plusgiro")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1234567"
|
||||
|
||||
|
||||
class TestValueSelectorOcrField:
|
||||
"""Tests for OCR reference number field value selection."""
|
||||
|
||||
def test_selects_longest_digit_sequence(self):
|
||||
tokens = _tokens("OCR", "1234567890")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "1234567890"
|
||||
|
||||
def test_selects_token_with_most_digits(self):
|
||||
tokens = _tokens("Ref", "nr", "94228110015950070")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "94228110015950070"
|
||||
|
||||
def test_ignores_short_digit_tokens(self):
|
||||
"""Tokens with fewer than 5 digits are not OCR references."""
|
||||
tokens = _tokens("OCR", "123")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "OCR")
|
||||
|
||||
# Fallback: return all tokens since no valid OCR found
|
||||
assert len(result) == 2
|
||||
|
||||
|
||||
class TestValueSelectorInvoiceNumberField:
|
||||
"""Tests for InvoiceNumber field value selection."""
|
||||
|
||||
def test_removes_swedish_label_keywords(self):
|
||||
tokens = _tokens("Fakturanummer", "INV-2024-001")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "INV-2024-001"
|
||||
|
||||
def test_keeps_non_label_tokens(self):
|
||||
tokens = _tokens("Nr", "12345")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "12345"
|
||||
|
||||
def test_multiple_value_tokens_kept(self):
|
||||
"""Multiple non-label tokens are all kept."""
|
||||
tokens = _tokens("Fakturanr", "INV", "2024", "001")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceNumber")
|
||||
|
||||
# "Fakturanr" is a label keyword, the rest are values
|
||||
result_texts = [t.text for t in result]
|
||||
assert "Fakturanr" not in result_texts
|
||||
assert "INV" in result_texts
|
||||
|
||||
|
||||
class TestValueSelectorOrgNumberField:
|
||||
"""Tests for supplier_org_number field value selection."""
|
||||
|
||||
def test_selects_org_number_with_hyphen(self):
|
||||
tokens = _tokens("Org.nr", "556123-4567")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "supplier_org_number")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "556123-4567"
|
||||
|
||||
def test_selects_org_number_without_hyphen(self):
|
||||
tokens = _tokens("Organisationsnummer", "5561234567")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "supplier_org_number")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "5561234567"
|
||||
|
||||
|
||||
class TestValueSelectorCustomerNumberField:
|
||||
"""Tests for customer_number field value selection."""
|
||||
|
||||
def test_removes_label_keeps_value(self):
|
||||
tokens = _tokens("Kundnummer", "ABC-123")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "customer_number")
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "ABC-123"
|
||||
|
||||
|
||||
class TestValueSelectorPaymentLineField:
|
||||
"""Tests for payment_line field -- keeps all tokens."""
|
||||
|
||||
def test_keeps_all_tokens(self):
|
||||
tokens = _tokens("#", "94228110015950070", "#", "15658", "00", "8", ">", "48666036#14#")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "payment_line")
|
||||
|
||||
assert len(result) == len(tokens)
|
||||
|
||||
|
||||
class TestValueSelectorFallback:
|
||||
"""Tests for fallback behavior."""
|
||||
|
||||
def test_unknown_field_returns_all_tokens(self):
|
||||
tokens = _tokens("some", "unknown", "text")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "unknown_field")
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
def test_empty_tokens_returns_empty(self):
|
||||
result = ValueSelector.select_value_tokens([], "InvoiceDate")
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_single_token_returns_it(self):
|
||||
tokens = _tokens("2024-01-15")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_never_returns_empty_when_tokens_exist(self):
|
||||
"""Fallback ensures we never lose data -- always return something."""
|
||||
tokens = _tokens("Fakturadatum", "unknown_format")
|
||||
|
||||
result = ValueSelector.select_value_tokens(tokens, "InvoiceDate")
|
||||
|
||||
assert len(result) > 0
|
||||
1
tests/services/__init__.py
Normal file
1
tests/services/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for backend services."""
|
||||
344
tests/services/test_data_mixer.py
Normal file
344
tests/services/test_data_mixer.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""
|
||||
Tests for Data Mixing Service.
|
||||
|
||||
Tests cover:
|
||||
1. get_mixing_ratio boundary values
|
||||
2. build_mixed_dataset with temp filesystem
|
||||
3. _find_pool_images matching logic
|
||||
4. _image_to_label_path conversion
|
||||
5. Edge cases (empty pool, no old data, cap)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
from backend.web.services.data_mixer import (
|
||||
get_mixing_ratio,
|
||||
build_mixed_dataset,
|
||||
_collect_images,
|
||||
_image_to_label_path,
|
||||
_find_pool_images,
|
||||
MIXING_RATIOS,
|
||||
DEFAULT_MULTIPLIER,
|
||||
MAX_OLD_SAMPLES,
|
||||
MIN_POOL_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Constants
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for data mixer constants."""
|
||||
|
||||
def test_mixing_ratios_defined(self):
|
||||
"""MIXING_RATIOS should have expected entries."""
|
||||
assert len(MIXING_RATIOS) == 4
|
||||
assert MIXING_RATIOS[0] == (10, 50)
|
||||
assert MIXING_RATIOS[1] == (50, 20)
|
||||
assert MIXING_RATIOS[2] == (200, 10)
|
||||
assert MIXING_RATIOS[3] == (500, 5)
|
||||
|
||||
def test_default_multiplier(self):
|
||||
"""DEFAULT_MULTIPLIER should be 5."""
|
||||
assert DEFAULT_MULTIPLIER == 5
|
||||
|
||||
def test_max_old_samples(self):
|
||||
"""MAX_OLD_SAMPLES should be 3000."""
|
||||
assert MAX_OLD_SAMPLES == 3000
|
||||
|
||||
def test_min_pool_size(self):
|
||||
"""MIN_POOL_SIZE should be 50."""
|
||||
assert MIN_POOL_SIZE == 50
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test get_mixing_ratio
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetMixingRatio:
|
||||
"""Tests for get_mixing_ratio function."""
|
||||
|
||||
def test_1_sample_returns_50x(self):
|
||||
"""1 new sample should get 50x old data."""
|
||||
assert get_mixing_ratio(1) == 50
|
||||
|
||||
def test_10_samples_returns_50x(self):
|
||||
"""10 new samples (boundary) should get 50x."""
|
||||
assert get_mixing_ratio(10) == 50
|
||||
|
||||
def test_11_samples_returns_20x(self):
|
||||
"""11 new samples should get 20x."""
|
||||
assert get_mixing_ratio(11) == 20
|
||||
|
||||
def test_50_samples_returns_20x(self):
|
||||
"""50 new samples (boundary) should get 20x."""
|
||||
assert get_mixing_ratio(50) == 20
|
||||
|
||||
def test_51_samples_returns_10x(self):
|
||||
"""51 new samples should get 10x."""
|
||||
assert get_mixing_ratio(51) == 10
|
||||
|
||||
def test_200_samples_returns_10x(self):
|
||||
"""200 new samples (boundary) should get 10x."""
|
||||
assert get_mixing_ratio(200) == 10
|
||||
|
||||
def test_201_samples_returns_5x(self):
|
||||
"""201 new samples should get 5x."""
|
||||
assert get_mixing_ratio(201) == 5
|
||||
|
||||
def test_500_samples_returns_5x(self):
|
||||
"""500 new samples (boundary) should get 5x."""
|
||||
assert get_mixing_ratio(500) == 5
|
||||
|
||||
def test_1000_samples_returns_default(self):
|
||||
"""1000+ samples should get default multiplier (5x)."""
|
||||
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _collect_images
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestCollectImages:
|
||||
"""Tests for _collect_images function."""
|
||||
|
||||
def test_collects_png_files(self, tmp_path: Path):
|
||||
"""Should collect .png files."""
|
||||
(tmp_path / "img1.png").write_bytes(b"fake png")
|
||||
(tmp_path / "img2.png").write_bytes(b"fake png")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 2
|
||||
|
||||
def test_collects_jpg_files(self, tmp_path: Path):
|
||||
"""Should collect .jpg files."""
|
||||
(tmp_path / "img1.jpg").write_bytes(b"fake jpg")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_collects_both_types(self, tmp_path: Path):
|
||||
"""Should collect both .png and .jpg files."""
|
||||
(tmp_path / "img1.png").write_bytes(b"fake png")
|
||||
(tmp_path / "img2.jpg").write_bytes(b"fake jpg")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 2
|
||||
|
||||
def test_ignores_other_files(self, tmp_path: Path):
|
||||
"""Should ignore non-image files."""
|
||||
(tmp_path / "data.txt").write_text("not an image")
|
||||
(tmp_path / "data.yaml").write_text("yaml")
|
||||
(tmp_path / "img.png").write_bytes(b"png")
|
||||
|
||||
images = _collect_images(tmp_path)
|
||||
assert len(images) == 1
|
||||
|
||||
def test_returns_empty_for_nonexistent_dir(self, tmp_path: Path):
|
||||
"""Should return empty list for nonexistent directory."""
|
||||
images = _collect_images(tmp_path / "nonexistent")
|
||||
assert images == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _image_to_label_path
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestImageToLabelPath:
|
||||
"""Tests for _image_to_label_path function."""
|
||||
|
||||
def test_converts_train_image_to_label(self, tmp_path: Path):
|
||||
"""Should convert images/train/img.png to labels/train/img.txt."""
|
||||
image_path = tmp_path / "dataset" / "images" / "train" / "doc1_page1.png"
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.name == "doc1_page1.txt"
|
||||
assert "labels" in str(label_path)
|
||||
assert "train" in str(label_path)
|
||||
|
||||
def test_converts_val_image_to_label(self, tmp_path: Path):
|
||||
"""Should convert images/val/img.jpg to labels/val/img.txt."""
|
||||
image_path = tmp_path / "dataset" / "images" / "val" / "doc2_page3.jpg"
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.name == "doc2_page3.txt"
|
||||
assert "labels" in str(label_path)
|
||||
assert "val" in str(label_path)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test _find_pool_images
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFindPoolImages:
|
||||
"""Tests for _find_pool_images function."""
|
||||
|
||||
def _create_dataset(self, base_path: Path, doc_ids: list[str], split: str = "train") -> None:
|
||||
"""Helper to create a dataset structure with images."""
|
||||
images_dir = base_path / "images" / split
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
for doc_id in doc_ids:
|
||||
(images_dir / f"{doc_id}_page1.png").write_bytes(b"img")
|
||||
(images_dir / f"{doc_id}_page2.png").write_bytes(b"img")
|
||||
|
||||
def test_finds_matching_images(self, tmp_path: Path):
|
||||
"""Should find images matching pool document IDs."""
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
self._create_dataset(tmp_path, [doc_id1, doc_id2])
|
||||
|
||||
pool_ids = {doc_id1}
|
||||
images = _find_pool_images(tmp_path, pool_ids)
|
||||
|
||||
assert len(images) == 2 # 2 pages for doc_id1
|
||||
assert all(doc_id1 in str(img) for img in images)
|
||||
|
||||
def test_ignores_non_pool_images(self, tmp_path: Path):
|
||||
"""Should not return images for documents not in pool."""
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
self._create_dataset(tmp_path, [doc_id1, doc_id2])
|
||||
|
||||
pool_ids = {doc_id1}
|
||||
images = _find_pool_images(tmp_path, pool_ids)
|
||||
|
||||
# Only doc_id1 images should be found
|
||||
for img in images:
|
||||
assert doc_id1 in str(img)
|
||||
assert doc_id2 not in str(img)
|
||||
|
||||
def test_searches_all_splits(self, tmp_path: Path):
|
||||
"""Should search train, val, and test splits."""
|
||||
doc_id = str(uuid4())
|
||||
for split in ("train", "val", "test"):
|
||||
self._create_dataset(tmp_path, [doc_id], split=split)
|
||||
|
||||
images = _find_pool_images(tmp_path, {doc_id})
|
||||
assert len(images) == 6 # 2 pages * 3 splits
|
||||
|
||||
def test_empty_pool_returns_empty(self, tmp_path: Path):
|
||||
"""Should return empty list for empty pool IDs."""
|
||||
self._create_dataset(tmp_path, [str(uuid4())])
|
||||
|
||||
images = _find_pool_images(tmp_path, set())
|
||||
assert images == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test build_mixed_dataset
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBuildMixedDataset:
|
||||
"""Tests for build_mixed_dataset function."""
|
||||
|
||||
def _setup_base_dataset(self, base_path: Path, num_old: int = 20) -> None:
|
||||
"""Create a base dataset with old training images."""
|
||||
for split in ("train", "val"):
|
||||
img_dir = base_path / "images" / split
|
||||
lbl_dir = base_path / "labels" / split
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
count = int(num_old * 0.8) if split == "train" else num_old - int(num_old * 0.8)
|
||||
for i in range(count):
|
||||
doc_id = str(uuid4())
|
||||
img_file = img_dir / f"{doc_id}_page1.png"
|
||||
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
|
||||
img_file.write_bytes(b"fake image data")
|
||||
lbl_file.write_text("0 0.5 0.5 0.1 0.1\n")
|
||||
|
||||
def _setup_pool_images(self, base_path: Path, doc_ids: list[str]) -> None:
|
||||
"""Add pool images to the base dataset."""
|
||||
img_dir = base_path / "images" / "train"
|
||||
lbl_dir = base_path / "labels" / "train"
|
||||
img_dir.mkdir(parents=True, exist_ok=True)
|
||||
lbl_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for doc_id in doc_ids:
|
||||
img_file = img_dir / f"{doc_id}_page1.png"
|
||||
lbl_file = lbl_dir / f"{doc_id}_page1.txt"
|
||||
img_file.write_bytes(b"pool image data")
|
||||
lbl_file.write_text("0 0.5 0.5 0.2 0.2\n")
|
||||
|
||||
@pytest.fixture
|
||||
def base_dataset(self, tmp_path: Path) -> Path:
|
||||
"""Create a base dataset for testing."""
|
||||
base_path = tmp_path / "base_dataset"
|
||||
self._setup_base_dataset(base_path, num_old=20)
|
||||
return base_path
|
||||
|
||||
def test_builds_output_structure(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should create proper YOLO directory structure."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
assert (output_dir / "images" / "train").exists()
|
||||
assert (output_dir / "images" / "val").exists()
|
||||
assert (output_dir / "labels" / "train").exists()
|
||||
assert (output_dir / "labels" / "val").exists()
|
||||
assert (output_dir / "data.yaml").exists()
|
||||
|
||||
def test_returns_correct_metadata(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should return correct counts and metadata."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
assert "data_yaml" in result
|
||||
assert "total_images" in result
|
||||
assert "old_images" in result
|
||||
assert "new_images" in result
|
||||
assert "mixing_ratio" in result
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
def test_mixing_ratio_applied(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Should use correct mixing ratio based on pool size."""
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base_dataset,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
# 5 new samples -> 50x multiplier
|
||||
assert result["mixing_ratio"] == 50
|
||||
|
||||
def test_seed_reproducibility(self, base_dataset: Path, tmp_path: Path):
|
||||
"""Same seed should produce same output."""
|
||||
pool_ids = [uuid4() for _ in range(3)]
|
||||
self._setup_pool_images(base_dataset, [str(pid) for pid in pool_ids])
|
||||
|
||||
out1 = tmp_path / "out1"
|
||||
out2 = tmp_path / "out2"
|
||||
|
||||
r1 = build_mixed_dataset(pool_ids, base_dataset, out1, seed=42)
|
||||
r2 = build_mixed_dataset(pool_ids, base_dataset, out2, seed=42)
|
||||
|
||||
assert r1["old_images"] == r2["old_images"]
|
||||
assert r1["new_images"] == r2["new_images"]
|
||||
assert r1["total_images"] == r2["total_images"]
|
||||
540
tests/services/test_gating_validator.py
Normal file
540
tests/services/test_gating_validator.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""
|
||||
Unit tests for gating validation service.
|
||||
|
||||
Tests the quality gate validation logic for model deployment:
|
||||
- Gate 1: mAP regression validation
|
||||
- Gate 2: detection rate validation
|
||||
- Overall status computation
|
||||
- Full validation workflow with mocked dependencies
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from backend.web.services.gating_validator import (
|
||||
GATE1_PASS_THRESHOLD,
|
||||
GATE1_REVIEW_THRESHOLD,
|
||||
GATE2_PASS_THRESHOLD,
|
||||
classify_gate1,
|
||||
classify_gate2,
|
||||
compute_overall_status,
|
||||
run_gating_validation,
|
||||
)
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
|
||||
class TestClassifyGate1:
|
||||
"""Test Gate 1 classification (mAP drop thresholds)."""
|
||||
|
||||
def test_pass_below_threshold(self):
|
||||
"""Test mAP drop < 0.01 returns pass."""
|
||||
assert classify_gate1(0.009) == "pass"
|
||||
assert classify_gate1(0.005) == "pass"
|
||||
assert classify_gate1(0.0) == "pass"
|
||||
assert classify_gate1(-0.01) == "pass" # negative drop (improvement)
|
||||
|
||||
def test_pass_boundary(self):
|
||||
"""Test mAP drop exactly at pass threshold."""
|
||||
# 0.01 should be review (not pass), since condition is < 0.01
|
||||
assert classify_gate1(GATE1_PASS_THRESHOLD) == "review"
|
||||
|
||||
def test_review_in_range(self):
|
||||
"""Test mAP drop in review range [0.01, 0.03)."""
|
||||
assert classify_gate1(0.01) == "review"
|
||||
assert classify_gate1(0.015) == "review"
|
||||
assert classify_gate1(0.02) == "review"
|
||||
assert classify_gate1(0.029) == "review"
|
||||
|
||||
def test_review_boundary(self):
|
||||
"""Test mAP drop exactly at review threshold."""
|
||||
# 0.03 should be reject (not review), since condition is < 0.03
|
||||
assert classify_gate1(GATE1_REVIEW_THRESHOLD) == "reject"
|
||||
|
||||
def test_reject_above_threshold(self):
|
||||
"""Test mAP drop >= 0.03 returns reject."""
|
||||
assert classify_gate1(0.03) == "reject"
|
||||
assert classify_gate1(0.05) == "reject"
|
||||
assert classify_gate1(0.10) == "reject"
|
||||
assert classify_gate1(1.0) == "reject"
|
||||
|
||||
|
||||
class TestClassifyGate2:
|
||||
"""Test Gate 2 classification (detection rate thresholds)."""
|
||||
|
||||
def test_pass_above_threshold(self):
|
||||
"""Test detection rate >= 0.80 returns pass."""
|
||||
assert classify_gate2(0.80) == "pass"
|
||||
assert classify_gate2(0.85) == "pass"
|
||||
assert classify_gate2(0.99) == "pass"
|
||||
assert classify_gate2(1.0) == "pass"
|
||||
|
||||
def test_pass_boundary(self):
|
||||
"""Test detection rate exactly at pass threshold."""
|
||||
assert classify_gate2(GATE2_PASS_THRESHOLD) == "pass"
|
||||
|
||||
def test_review_below_threshold(self):
|
||||
"""Test detection rate < 0.80 returns review."""
|
||||
assert classify_gate2(0.79) == "review"
|
||||
assert classify_gate2(0.75) == "review"
|
||||
assert classify_gate2(0.50) == "review"
|
||||
assert classify_gate2(0.0) == "review"
|
||||
|
||||
|
||||
class TestComputeOverallStatus:
|
||||
"""Test overall status computation from individual gates."""
|
||||
|
||||
def test_both_pass(self):
|
||||
"""Test both gates pass -> overall pass."""
|
||||
assert compute_overall_status("pass", "pass") == "pass"
|
||||
|
||||
def test_gate1_reject_gate2_pass(self):
|
||||
"""Test any reject -> overall reject."""
|
||||
assert compute_overall_status("reject", "pass") == "reject"
|
||||
|
||||
def test_gate1_pass_gate2_reject(self):
|
||||
"""Test any reject -> overall reject."""
|
||||
assert compute_overall_status("pass", "reject") == "reject"
|
||||
|
||||
def test_both_reject(self):
|
||||
"""Test both reject -> overall reject."""
|
||||
assert compute_overall_status("reject", "reject") == "reject"
|
||||
|
||||
def test_gate1_review_gate2_pass(self):
|
||||
"""Test any review (no reject) -> overall review."""
|
||||
assert compute_overall_status("review", "pass") == "review"
|
||||
|
||||
def test_gate1_pass_gate2_review(self):
|
||||
"""Test any review (no reject) -> overall review."""
|
||||
assert compute_overall_status("pass", "review") == "review"
|
||||
|
||||
def test_both_review(self):
|
||||
"""Test both review -> overall review."""
|
||||
assert compute_overall_status("review", "review") == "review"
|
||||
|
||||
def test_gate1_reject_gate2_review(self):
|
||||
"""Test reject takes precedence over review."""
|
||||
assert compute_overall_status("reject", "review") == "reject"
|
||||
|
||||
def test_gate1_review_gate2_reject(self):
|
||||
"""Test reject takes precedence over review."""
|
||||
assert compute_overall_status("review", "reject") == "reject"
|
||||
|
||||
|
||||
class TestRunGatingValidation:
|
||||
"""Test full gating validation workflow with mocked dependencies."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_version_id(self):
|
||||
"""Generate a model version ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base_model_version_id(self):
|
||||
"""Generate a base model version ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_task_id(self):
|
||||
"""Generate a task ID for testing."""
|
||||
return uuid4()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_base_model(self):
|
||||
"""Create a mock base model with metrics."""
|
||||
model = Mock()
|
||||
model.metrics_mAP = 0.85
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def mock_new_model(self):
|
||||
"""Create a mock new model with metrics."""
|
||||
model = Mock()
|
||||
model.metrics_mAP = 0.82
|
||||
return model
|
||||
|
||||
def test_gate1_pass_gate2_pass(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test validation with both gates passing."""
|
||||
# Setup: base mAP=0.85, new mAP=0.84 -> drop=0.01 (review)
|
||||
# But new model mAP=0.82 -> gate2 pass
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.84}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
# Mock repository
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
# Mock session context
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock YOLO trainer
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
# Execute
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
|
||||
assert result.gate1_original_mAP == 0.85
|
||||
assert result.gate1_new_mAP == 0.84
|
||||
assert result.gate1_mAP_drop == pytest.approx(0.01, abs=1e-6)
|
||||
|
||||
assert result.gate2_status == "pass" # 0.82 >= 0.80
|
||||
assert result.gate2_detection_rate == 0.82
|
||||
|
||||
assert result.overall_status == "review" # Any review -> review
|
||||
|
||||
# Verify DB operations
|
||||
mock_session.add.assert_called()
|
||||
mock_session.commit.assert_called()
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_gate1_reject_due_to_large_drop(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 reject when mAP drop >= 0.03."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.81} # 0.85 - 0.81 = 0.04 (reject)
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "reject"
|
||||
assert result.gate1_mAP_drop == pytest.approx(0.04, abs=1e-6)
|
||||
assert result.overall_status == "reject" # Any reject -> reject
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "reject")
|
||||
|
||||
def test_gate2_review_due_to_low_detection_rate(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 2 review when detection rate < 0.80."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.75 # Below 0.80 threshold
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.845} # Gate 1: 0.85 - 0.845 = 0.005 (pass)
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass"
|
||||
assert result.gate2_status == "review" # 0.75 < 0.80
|
||||
assert result.gate2_detection_rate == 0.75
|
||||
assert result.overall_status == "review"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_no_base_model_skips_gate1(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_task_id,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 passes when no base model is provided."""
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.return_value = mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=None,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass" # Skipped
|
||||
assert result.gate1_original_mAP is None
|
||||
assert result.gate1_new_mAP is None
|
||||
assert result.gate1_mAP_drop is None
|
||||
|
||||
assert result.gate2_status == "pass" # 0.85 >= 0.80
|
||||
assert result.overall_status == "pass"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "pass")
|
||||
|
||||
def test_base_model_without_metrics_skips_gate1(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 passes when base model has no metrics."""
|
||||
mock_base_model.metrics_mAP = None
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "pass" # Skipped due to no base metrics
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "pass"
|
||||
|
||||
def test_validation_failure_marks_gate1_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 review when validation raises exception."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock trainer to raise exception
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.side_effect = RuntimeError("Validation failed")
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # Exception -> review
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "review"
|
||||
|
||||
mock_update.assert_called_once_with(str(mock_model_version_id), "review")
|
||||
|
||||
def test_validation_returns_none_mAP_marks_gate1_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 1 review when validation returns None mAP."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": None} # No mAP returned
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # None mAP -> review
|
||||
assert result.gate1_new_mAP is None
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "review"
|
||||
|
||||
def test_gate2_exception_marks_gate2_review(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test Gate 2 review when accessing new model metrics raises exception."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.82
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.84}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
|
||||
# Mock to raise exception for new model on second call
|
||||
def get_side_effect(id):
|
||||
if str(id) == str(mock_base_model_version_id):
|
||||
return mock_base_model
|
||||
elif str(id) == str(mock_model_version_id):
|
||||
raise RuntimeError("Cannot fetch new model")
|
||||
return None
|
||||
|
||||
mock_repo.get.side_effect = get_side_effect
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
result = run_gating_validation(
|
||||
model_version_id=mock_model_version_id,
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=mock_base_model_version_id,
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=mock_task_id,
|
||||
)
|
||||
|
||||
assert result.gate1_status == "review" # 0.85 - 0.84 = 0.01
|
||||
assert result.gate2_status == "review" # Exception -> review
|
||||
assert result.overall_status == "review"
|
||||
|
||||
def test_string_uuids_accepted(
|
||||
self,
|
||||
mock_model_version_id,
|
||||
mock_base_model_version_id,
|
||||
mock_task_id,
|
||||
mock_base_model,
|
||||
mock_new_model,
|
||||
):
|
||||
"""Test that string UUIDs are accepted and converted properly."""
|
||||
mock_base_model.metrics_mAP = 0.85
|
||||
mock_new_model.metrics_mAP = 0.85
|
||||
|
||||
mock_val_metrics = {"mAP50": 0.85}
|
||||
|
||||
with patch("backend.web.services.gating_validator.ModelVersionRepository") as MockRepo, \
|
||||
patch("backend.web.services.gating_validator.get_session_context") as mock_session_ctx, \
|
||||
patch("shared.training.YOLOTrainer") as MockTrainer, \
|
||||
patch("backend.web.services.gating_validator._update_model_gating_status") as mock_update:
|
||||
|
||||
mock_repo = MockRepo.return_value
|
||||
mock_repo.get.side_effect = lambda id: mock_base_model if str(id) == str(mock_base_model_version_id) else mock_new_model
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session_ctx.return_value.__enter__.return_value = mock_session
|
||||
|
||||
mock_trainer = MockTrainer.return_value
|
||||
mock_trainer.validate.return_value = mock_val_metrics
|
||||
|
||||
# Pass string UUIDs
|
||||
result = run_gating_validation(
|
||||
model_version_id=str(mock_model_version_id),
|
||||
new_model_path="/path/to/model.pt",
|
||||
base_model_version_id=str(mock_base_model_version_id),
|
||||
data_yaml="/path/to/data.yaml",
|
||||
task_id=str(mock_task_id),
|
||||
)
|
||||
|
||||
assert result.model_version_id == mock_model_version_id
|
||||
assert result.task_id == mock_task_id
|
||||
assert result.overall_status == "pass"
|
||||
@@ -1,556 +1,170 @@
|
||||
"""
|
||||
Tests for expand_bbox function.
|
||||
Tests for expand_bbox function with uniform pixel padding.
|
||||
|
||||
Tests verify that bbox expansion works correctly with center-point scaling,
|
||||
directional compensation, max padding clamping, and image boundary handling.
|
||||
Verifies that bbox expansion adds uniform padding on all sides,
|
||||
clamps to image boundaries, and returns integer tuples.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
expand_bbox,
|
||||
ScaleStrategy,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
DEFAULT_STRATEGY,
|
||||
)
|
||||
from shared.bbox import expand_bbox
|
||||
from shared.bbox.scale_strategy import UNIFORM_PAD
|
||||
|
||||
|
||||
class TestExpandBboxCenterScaling:
|
||||
"""Tests for center-point based scaling."""
|
||||
class TestExpandBboxUniformPadding:
|
||||
"""Tests for uniform padding on all sides."""
|
||||
|
||||
def test_center_scaling_expands_symmetrically(self):
|
||||
"""Verify bbox expands symmetrically around center when no extra ratios."""
|
||||
# 100x50 bbox at (100, 200)
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider
|
||||
scale_y=1.4, # 40% taller
|
||||
max_pad_x=1000, # Large to avoid clamping
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_adds_uniform_pad_on_all_sides(self):
|
||||
"""Verify default pad is applied equally on all four sides."""
|
||||
bbox = (100, 200, 300, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Original: width=100, height=50
|
||||
# New: width=120, height=70
|
||||
# Center: (150, 225)
|
||||
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
|
||||
assert result[0] == 90 # x0
|
||||
assert result[1] == 190 # y0
|
||||
assert result[2] == 210 # x1
|
||||
assert result[3] == 260 # y1
|
||||
|
||||
def test_no_scaling_returns_original(self):
|
||||
"""Verify scale=1.0 with no extras returns original bbox."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
assert result == (
|
||||
100 - UNIFORM_PAD,
|
||||
200 - UNIFORM_PAD,
|
||||
300 + UNIFORM_PAD,
|
||||
250 + UNIFORM_PAD,
|
||||
)
|
||||
|
||||
def test_custom_pad_value(self):
|
||||
"""Verify custom pad overrides default."""
|
||||
bbox = (100, 200, 300, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
pad=20,
|
||||
)
|
||||
|
||||
assert result == (100, 200, 200, 250)
|
||||
assert result == (80, 180, 320, 270)
|
||||
|
||||
|
||||
class TestExpandBboxDirectionalCompensation:
|
||||
"""Tests for directional compensation (extra ratios)."""
|
||||
|
||||
def test_extra_top_expands_upward(self):
|
||||
"""Verify extra_top_ratio adds expansion toward top."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Add 50% of height to top
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_zero_pad_returns_original(self):
|
||||
"""Verify pad=0 returns original bbox as integers."""
|
||||
bbox = (100, 200, 300, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
pad=0,
|
||||
)
|
||||
|
||||
# extra_top = 50 * 0.5 = 25
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 175 # y0 = 200 - 25
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
assert result == (100, 200, 300, 250)
|
||||
|
||||
def test_extra_left_expands_leftward(self):
|
||||
"""Verify extra_left_ratio adds expansion toward left."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.8, # Add 80% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_all_field_types_get_same_padding(self):
|
||||
"""Verify no field-specific expansion -- same result regardless of field."""
|
||||
bbox = (100, 200, 300, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
result_a = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
result_b = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
# extra_left = 100 * 0.8 = 80
|
||||
assert result[0] == 20 # x0 = 100 - 80
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_right_expands_rightward(self):
|
||||
"""Verify extra_right_ratio adds expansion toward right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.3, # Add 30% of width to right
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_right = 100 * 0.3 = 30
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_bottom_expands_downward(self):
|
||||
"""Verify extra_bottom_ratio adds expansion toward bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.4, # Add 40% of height to bottom
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_bottom = 50 * 0.4 = 20
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_combined_scaling_and_directional(self):
|
||||
"""Verify scale + directional compensation work together."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider -> 120 width
|
||||
scale_y=1.0, # no height change
|
||||
extra_left_ratio=0.5, # Add 50% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Center: x=150
|
||||
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
|
||||
# After extra_left: x0 = 90 - (100 * 0.5) = 40
|
||||
assert result[0] == 40 # x0
|
||||
assert result[2] == 210 # x1
|
||||
|
||||
|
||||
class TestExpandBboxMaxPadClamping:
|
||||
"""Tests for max padding clamping."""
|
||||
|
||||
def test_max_pad_x_limits_horizontal_expansion(self):
|
||||
"""Verify max_pad_x limits expansion on left and right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=2.0, # Double width (would add 50 each side)
|
||||
scale_y=1.0,
|
||||
max_pad_x=30, # Limit to 30 pixels each side
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
|
||||
# But max_pad_x=30 limits to: x0=70, x1=230
|
||||
assert result[0] == 70 # x0 = 100 - 30
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
|
||||
def test_max_pad_y_limits_vertical_expansion(self):
|
||||
"""Verify max_pad_y limits expansion on top and bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=3.0, # Triple height (would add 50 each side)
|
||||
max_pad_x=1000,
|
||||
max_pad_y=20, # Limit to 20 pixels each side
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: y0=175, y1=275 (50px each side)
|
||||
# But max_pad_y=20 limits to: y0=180, y1=270
|
||||
assert result[1] == 180 # y0 = 200 - 20
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_max_pad_preserves_asymmetry(self):
|
||||
"""Verify max_pad clamping preserves asymmetric expansion."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=1.0, # 100px left expansion
|
||||
extra_right_ratio=0.0, # No right expansion
|
||||
max_pad_x=50, # Limit to 50 pixels
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Left would expand 100, clamped to 50
|
||||
# Right stays at 0
|
||||
assert result[0] == 50 # x0 = 100 - 50
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result_a == result_b
|
||||
|
||||
|
||||
class TestExpandBboxImageBoundaryClamping:
|
||||
"""Tests for image boundary clamping."""
|
||||
"""Tests for clamping to image boundaries."""
|
||||
|
||||
def test_clamps_to_left_boundary(self):
|
||||
"""Verify x0 is clamped to 0."""
|
||||
bbox = (10, 200, 110, 250) # Close to left edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.5, # Would push x0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_clamps_x0_to_zero(self):
|
||||
bbox = (5, 200, 100, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
assert result[0] == 0 # Clamped to 0
|
||||
assert result[0] == 0
|
||||
|
||||
def test_clamps_to_top_boundary(self):
|
||||
"""Verify y0 is clamped to 0."""
|
||||
bbox = (100, 10, 200, 60) # Close to top edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Would push y0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_clamps_y0_to_zero(self):
|
||||
bbox = (100, 3, 300, 50)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
assert result[1] == 0 # Clamped to 0
|
||||
assert result[1] == 0
|
||||
|
||||
def test_clamps_to_right_boundary(self):
|
||||
"""Verify x1 is clamped to image_width."""
|
||||
bbox = (900, 200, 990, 250) # Close to right edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.5, # Would push x1 beyond image_width
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_clamps_x1_to_image_width(self):
|
||||
bbox = (900, 200, 995, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
assert result[2] == 1000 # Clamped to image_width
|
||||
assert result[2] == 1000
|
||||
|
||||
def test_clamps_to_bottom_boundary(self):
|
||||
"""Verify y1 is clamped to image_height."""
|
||||
bbox = (100, 940, 200, 990) # Close to bottom edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
def test_clamps_y1_to_image_height(self):
|
||||
bbox = (100, 900, 300, 995)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
assert result[3] == 1000 # Clamped to image_height
|
||||
assert result[3] == 1000
|
||||
|
||||
def test_corner_bbox_clamps_multiple_sides(self):
|
||||
"""Bbox near top-left corner clamps both x0 and y0."""
|
||||
bbox = (2, 3, 50, 60)
|
||||
|
||||
class TestExpandBboxUnknownField:
|
||||
"""Tests for unknown field handling."""
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
def test_unknown_field_uses_default_strategy(self):
|
||||
"""Verify unknown field types use DEFAULT_STRATEGY."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="unknown_field_xyz",
|
||||
)
|
||||
|
||||
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
|
||||
# Original: width=100, height=50
|
||||
# New: width=115, height=57.5
|
||||
# Center: (150, 225)
|
||||
# x0 = 150 - 57.5 = 92.5 -> 92
|
||||
# x1 = 150 + 57.5 = 207.5 -> 207
|
||||
# y0 = 225 - 28.75 = 196.25 -> 196
|
||||
# y1 = 225 + 28.75 = 253.75 -> 253
|
||||
# But max_pad_x=50 may clamp...
|
||||
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
|
||||
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
|
||||
assert result[0] == 92
|
||||
assert result[2] == 207
|
||||
|
||||
|
||||
class TestExpandBboxWithRealStrategies:
|
||||
"""Tests using actual FIELD_SCALE_STRATEGIES."""
|
||||
|
||||
def test_ocr_number_expands_significantly_upward(self):
|
||||
"""Verify ocr_number field gets significant upward expansion."""
|
||||
bbox = (100, 200, 200, 230) # Small height=30
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
)
|
||||
|
||||
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
|
||||
# y0 should decrease significantly
|
||||
assert result[1] < 200 - 10 # At least 10px upward expansion
|
||||
|
||||
def test_bankgiro_expands_significantly_leftward(self):
|
||||
"""Verify bankgiro field gets significant leftward expansion."""
|
||||
bbox = (200, 200, 300, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
)
|
||||
|
||||
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
|
||||
# x0 should decrease significantly
|
||||
assert result[0] < 200 - 30 # At least 30px leftward expansion
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount field gets rightward expansion for currency."""
|
||||
bbox = (100, 200, 200, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="amount",
|
||||
)
|
||||
|
||||
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
|
||||
# x1 should increase
|
||||
assert result[2] > 200 + 10 # At least 10px rightward expansion
|
||||
assert result[0] == 0
|
||||
assert result[1] == 0
|
||||
assert result[2] == 50 + UNIFORM_PAD
|
||||
assert result[3] == 60 + UNIFORM_PAD
|
||||
|
||||
|
||||
class TestExpandBboxReturnType:
|
||||
"""Tests for return type and value format."""
|
||||
|
||||
def test_returns_tuple_of_four_ints(self):
|
||||
"""Verify return type is tuple of 4 integers."""
|
||||
bbox = (100.5, 200.3, 200.7, 250.9)
|
||||
bbox = (100.5, 200.3, 300.7, 250.9)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 4
|
||||
assert all(isinstance(v, int) for v in result)
|
||||
|
||||
def test_returns_valid_bbox_format(self):
|
||||
"""Verify returned bbox has x0 < x1 and y0 < y1."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
def test_float_bbox_floors_correctly(self):
|
||||
"""Verify float coordinates are converted to int properly."""
|
||||
bbox = (100.7, 200.3, 300.2, 250.8)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=0)
|
||||
|
||||
# int() truncates toward zero
|
||||
assert result == (100, 200, 300, 250)
|
||||
|
||||
def test_returns_valid_bbox_ordering(self):
|
||||
"""Verify x0 < x1 and y0 < y1."""
|
||||
bbox = (100, 200, 300, 250)
|
||||
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 < x1, "x0 should be less than x1"
|
||||
assert y0 < y1, "y0 should be less than y1"
|
||||
assert x0 < x1
|
||||
assert y0 < y1
|
||||
|
||||
|
||||
class TestManualLabelMode:
|
||||
"""Tests for manual_mode parameter."""
|
||||
class TestExpandBboxEdgeCases:
|
||||
"""Tests for edge cases."""
|
||||
|
||||
def test_manual_mode_uses_minimal_padding(self):
|
||||
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
def test_small_bbox_with_large_pad(self):
|
||||
"""Pad larger than bbox still works correctly."""
|
||||
bbox = (100, 200, 105, 203) # 5x3 pixel bbox
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Would normally expand left significantly
|
||||
manual_mode=True,
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000, pad=50)
|
||||
|
||||
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
|
||||
# Should only add 10px padding each side (but scale=1.0 means no scaling)
|
||||
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
|
||||
# Only max_pad=10 applies as a limit, but there's no expansion to limit
|
||||
# So result should be same as original
|
||||
assert result == (100, 200, 200, 250)
|
||||
assert result == (50, 150, 155, 253)
|
||||
|
||||
def test_manual_mode_ignores_field_type(self):
|
||||
"""Verify manual_mode ignores field-specific strategies."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
def test_bbox_at_origin(self):
|
||||
bbox = (0, 0, 50, 30)
|
||||
|
||||
# Different fields should give same result in manual_mode
|
||||
result_bankgiro = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
result_ocr = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
manual_mode=True,
|
||||
)
|
||||
assert result[0] == 0
|
||||
assert result[1] == 0
|
||||
|
||||
assert result_bankgiro == result_ocr
|
||||
def test_bbox_at_image_edge(self):
|
||||
bbox = (950, 970, 1000, 1000)
|
||||
|
||||
def test_manual_mode_vs_auto_mode_different(self):
|
||||
"""Verify manual_mode produces different results than auto mode."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
result = expand_bbox(bbox=bbox, image_width=1000, image_height=1000)
|
||||
|
||||
auto_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Has extra_left_ratio=0.80
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Auto mode should expand more (especially to the left for bankgiro)
|
||||
assert auto_result[0] < manual_result[0] # Auto x0 is more left
|
||||
|
||||
def test_manual_mode_clamps_to_image_bounds(self):
|
||||
"""Verify manual_mode still respects image boundaries."""
|
||||
bbox = (5, 5, 50, 50) # Close to top-left corner
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Should clamp to 0
|
||||
assert result[0] >= 0
|
||||
assert result[1] >= 0
|
||||
assert result[2] == 1000
|
||||
assert result[3] == 1000
|
||||
|
||||
@@ -1,192 +1,24 @@
|
||||
"""
|
||||
Tests for ScaleStrategy configuration.
|
||||
Tests for simplified scale strategy configuration.
|
||||
|
||||
Tests verify that scale strategies are properly defined, immutable,
|
||||
and cover all required fields.
|
||||
Verifies that UNIFORM_PAD constant is properly defined
|
||||
and replaces the old field-specific strategies.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from shared.fields import CLASS_NAMES
|
||||
from shared.bbox.scale_strategy import UNIFORM_PAD
|
||||
|
||||
|
||||
class TestScaleStrategyDataclass:
|
||||
"""Tests for ScaleStrategy dataclass behavior."""
|
||||
class TestUniformPad:
|
||||
"""Tests for UNIFORM_PAD constant."""
|
||||
|
||||
def test_default_strategy_values(self):
|
||||
"""Verify default strategy has expected default values."""
|
||||
strategy = ScaleStrategy()
|
||||
assert strategy.scale_x == 1.15
|
||||
assert strategy.scale_y == 1.15
|
||||
assert strategy.extra_top_ratio == 0.0
|
||||
assert strategy.extra_bottom_ratio == 0.0
|
||||
assert strategy.extra_left_ratio == 0.0
|
||||
assert strategy.extra_right_ratio == 0.0
|
||||
assert strategy.max_pad_x == 50
|
||||
assert strategy.max_pad_y == 50
|
||||
def test_uniform_pad_is_integer(self):
|
||||
assert isinstance(UNIFORM_PAD, int)
|
||||
|
||||
def test_scale_strategy_immutability(self):
|
||||
"""Verify ScaleStrategy is frozen (immutable)."""
|
||||
strategy = ScaleStrategy()
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.scale_x = 2.0 # type: ignore
|
||||
def test_uniform_pad_value_is_15(self):
|
||||
"""15px at 150 DPI provides ~2.5mm real-world padding."""
|
||||
assert UNIFORM_PAD == 15
|
||||
|
||||
def test_custom_strategy_values(self):
|
||||
"""Verify custom values are properly set."""
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.5,
|
||||
scale_y=1.8,
|
||||
extra_top_ratio=0.6,
|
||||
extra_left_ratio=0.8,
|
||||
max_pad_x=100,
|
||||
max_pad_y=150,
|
||||
)
|
||||
assert strategy.scale_x == 1.5
|
||||
assert strategy.scale_y == 1.8
|
||||
assert strategy.extra_top_ratio == 0.6
|
||||
assert strategy.extra_left_ratio == 0.8
|
||||
assert strategy.max_pad_x == 100
|
||||
assert strategy.max_pad_y == 150
|
||||
|
||||
|
||||
class TestDefaultStrategy:
|
||||
"""Tests for DEFAULT_STRATEGY constant."""
|
||||
|
||||
def test_default_strategy_is_scale_strategy(self):
|
||||
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_default_strategy_matches_default_values(self):
|
||||
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
|
||||
expected = ScaleStrategy()
|
||||
assert DEFAULT_STRATEGY == expected
|
||||
|
||||
|
||||
class TestManualLabelStrategy:
|
||||
"""Tests for MANUAL_LABEL_STRATEGY constant."""
|
||||
|
||||
def test_manual_label_strategy_is_scale_strategy(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_manual_label_strategy_has_no_scaling(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
|
||||
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
|
||||
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
|
||||
|
||||
def test_manual_label_strategy_has_no_directional_expansion(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
|
||||
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
|
||||
|
||||
def test_manual_label_strategy_has_small_max_pad(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
|
||||
|
||||
|
||||
class TestFieldScaleStrategies:
|
||||
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
|
||||
|
||||
def test_all_class_names_have_strategies(self):
|
||||
"""Verify all field class names have defined strategies."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in FIELD_SCALE_STRATEGIES, (
|
||||
f"Missing strategy for field: {class_name}"
|
||||
)
|
||||
|
||||
def test_strategies_are_scale_strategy_instances(self):
|
||||
"""Verify all strategies are ScaleStrategy instances."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert isinstance(strategy, ScaleStrategy), (
|
||||
f"Strategy for {field_name} is not a ScaleStrategy"
|
||||
)
|
||||
|
||||
def test_scale_values_are_greater_than_one(self):
|
||||
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.scale_x >= 1.0, (
|
||||
f"{field_name} scale_x should be >= 1.0"
|
||||
)
|
||||
assert strategy.scale_y >= 1.0, (
|
||||
f"{field_name} scale_y should be >= 1.0"
|
||||
)
|
||||
|
||||
def test_extra_ratios_are_non_negative(self):
|
||||
"""Verify all extra ratios are >= 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.extra_top_ratio >= 0, (
|
||||
f"{field_name} extra_top_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_bottom_ratio >= 0, (
|
||||
f"{field_name} extra_bottom_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_left_ratio >= 0, (
|
||||
f"{field_name} extra_left_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_right_ratio >= 0, (
|
||||
f"{field_name} extra_right_ratio should be >= 0"
|
||||
)
|
||||
|
||||
def test_max_pad_values_are_positive(self):
|
||||
"""Verify all max_pad values are > 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.max_pad_x > 0, (
|
||||
f"{field_name} max_pad_x should be > 0"
|
||||
)
|
||||
assert strategy.max_pad_y > 0, (
|
||||
f"{field_name} max_pad_y should be > 0"
|
||||
)
|
||||
|
||||
|
||||
class TestSpecificFieldStrategies:
|
||||
"""Tests for specific field strategy configurations."""
|
||||
|
||||
def test_ocr_number_expands_upward(self):
|
||||
"""Verify ocr_number strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
|
||||
|
||||
def test_bankgiro_expands_leftward(self):
|
||||
"""Verify bankgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
|
||||
|
||||
def test_plusgiro_expands_leftward(self):
|
||||
"""Verify plusgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount strategy expands rightward for currency symbol."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["amount"]
|
||||
assert strategy.extra_right_ratio > 0.0
|
||||
|
||||
def test_invoice_date_expands_upward(self):
|
||||
"""Verify invoice_date strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
|
||||
def test_invoice_due_date_expands_upward_and_leftward(self):
|
||||
"""Verify invoice_due_date strategy expands both up and left."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
|
||||
def test_payment_line_has_minimal_expansion(self):
|
||||
"""Verify payment_line has conservative expansion (machine code)."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
|
||||
# Payment line is machine-readable, needs minimal expansion
|
||||
assert strategy.scale_x <= 1.2
|
||||
assert strategy.scale_y <= 1.3
|
||||
def test_uniform_pad_is_positive(self):
|
||||
assert UNIFORM_PAD > 0
|
||||
|
||||
@@ -171,8 +171,8 @@ class TestGenerateFromMatches:
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_applies_field_specific_expansion(self):
|
||||
"""Verify different fields get different expansion."""
|
||||
def test_applies_uniform_expansion(self):
|
||||
"""Verify all fields get the same uniform expansion."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Same bbox, different fields
|
||||
@@ -199,10 +199,11 @@ class TestGenerateFromMatches:
|
||||
dpi=150
|
||||
)[0]
|
||||
|
||||
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
|
||||
# They should have different widths due to different expansion
|
||||
# Bankgiro expands more to the left
|
||||
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
|
||||
# Uniform expansion: same bbox -> same dimensions (only class_id differs)
|
||||
assert ann_bankgiro.width == ann_invoice.width
|
||||
assert ann_bankgiro.height == ann_invoice.height
|
||||
assert ann_bankgiro.x_center == ann_invoice.x_center
|
||||
assert ann_bankgiro.y_center == ann_invoice.y_center
|
||||
|
||||
def test_enforces_min_bbox_height(self):
|
||||
"""Verify minimum bbox height is enforced."""
|
||||
|
||||
@@ -7,27 +7,23 @@ from pathlib import Path
|
||||
|
||||
from training.yolo.db_dataset import DBYOLODataset
|
||||
from training.yolo.annotation_generator import YOLOAnnotation
|
||||
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
|
||||
from shared.bbox import UNIFORM_PAD
|
||||
from shared.fields import CLASS_NAMES
|
||||
|
||||
|
||||
class TestConvertLabelsWithExpandBbox:
|
||||
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
|
||||
"""Tests for _convert_labels using uniform expand_bbox."""
|
||||
|
||||
def test_convert_labels_uses_expand_bbox(self):
|
||||
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
|
||||
# Create a mock dataset without loading from DB
|
||||
"""Verify _convert_labels calls expand_bbox with uniform padding."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Create annotation for bankgiro (has extra_left_ratio)
|
||||
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
|
||||
# center: (150, 225), width: 100, height: 50
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro
|
||||
x_center=150, # in PDF points
|
||||
x_center=150,
|
||||
y_center=225,
|
||||
width=100,
|
||||
height=50,
|
||||
@@ -35,48 +31,26 @@ class TestConvertLabelsWithExpandBbox:
|
||||
)
|
||||
]
|
||||
|
||||
# Image size in pixels (at 300 DPI)
|
||||
img_width = 2480 # A4 width at 300 DPI
|
||||
img_height = 3508 # A4 height at 300 DPI
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Convert labels
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Should have one label
|
||||
assert labels.shape == (1, 5)
|
||||
|
||||
# Check class_id
|
||||
assert labels[0, 0] == 4
|
||||
|
||||
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
|
||||
# Original bbox at 300 DPI:
|
||||
# x0 = 100 * (300/72) = 416.67
|
||||
# y0 = 200 * (300/72) = 833.33
|
||||
# x1 = 200 * (300/72) = 833.33
|
||||
# y1 = 250 * (300/72) = 1041.67
|
||||
# width_px = 416.67, height_px = 208.33
|
||||
|
||||
# After expand_bbox with bankgiro strategy:
|
||||
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
|
||||
# The x_center should shift left due to extra_left_ratio
|
||||
x_center = labels[0, 1]
|
||||
y_center = labels[0, 2]
|
||||
width = labels[0, 3]
|
||||
height = labels[0, 4]
|
||||
|
||||
# Verify normalized values are in valid range
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 < width <= 1
|
||||
assert 0 < height <= 1
|
||||
|
||||
# Width should be larger than original due to scaling and extra_left
|
||||
# Original normalized width: 416.67 / 2480 = 0.168
|
||||
# After bankgiro expansion it should be wider
|
||||
assert width > 0.168
|
||||
|
||||
def test_convert_labels_different_field_types(self):
|
||||
"""Verify different field types use their specific strategies."""
|
||||
def test_convert_labels_all_fields_get_same_expansion(self):
|
||||
"""Verify all field types get the same uniform expansion."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
@@ -84,7 +58,6 @@ class TestConvertLabelsWithExpandBbox:
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Same bbox for different field types
|
||||
base_annotation = {
|
||||
'x_center': 150,
|
||||
'y_center': 225,
|
||||
@@ -93,30 +66,20 @@ class TestConvertLabelsWithExpandBbox:
|
||||
'confidence': 0.9
|
||||
}
|
||||
|
||||
# OCR number (class_id=3) - has extra_top_ratio=0.60
|
||||
# All field types should get the same uniform expansion
|
||||
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
|
||||
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
|
||||
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
|
||||
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Amount (class_id=6) - has extra_right_ratio=0.30
|
||||
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
|
||||
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
|
||||
# x_center and y_center should be the same (uniform padding is symmetric)
|
||||
assert abs(ocr_labels[0, 1] - bankgiro_labels[0, 1]) < 0.001
|
||||
assert abs(ocr_labels[0, 2] - bankgiro_labels[0, 2]) < 0.001
|
||||
|
||||
# Each field type should have different expansion
|
||||
# OCR should expand more vertically (extra_top)
|
||||
# Bankgiro should expand more to the left
|
||||
# Amount should expand more to the right
|
||||
|
||||
# OCR: extra_top shifts y_center up
|
||||
# Bankgiro: extra_left shifts x_center left
|
||||
# So bankgiro x_center < OCR x_center
|
||||
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
|
||||
|
||||
# OCR has higher scale_y (1.80) than amount (1.35)
|
||||
assert ocr_labels[0, 4] > amount_labels[0, 4]
|
||||
# width and height should also be the same
|
||||
assert abs(ocr_labels[0, 3] - bankgiro_labels[0, 3]) < 0.001
|
||||
assert abs(ocr_labels[0, 4] - bankgiro_labels[0, 4]) < 0.001
|
||||
|
||||
def test_convert_labels_clamps_to_image_bounds(self):
|
||||
"""Verify labels are clamped to image boundaries."""
|
||||
@@ -124,11 +87,10 @@ class TestConvertLabelsWithExpandBbox:
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Annotation near edge of image (in PDF points)
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro - will expand left
|
||||
x_center=30, # Very close to left edge
|
||||
class_id=4,
|
||||
x_center=30,
|
||||
y_center=50,
|
||||
width=40,
|
||||
height=30,
|
||||
@@ -141,11 +103,10 @@ class TestConvertLabelsWithExpandBbox:
|
||||
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# All values should be in valid range
|
||||
assert 0 <= labels[0, 1] <= 1 # x_center
|
||||
assert 0 <= labels[0, 2] <= 1 # y_center
|
||||
assert 0 < labels[0, 3] <= 1 # width
|
||||
assert 0 < labels[0, 4] <= 1 # height
|
||||
assert 0 <= labels[0, 1] <= 1
|
||||
assert 0 <= labels[0, 2] <= 1
|
||||
assert 0 < labels[0, 3] <= 1
|
||||
assert 0 < labels[0, 4] <= 1
|
||||
|
||||
def test_convert_labels_empty_annotations(self):
|
||||
"""Verify empty annotations return empty array."""
|
||||
@@ -162,23 +123,21 @@ class TestConvertLabelsWithExpandBbox:
|
||||
"""Verify minimum height is enforced after expansion."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 50 # Higher minimum
|
||||
dataset.min_bbox_height_px = 50
|
||||
|
||||
# Very small annotation
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=9, # payment_line - minimal expansion
|
||||
class_id=9,
|
||||
x_center=100,
|
||||
y_center=100,
|
||||
width=200,
|
||||
height=5, # Very small height
|
||||
height=5,
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
|
||||
|
||||
# Height should be at least min_bbox_height_px / img_height
|
||||
min_normalized_height = 50 / 3508
|
||||
assert labels[0, 4] >= min_normalized_height
|
||||
|
||||
@@ -190,25 +149,23 @@ class TestCreateAnnotationWithClassName:
|
||||
"""Verify _create_annotation stores class_name for later use."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
|
||||
# Create annotation for invoice_number
|
||||
annotation = dataset._create_annotation(
|
||||
field_name="InvoiceNumber",
|
||||
bbox=[100, 200, 200, 250],
|
||||
score=0.9
|
||||
)
|
||||
|
||||
assert annotation.class_id == 0 # invoice_number class_id
|
||||
assert annotation.class_id == 0
|
||||
|
||||
|
||||
class TestLoadLabelsFromDbWithClassName:
|
||||
"""Tests for _load_labels_from_db preserving field_name for expansion."""
|
||||
|
||||
def test_load_labels_maps_field_names_correctly(self):
|
||||
"""Verify field names are mapped correctly for expand_bbox."""
|
||||
"""Verify field names are mapped correctly."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.min_confidence = 0.7
|
||||
|
||||
# Mock database
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_documents_batch.return_value = {
|
||||
'doc1': {
|
||||
@@ -240,12 +197,7 @@ class TestLoadLabelsFromDbWithClassName:
|
||||
assert 'doc1' in result
|
||||
page_labels, is_scanned, csv_split = result['doc1']
|
||||
|
||||
# Should have 2 annotations on page 0
|
||||
assert 0 in page_labels
|
||||
assert len(page_labels[0]) == 2
|
||||
|
||||
# First annotation: Bankgiro (class_id=4)
|
||||
assert page_labels[0][0].class_id == 4
|
||||
|
||||
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
|
||||
assert page_labels[0][1].class_id == 5
|
||||
|
||||
@@ -53,7 +53,7 @@ class TestTrainingConfigSchema:
|
||||
"""Test default training configuration."""
|
||||
config = TrainingConfig()
|
||||
|
||||
assert config.model_name == "yolo11n.pt"
|
||||
assert config.model_name == "yolo26s.pt"
|
||||
assert config.epochs == 100
|
||||
assert config.batch_size == 16
|
||||
assert config.image_size == 640
|
||||
@@ -63,7 +63,7 @@ class TestTrainingConfigSchema:
|
||||
def test_custom_config(self):
|
||||
"""Test custom training configuration."""
|
||||
config = TrainingConfig(
|
||||
model_name="yolo11s.pt",
|
||||
model_name="yolo26s.pt",
|
||||
epochs=50,
|
||||
batch_size=8,
|
||||
image_size=416,
|
||||
@@ -71,7 +71,7 @@ class TestTrainingConfigSchema:
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
assert config.model_name == "yolo11s.pt"
|
||||
assert config.model_name == "yolo26s.pt"
|
||||
assert config.epochs == 50
|
||||
assert config.batch_size == 8
|
||||
|
||||
@@ -136,7 +136,7 @@ class TestTrainingTaskModel:
|
||||
def test_task_with_config(self):
|
||||
"""Test task with configuration."""
|
||||
config = {
|
||||
"model_name": "yolo11n.pt",
|
||||
"model_name": "yolo26s.pt",
|
||||
"epochs": 100,
|
||||
}
|
||||
task = TrainingTask(
|
||||
|
||||
784
tests/web/test_data_mixer.py
Normal file
784
tests/web/test_data_mixer.py
Normal file
@@ -0,0 +1,784 @@
|
||||
"""
|
||||
Comprehensive unit tests for Data Mixing Service.
|
||||
|
||||
Tests the data mixing service functions for YOLO fine-tuning:
|
||||
- Mixing ratio calculation based on sample counts
|
||||
- Dataset building with old/new sample mixing
|
||||
- Image collection and path conversion
|
||||
- Pool document matching
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.web.services.data_mixer import (
|
||||
get_mixing_ratio,
|
||||
build_mixed_dataset,
|
||||
_collect_images,
|
||||
_image_to_label_path,
|
||||
_find_pool_images,
|
||||
MIXING_RATIOS,
|
||||
DEFAULT_MULTIPLIER,
|
||||
MAX_OLD_SAMPLES,
|
||||
MIN_POOL_SIZE,
|
||||
)
|
||||
|
||||
|
||||
class TestGetMixingRatio:
|
||||
"""Tests for get_mixing_ratio function."""
|
||||
|
||||
def test_mixing_ratio_at_first_threshold(self):
|
||||
"""Test mixing ratio at first threshold boundary (10 samples)."""
|
||||
assert get_mixing_ratio(1) == 50
|
||||
assert get_mixing_ratio(5) == 50
|
||||
assert get_mixing_ratio(10) == 50
|
||||
|
||||
def test_mixing_ratio_at_second_threshold(self):
|
||||
"""Test mixing ratio at second threshold boundary (50 samples)."""
|
||||
assert get_mixing_ratio(11) == 20
|
||||
assert get_mixing_ratio(30) == 20
|
||||
assert get_mixing_ratio(50) == 20
|
||||
|
||||
def test_mixing_ratio_at_third_threshold(self):
|
||||
"""Test mixing ratio at third threshold boundary (200 samples)."""
|
||||
assert get_mixing_ratio(51) == 10
|
||||
assert get_mixing_ratio(100) == 10
|
||||
assert get_mixing_ratio(200) == 10
|
||||
|
||||
def test_mixing_ratio_at_fourth_threshold(self):
|
||||
"""Test mixing ratio at fourth threshold boundary (500 samples)."""
|
||||
assert get_mixing_ratio(201) == 5
|
||||
assert get_mixing_ratio(350) == 5
|
||||
assert get_mixing_ratio(500) == 5
|
||||
|
||||
def test_mixing_ratio_above_all_thresholds(self):
|
||||
"""Test mixing ratio for samples above all thresholds."""
|
||||
assert get_mixing_ratio(501) == DEFAULT_MULTIPLIER
|
||||
assert get_mixing_ratio(1000) == DEFAULT_MULTIPLIER
|
||||
assert get_mixing_ratio(10000) == DEFAULT_MULTIPLIER
|
||||
|
||||
def test_mixing_ratio_boundary_values(self):
|
||||
"""Test exact threshold boundaries match expected ratios."""
|
||||
# Verify threshold boundaries from MIXING_RATIOS
|
||||
for threshold, expected_multiplier in MIXING_RATIOS:
|
||||
assert get_mixing_ratio(threshold) == expected_multiplier
|
||||
# One above threshold should give next ratio
|
||||
if threshold < MIXING_RATIOS[-1][0]:
|
||||
next_idx = MIXING_RATIOS.index((threshold, expected_multiplier)) + 1
|
||||
next_multiplier = MIXING_RATIOS[next_idx][1]
|
||||
assert get_mixing_ratio(threshold + 1) == next_multiplier
|
||||
|
||||
|
||||
class TestCollectImages:
|
||||
"""Tests for _collect_images function."""
|
||||
|
||||
def test_collect_images_empty_directory(self, tmp_path):
|
||||
"""Test collecting images from empty directory."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_collect_images_nonexistent_directory(self, tmp_path):
|
||||
"""Test collecting images from non-existent directory."""
|
||||
images_dir = tmp_path / "nonexistent"
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_collect_png_images(self, tmp_path):
|
||||
"""Test collecting PNG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create PNG files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.png").touch()
|
||||
(images_dir / "img3.png").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 3
|
||||
assert all(img.suffix == ".png" for img in result)
|
||||
# Verify sorted order
|
||||
assert result == sorted(result)
|
||||
|
||||
def test_collect_jpg_images(self, tmp_path):
|
||||
"""Test collecting JPG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create JPG files
|
||||
(images_dir / "img1.jpg").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(img.suffix == ".jpg" for img in result)
|
||||
|
||||
def test_collect_mixed_image_types(self, tmp_path):
|
||||
"""Test collecting both PNG and JPG images."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create mixed files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
(images_dir / "img3.png").touch()
|
||||
(images_dir / "img4.jpg").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 4
|
||||
# PNG files should come first (sorted separately)
|
||||
png_files = [r for r in result if r.suffix == ".png"]
|
||||
jpg_files = [r for r in result if r.suffix == ".jpg"]
|
||||
assert len(png_files) == 2
|
||||
assert len(jpg_files) == 2
|
||||
|
||||
def test_collect_images_ignores_other_files(self, tmp_path):
|
||||
"""Test that non-image files are ignored."""
|
||||
images_dir = tmp_path / "images"
|
||||
images_dir.mkdir()
|
||||
|
||||
# Create various files
|
||||
(images_dir / "img1.png").touch()
|
||||
(images_dir / "img2.jpg").touch()
|
||||
(images_dir / "doc.txt").touch()
|
||||
(images_dir / "data.json").touch()
|
||||
(images_dir / "notes.md").touch()
|
||||
|
||||
result = _collect_images(images_dir)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(img.suffix in [".png", ".jpg"] for img in result)
|
||||
|
||||
|
||||
class TestImageToLabelPath:
|
||||
"""Tests for _image_to_label_path function."""
|
||||
|
||||
def test_image_to_label_path_train(self, tmp_path):
|
||||
"""Test converting train image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "doc123_page1.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "train" / "doc123_page1.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_val(self, tmp_path):
|
||||
"""Test converting val image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "val" / "doc456_page2.jpg"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "val" / "doc456_page2.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_test(self, tmp_path):
|
||||
"""Test converting test image path to label path."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "test" / "doc789_page3.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
expected = base / "labels" / "test" / "doc789_page3.txt"
|
||||
assert label_path == expected
|
||||
|
||||
def test_image_to_label_path_preserves_filename(self, tmp_path):
|
||||
"""Test that filename (without extension) is preserved."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "complex_filename_123_page5.png"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.stem == "complex_filename_123_page5"
|
||||
assert label_path.suffix == ".txt"
|
||||
|
||||
def test_image_to_label_path_jpg_to_txt(self, tmp_path):
|
||||
"""Test that JPG extension is converted to TXT."""
|
||||
base = tmp_path / "dataset"
|
||||
image_path = base / "images" / "train" / "image.jpg"
|
||||
|
||||
label_path = _image_to_label_path(image_path)
|
||||
|
||||
assert label_path.suffix == ".txt"
|
||||
|
||||
|
||||
class TestFindPoolImages:
|
||||
"""Tests for _find_pool_images function."""
|
||||
|
||||
def test_find_pool_images_in_train(self, tmp_path):
|
||||
"""Test finding pool images in train split."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create images
|
||||
(train_dir / f"{doc_id}_page1.png").touch()
|
||||
(train_dir / f"{doc_id}_page2.png").touch()
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 2
|
||||
assert all(doc_id in str(img) for img in result)
|
||||
|
||||
def test_find_pool_images_in_val(self, tmp_path):
|
||||
"""Test finding pool images in val split."""
|
||||
base = tmp_path / "dataset"
|
||||
val_dir = base / "images" / "val"
|
||||
val_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create images
|
||||
(val_dir / f"{doc_id}_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert doc_id in str(result[0])
|
||||
|
||||
def test_find_pool_images_across_splits(self, tmp_path):
|
||||
"""Test finding pool images across train, val, and test splits."""
|
||||
base = tmp_path / "dataset"
|
||||
|
||||
doc_id1 = str(uuid4())
|
||||
doc_id2 = str(uuid4())
|
||||
pool_doc_ids = {doc_id1, doc_id2}
|
||||
|
||||
# Create images in different splits
|
||||
train_dir = base / "images" / "train"
|
||||
val_dir = base / "images" / "val"
|
||||
test_dir = base / "images" / "test"
|
||||
|
||||
train_dir.mkdir(parents=True)
|
||||
val_dir.mkdir(parents=True)
|
||||
test_dir.mkdir(parents=True)
|
||||
|
||||
(train_dir / f"{doc_id1}_page1.png").touch()
|
||||
(val_dir / f"{doc_id1}_page2.png").touch()
|
||||
(test_dir / f"{doc_id2}_page1.png").touch()
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 3
|
||||
doc1_images = [img for img in result if doc_id1 in str(img)]
|
||||
doc2_images = [img for img in result if doc_id2 in str(img)]
|
||||
assert len(doc1_images) == 2
|
||||
assert len(doc2_images) == 1
|
||||
|
||||
def test_find_pool_images_empty_pool(self, tmp_path):
|
||||
"""Test finding images with empty pool."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
(train_dir / "doc123_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, set())
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_find_pool_images_no_matches(self, tmp_path):
|
||||
"""Test finding images when no documents match pool."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
pool_doc_ids = {str(uuid4())}
|
||||
|
||||
(train_dir / "other_doc_page1.png").touch()
|
||||
(train_dir / "another_doc_page1.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_find_pool_images_multiple_pages(self, tmp_path):
|
||||
"""Test finding multiple pages for same document."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
# Create multiple pages
|
||||
for i in range(1, 6):
|
||||
(train_dir / f"{doc_id}_page{i}.png").touch()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 5
|
||||
|
||||
def test_find_pool_images_ignores_non_files(self, tmp_path):
|
||||
"""Test that directories are ignored."""
|
||||
base = tmp_path / "dataset"
|
||||
train_dir = base / "images" / "train"
|
||||
train_dir.mkdir(parents=True)
|
||||
|
||||
doc_id = str(uuid4())
|
||||
pool_doc_ids = {doc_id}
|
||||
|
||||
(train_dir / f"{doc_id}_page1.png").touch()
|
||||
(train_dir / "subdir").mkdir()
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
|
||||
def test_find_pool_images_nonexistent_splits(self, tmp_path):
|
||||
"""Test handling non-existent split directories."""
|
||||
base = tmp_path / "dataset"
|
||||
# Don't create any directories
|
||||
|
||||
pool_doc_ids = {str(uuid4())}
|
||||
|
||||
result = _find_pool_images(base, pool_doc_ids)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestBuildMixedDataset:
|
||||
"""Tests for build_mixed_dataset function."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_base_dataset(self, tmp_path):
|
||||
"""Create a base dataset with old training data."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create old training images and labels
|
||||
for i in range(1, 11):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
for i in range(1, 6):
|
||||
img_path = base / "images" / "val" / f"old_doc_val_{i}_page1.png"
|
||||
label_path = base / "labels" / "val" / f"old_doc_val_{i}_page1.txt"
|
||||
img_path.write_text(f"val image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
return base
|
||||
|
||||
@pytest.fixture
|
||||
def setup_pool_documents(self, tmp_path, setup_base_dataset):
|
||||
"""Create pool documents in base dataset."""
|
||||
base = setup_base_dataset
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
|
||||
# Add pool documents to train split
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
return base, pool_ids
|
||||
|
||||
def test_build_mixed_dataset_basic(self, tmp_path, setup_pool_documents):
|
||||
"""Test basic mixed dataset building."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert "data_yaml" in result
|
||||
assert "total_images" in result
|
||||
assert "old_images" in result
|
||||
assert "new_images" in result
|
||||
assert "mixing_ratio" in result
|
||||
|
||||
# Verify counts - new images should be > 0 (at least some were copied)
|
||||
# Note: new images are split 80/20 and copied without overwriting
|
||||
assert result["new_images"] > 0
|
||||
assert result["old_images"] > 0
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
# Verify output structure
|
||||
assert output_dir.exists()
|
||||
assert (output_dir / "images" / "train").exists()
|
||||
assert (output_dir / "images" / "val").exists()
|
||||
assert (output_dir / "labels" / "train").exists()
|
||||
assert (output_dir / "labels" / "val").exists()
|
||||
|
||||
# Verify data.yaml exists
|
||||
yaml_path = Path(result["data_yaml"])
|
||||
assert yaml_path.exists()
|
||||
yaml_content = yaml_path.read_text()
|
||||
assert "train: images/train" in yaml_content
|
||||
assert "val: images/val" in yaml_content
|
||||
assert "nc:" in yaml_content
|
||||
assert "names:" in yaml_content
|
||||
|
||||
def test_build_mixed_dataset_respects_mixing_ratio(self, tmp_path, setup_pool_documents):
|
||||
"""Test that mixing ratio is correctly applied."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
# With 5 pool documents, get_mixing_ratio(5) returns 50
|
||||
# (because 5 <= 10, first threshold)
|
||||
# So target old_samples = 5 * 50 = 250
|
||||
# But limited by available data: 10 old train + 5 old val + 5 pool = 20 total
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Pool images are in the base dataset, so they can be sampled as "old"
|
||||
# Total available: 20 images (15 pure old + 5 pool images)
|
||||
assert result["old_images"] <= 20 # Can't exceed available in base dataset
|
||||
assert result["old_images"] > 0 # Should have some old data
|
||||
assert result["mixing_ratio"] == 50 # Correct ratio for 5 samples
|
||||
|
||||
def test_build_mixed_dataset_max_old_samples_limit(self, tmp_path):
|
||||
"""Test that MAX_OLD_SAMPLES limit is applied."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create MORE than MAX_OLD_SAMPLES old images
|
||||
for i in range(MAX_OLD_SAMPLES + 500):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
# Create pool documents (100 samples, ratio=10, so target=1000)
|
||||
# But should be capped at MAX_OLD_SAMPLES (3000)
|
||||
pool_ids = [uuid4() for _ in range(100)]
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Should be capped at MAX_OLD_SAMPLES
|
||||
assert result["old_images"] <= MAX_OLD_SAMPLES
|
||||
|
||||
def test_build_mixed_dataset_empty_pool(self, tmp_path, setup_base_dataset):
|
||||
"""Test building dataset with empty pool."""
|
||||
base = setup_base_dataset
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=[],
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# With 0 new samples, all counts should be 0
|
||||
assert result["new_images"] == 0
|
||||
assert result["old_images"] == 0
|
||||
assert result["total_images"] == 0
|
||||
|
||||
def test_build_mixed_dataset_no_old_data(self, tmp_path):
|
||||
"""Test building dataset with ONLY pool data (no separate old data)."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create empty directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create only pool documents
|
||||
# NOTE: These are placed in base dataset train split
|
||||
# So they will be sampled as "old" data first, then skipped as "new"
|
||||
pool_ids = [uuid4() for _ in range(5)]
|
||||
for doc_id in pool_ids:
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page1.png"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page1.txt"
|
||||
img_path.write_text(f"pool image {doc_id}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Pool images are in base dataset, so they get sampled as "old" images
|
||||
# Then when copying "new" images, they're skipped because they already exist
|
||||
# So we expect: old_images > 0, new_images may be 0, total >= 0
|
||||
assert result["total_images"] > 0
|
||||
assert result["total_images"] == result["old_images"] + result["new_images"]
|
||||
|
||||
def test_build_mixed_dataset_train_val_split(self, tmp_path, setup_pool_documents):
|
||||
"""Test that images are split into train/val (80/20)."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Count images in train and val
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
val_images = list((output_dir / "images" / "val").glob("*.png"))
|
||||
|
||||
total_output_images = len(train_images) + len(val_images)
|
||||
|
||||
# Should match total_images count
|
||||
assert total_output_images == result["total_images"]
|
||||
|
||||
# Check approximate 80/20 split (allow some variance due to small sample size)
|
||||
if total_output_images > 0:
|
||||
train_ratio = len(train_images) / total_output_images
|
||||
assert 0.6 <= train_ratio <= 0.9 # Allow some variance
|
||||
|
||||
def test_build_mixed_dataset_reproducible_with_seed(self, tmp_path, setup_pool_documents):
|
||||
"""Test that same seed produces same results."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir1 = tmp_path / "mixed_output1"
|
||||
output_dir2 = tmp_path / "mixed_output2"
|
||||
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir1,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir2,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
# Same counts
|
||||
assert result1["old_images"] == result2["old_images"]
|
||||
assert result1["new_images"] == result2["new_images"]
|
||||
|
||||
# Same files in train/val
|
||||
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
|
||||
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
|
||||
assert train_files1 == train_files2
|
||||
|
||||
def test_build_mixed_dataset_different_seeds(self, tmp_path, setup_pool_documents):
|
||||
"""Test that different seeds produce different sampling."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir1 = tmp_path / "mixed_output1"
|
||||
output_dir2 = tmp_path / "mixed_output2"
|
||||
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir1,
|
||||
seed=123,
|
||||
)
|
||||
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir2,
|
||||
seed=456,
|
||||
)
|
||||
|
||||
# Both should have processed images
|
||||
assert result1["total_images"] > 0
|
||||
assert result2["total_images"] > 0
|
||||
|
||||
# Both should have the same mixing ratio (based on pool size)
|
||||
assert result1["mixing_ratio"] == result2["mixing_ratio"]
|
||||
|
||||
# File distribution in train/val may differ due to different shuffling
|
||||
train_files1 = {f.name for f in (output_dir1 / "images" / "train").glob("*.png")}
|
||||
train_files2 = {f.name for f in (output_dir2 / "images" / "train").glob("*.png")}
|
||||
|
||||
# With different seeds, we expect some difference in file distribution
|
||||
# But this is not strictly guaranteed, so we just verify both have files
|
||||
assert len(train_files1) > 0
|
||||
assert len(train_files2) > 0
|
||||
|
||||
def test_build_mixed_dataset_copies_labels(self, tmp_path, setup_pool_documents):
|
||||
"""Test that corresponding label files are copied."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Count labels
|
||||
train_labels = list((output_dir / "labels" / "train").glob("*.txt"))
|
||||
val_labels = list((output_dir / "labels" / "val").glob("*.txt"))
|
||||
|
||||
# Each image should have a corresponding label
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
val_images = list((output_dir / "images" / "val").glob("*.png"))
|
||||
|
||||
# Allow label count to be <= image count (in case some labels are missing)
|
||||
assert len(train_labels) <= len(train_images)
|
||||
assert len(val_labels) <= len(val_images)
|
||||
|
||||
def test_build_mixed_dataset_skips_duplicate_files(self, tmp_path, setup_pool_documents):
|
||||
"""Test behavior when running build_mixed_dataset multiple times."""
|
||||
base, pool_ids = setup_pool_documents
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
# First build
|
||||
result1 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
initial_count = result1["total_images"]
|
||||
|
||||
# Find a file in output and modify it
|
||||
train_images = list((output_dir / "images" / "train").glob("*.png"))
|
||||
if len(train_images) > 0:
|
||||
test_file = train_images[0]
|
||||
test_file.write_text("modified content")
|
||||
|
||||
# Second build with same seed
|
||||
result2 = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# The implementation uses shutil.copy2 which WILL overwrite
|
||||
# So the file will be restored to original content
|
||||
# Just verify the build completed successfully
|
||||
assert result2["total_images"] >= 0
|
||||
|
||||
# Verify the file was overwritten (shutil.copy2 overwrites by default)
|
||||
content = test_file.read_text()
|
||||
assert content != "modified content" # Should be restored
|
||||
|
||||
def test_build_mixed_dataset_handles_jpg_images(self, tmp_path):
|
||||
"""Test that JPG images are handled correctly."""
|
||||
base = tmp_path / "base_dataset"
|
||||
|
||||
# Create directory structure
|
||||
for split in ("train", "val"):
|
||||
(base / "images" / split).mkdir(parents=True)
|
||||
(base / "labels" / split).mkdir(parents=True)
|
||||
|
||||
# Create JPG images as old data
|
||||
for i in range(1, 6):
|
||||
img_path = base / "images" / "train" / f"old_doc_{i}_page1.jpg"
|
||||
label_path = base / "labels" / "train" / f"old_doc_{i}_page1.txt"
|
||||
img_path.write_text(f"jpg image {i}")
|
||||
label_path.write_text(f"0 0.5 0.5 0.1 0.1")
|
||||
|
||||
# Create pool with JPG - use multiple pages to ensure at least one gets copied
|
||||
pool_ids = [uuid4()]
|
||||
doc_id = pool_ids[0]
|
||||
for page_num in range(1, 4):
|
||||
img_path = base / "images" / "train" / f"{doc_id}_page{page_num}.jpg"
|
||||
label_path = base / "labels" / "train" / f"{doc_id}_page{page_num}.txt"
|
||||
img_path.write_text(f"pool jpg {doc_id} page {page_num}")
|
||||
label_path.write_text(f"1 0.5 0.5 0.2 0.2")
|
||||
|
||||
output_dir = tmp_path / "mixed_output"
|
||||
|
||||
result = build_mixed_dataset(
|
||||
pool_document_ids=pool_ids,
|
||||
base_dataset_path=base,
|
||||
output_dir=output_dir,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
# Should have some new JPG images (at least 1 from the pool)
|
||||
assert result["new_images"] > 0
|
||||
assert result["old_images"] > 0
|
||||
|
||||
# Verify JPG files exist in output
|
||||
all_images = list((output_dir / "images" / "train").glob("*.jpg")) + \
|
||||
list((output_dir / "images" / "val").glob("*.jpg"))
|
||||
assert len(all_images) > 0
|
||||
|
||||
|
||||
class TestConstants:
|
||||
"""Tests for module constants."""
|
||||
|
||||
def test_mixing_ratios_structure(self):
|
||||
"""Test MIXING_RATIOS constant structure."""
|
||||
assert isinstance(MIXING_RATIOS, list)
|
||||
assert len(MIXING_RATIOS) == 4
|
||||
|
||||
# Verify format: (threshold, multiplier)
|
||||
for item in MIXING_RATIOS:
|
||||
assert isinstance(item, tuple)
|
||||
assert len(item) == 2
|
||||
assert isinstance(item[0], int)
|
||||
assert isinstance(item[1], int)
|
||||
|
||||
# Verify thresholds are ascending
|
||||
thresholds = [t for t, _ in MIXING_RATIOS]
|
||||
assert thresholds == sorted(thresholds)
|
||||
|
||||
# Verify multipliers are descending
|
||||
multipliers = [m for _, m in MIXING_RATIOS]
|
||||
assert multipliers == sorted(multipliers, reverse=True)
|
||||
|
||||
def test_default_multiplier(self):
|
||||
"""Test DEFAULT_MULTIPLIER constant."""
|
||||
assert DEFAULT_MULTIPLIER == 5
|
||||
assert DEFAULT_MULTIPLIER == MIXING_RATIOS[-1][1]
|
||||
|
||||
def test_max_old_samples(self):
|
||||
"""Test MAX_OLD_SAMPLES constant."""
|
||||
assert MAX_OLD_SAMPLES == 3000
|
||||
assert MAX_OLD_SAMPLES > 0
|
||||
|
||||
def test_min_pool_size(self):
|
||||
"""Test MIN_POOL_SIZE constant."""
|
||||
assert MIN_POOL_SIZE == 50
|
||||
assert MIN_POOL_SIZE > 0
|
||||
@@ -310,7 +310,7 @@ class TestSchedulerDatasetStatusUpdates:
|
||||
try:
|
||||
scheduler._execute_task(
|
||||
task_id=task_id,
|
||||
config={"model_name": "yolo11n.pt"},
|
||||
config={"model_name": "yolo26s.pt"},
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
467
tests/web/test_finetune_pool.py
Normal file
467
tests/web/test_finetune_pool.py
Normal file
@@ -0,0 +1,467 @@
|
||||
"""
|
||||
Tests for Fine-Tune Pool feature.
|
||||
|
||||
Tests cover:
|
||||
1. FineTunePoolEntry database model
|
||||
2. PoolAddRequest/PoolStatsResponse schemas
|
||||
3. Chain prevention logic
|
||||
4. Pool threshold enforcement
|
||||
5. Model lineage fields on ModelVersion
|
||||
6. Gating enforcement on model activation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4, UUID
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Database Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestFineTunePoolEntryModel:
|
||||
"""Tests for FineTunePoolEntry model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""FineTunePoolEntry should have correct defaults."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
entry = FineTunePoolEntry(document_id=uuid4())
|
||||
assert entry.entry_id is not None
|
||||
assert entry.is_verified is False
|
||||
assert entry.verified_at is None
|
||||
assert entry.verified_by is None
|
||||
assert entry.added_by is None
|
||||
assert entry.reason is None
|
||||
|
||||
def test_creates_with_all_fields(self):
|
||||
"""FineTunePoolEntry should accept all fields."""
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
|
||||
doc_id = uuid4()
|
||||
entry = FineTunePoolEntry(
|
||||
document_id=doc_id,
|
||||
added_by="admin",
|
||||
reason="user_reported_failure",
|
||||
is_verified=True,
|
||||
verified_by="reviewer",
|
||||
)
|
||||
assert entry.document_id == doc_id
|
||||
assert entry.added_by == "admin"
|
||||
assert entry.reason == "user_reported_failure"
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "reviewer"
|
||||
|
||||
|
||||
class TestGatingResultModel:
|
||||
"""Tests for GatingResult model."""
|
||||
|
||||
def test_creates_with_defaults(self):
|
||||
"""GatingResult should have correct defaults."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
model_version_id = uuid4()
|
||||
result = GatingResult(
|
||||
model_version_id=model_version_id,
|
||||
gate1_status="pass",
|
||||
gate2_status="pass",
|
||||
overall_status="pass",
|
||||
)
|
||||
assert result.result_id is not None
|
||||
assert result.model_version_id == model_version_id
|
||||
assert result.gate1_status == "pass"
|
||||
assert result.gate2_status == "pass"
|
||||
assert result.overall_status == "pass"
|
||||
assert result.gate1_mAP_drop is None
|
||||
assert result.gate2_detection_rate is None
|
||||
|
||||
def test_creates_with_full_metrics(self):
|
||||
"""GatingResult should store full metrics."""
|
||||
from backend.data.admin_models import GatingResult
|
||||
|
||||
result = GatingResult(
|
||||
model_version_id=uuid4(),
|
||||
gate1_status="review",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.93,
|
||||
gate1_mAP_drop=0.02,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.85,
|
||||
gate2_total_samples=100,
|
||||
gate2_detected_samples=85,
|
||||
overall_status="review",
|
||||
)
|
||||
assert result.gate1_original_mAP == 0.95
|
||||
assert result.gate1_new_mAP == 0.93
|
||||
assert result.gate1_mAP_drop == 0.02
|
||||
assert result.gate2_detection_rate == 0.85
|
||||
|
||||
|
||||
class TestModelVersionLineage:
|
||||
"""Tests for ModelVersion lineage fields."""
|
||||
|
||||
def test_default_model_type_is_base(self):
|
||||
"""ModelVersion should default to 'base' model_type."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="test-model",
|
||||
model_path="/path/to/model.pt",
|
||||
)
|
||||
assert mv.model_type == "base"
|
||||
assert mv.base_model_version_id is None
|
||||
assert mv.base_training_dataset_id is None
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
def test_finetune_model_type(self):
|
||||
"""ModelVersion should support 'finetune' type with lineage."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
base_id = uuid4()
|
||||
dataset_id = uuid4()
|
||||
mv = ModelVersion(
|
||||
version="v2.0",
|
||||
name="finetuned-model",
|
||||
model_path="/path/to/ft_model.pt",
|
||||
model_type="finetune",
|
||||
base_model_version_id=base_id,
|
||||
base_training_dataset_id=dataset_id,
|
||||
gating_status="pending",
|
||||
)
|
||||
assert mv.model_type == "finetune"
|
||||
assert mv.base_model_version_id == base_id
|
||||
assert mv.base_training_dataset_id == dataset_id
|
||||
assert mv.gating_status == "pending"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Schemas
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolSchemas:
|
||||
"""Tests for pool Pydantic schemas."""
|
||||
|
||||
def test_pool_add_request_defaults(self):
|
||||
"""PoolAddRequest should have default reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(document_id="550e8400-e29b-41d4-a716-446655440001")
|
||||
assert req.document_id == "550e8400-e29b-41d4-a716-446655440001"
|
||||
assert req.reason == "user_reported_failure"
|
||||
|
||||
def test_pool_add_request_custom_reason(self):
|
||||
"""PoolAddRequest should accept custom reason."""
|
||||
from backend.web.schemas.admin.pool import PoolAddRequest
|
||||
|
||||
req = PoolAddRequest(
|
||||
document_id="550e8400-e29b-41d4-a716-446655440001",
|
||||
reason="manual_addition",
|
||||
)
|
||||
assert req.reason == "manual_addition"
|
||||
|
||||
def test_pool_stats_response(self):
|
||||
"""PoolStatsResponse should compute readiness correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolStatsResponse
|
||||
|
||||
# Not ready
|
||||
stats = PoolStatsResponse(
|
||||
total_entries=30,
|
||||
verified_entries=20,
|
||||
unverified_entries=10,
|
||||
is_ready=False,
|
||||
)
|
||||
assert stats.is_ready is False
|
||||
assert stats.min_required == 50
|
||||
|
||||
# Ready
|
||||
stats_ready = PoolStatsResponse(
|
||||
total_entries=80,
|
||||
verified_entries=60,
|
||||
unverified_entries=20,
|
||||
is_ready=True,
|
||||
)
|
||||
assert stats_ready.is_ready is True
|
||||
|
||||
def test_pool_entry_item(self):
|
||||
"""PoolEntryItem should serialize correctly."""
|
||||
from backend.web.schemas.admin.pool import PoolEntryItem
|
||||
|
||||
entry = PoolEntryItem(
|
||||
entry_id="entry-uuid",
|
||||
document_id="doc-uuid",
|
||||
is_verified=True,
|
||||
verified_at=datetime.utcnow(),
|
||||
verified_by="admin",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert entry.is_verified is True
|
||||
assert entry.verified_by == "admin"
|
||||
|
||||
def test_gating_result_item(self):
|
||||
"""GatingResultItem should serialize all gate fields."""
|
||||
from backend.web.schemas.admin.pool import GatingResultItem
|
||||
|
||||
item = GatingResultItem(
|
||||
result_id="result-uuid",
|
||||
model_version_id="model-uuid",
|
||||
gate1_status="pass",
|
||||
gate1_original_mAP=0.95,
|
||||
gate1_new_mAP=0.94,
|
||||
gate1_mAP_drop=0.01,
|
||||
gate2_status="pass",
|
||||
gate2_detection_rate=0.90,
|
||||
gate2_total_samples=50,
|
||||
gate2_detected_samples=45,
|
||||
overall_status="pass",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
assert item.gate1_status == "pass"
|
||||
assert item.overall_status == "pass"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Chain Prevention
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestChainPrevention:
|
||||
"""Tests for fine-tune chain prevention logic."""
|
||||
|
||||
def test_rejects_finetune_from_finetune_model(self):
|
||||
"""Should reject training when base model is already a fine-tune."""
|
||||
# Simulate the chain prevention check from datasets.py
|
||||
model_type = "finetune"
|
||||
base_model_version_id = str(uuid4())
|
||||
|
||||
# This should trigger rejection
|
||||
assert model_type == "finetune"
|
||||
|
||||
def test_allows_finetune_from_base_model(self):
|
||||
"""Should allow training when base model is a base model."""
|
||||
model_type = "base"
|
||||
assert model_type != "finetune"
|
||||
|
||||
def test_allows_fresh_training(self):
|
||||
"""Should allow fresh training (no base model)."""
|
||||
base_model_version_id = None
|
||||
assert base_model_version_id is None # No chain check needed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool Threshold
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolThreshold:
|
||||
"""Tests for minimum pool size enforcement."""
|
||||
|
||||
def test_min_pool_size_constant(self):
|
||||
"""MIN_POOL_SIZE should be 50."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
assert MIN_POOL_SIZE == 50
|
||||
|
||||
def test_pool_below_threshold_blocks_finetune(self):
|
||||
"""Pool with fewer than 50 verified entries should block fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 30
|
||||
assert verified_count < MIN_POOL_SIZE
|
||||
|
||||
def test_pool_at_threshold_allows_finetune(self):
|
||||
"""Pool with exactly 50 verified entries should allow fine-tuning."""
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = 50
|
||||
assert verified_count >= MIN_POOL_SIZE
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Gating Enforcement on Activation
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGatingEnforcement:
|
||||
"""Tests for gating enforcement when activating models."""
|
||||
|
||||
def test_base_model_skips_gating(self):
|
||||
"""Base models should have gating_status 'skipped'."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
mv = ModelVersion(
|
||||
version="v1.0",
|
||||
name="base",
|
||||
model_path="/model.pt",
|
||||
model_type="base",
|
||||
)
|
||||
# Base models skip gating - activation should work
|
||||
assert mv.model_type == "base"
|
||||
# Gating should not block base model activation
|
||||
|
||||
def test_finetune_model_rejected_blocks_activation(self):
|
||||
"""Fine-tuned models with 'reject' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "reject"
|
||||
|
||||
# Simulates the check in models.py activation endpoint
|
||||
should_block = model_type == "finetune" and gating_status == "reject"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pending_blocks_activation(self):
|
||||
"""Fine-tuned models with 'pending' gating should block activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pending"
|
||||
|
||||
should_block = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block is True
|
||||
|
||||
def test_finetune_model_pass_allows_activation(self):
|
||||
"""Fine-tuned models with 'pass' gating should allow activation."""
|
||||
model_type = "finetune"
|
||||
gating_status = "pass"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
|
||||
def test_finetune_model_review_allows_with_warning(self):
|
||||
"""Fine-tuned models with 'review' gating should allow but warn."""
|
||||
model_type = "finetune"
|
||||
gating_status = "review"
|
||||
|
||||
should_block_reject = model_type == "finetune" and gating_status == "reject"
|
||||
should_block_pending = model_type == "finetune" and gating_status == "pending"
|
||||
assert should_block_reject is False
|
||||
assert should_block_pending is False
|
||||
# Should include warning in message
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Pool API Route Registration
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestPoolRouteRegistration:
|
||||
"""Tests for pool route registration."""
|
||||
|
||||
def test_pool_routes_registered(self):
|
||||
"""Pool routes should be registered on training router."""
|
||||
from backend.web.api.v1.admin.training import create_training_router
|
||||
|
||||
router = create_training_router()
|
||||
paths = [route.path for route in router.routes]
|
||||
|
||||
assert any("/pool" in p for p in paths)
|
||||
assert any("/pool/stats" in p for p in paths)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test Scheduler Fine-Tune Parameter Override
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestSchedulerFineTuneParams:
|
||||
"""Tests for scheduler fine-tune parameter overrides."""
|
||||
|
||||
def test_finetune_detected_from_base_model_path(self):
|
||||
"""Scheduler should detect fine-tune mode from base_model_path."""
|
||||
config = {"base_model_path": "/path/to/base_model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is True
|
||||
|
||||
def test_fresh_training_not_finetune(self):
|
||||
"""Scheduler should not enable fine-tune for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
assert is_finetune is False
|
||||
|
||||
def test_finetune_defaults_correct_epochs(self):
|
||||
"""Fine-tune should default to 10 epochs."""
|
||||
config = {"base_model_path": "/path/to/model.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
|
||||
if is_finetune:
|
||||
epochs = config.get("epochs", 10)
|
||||
learning_rate = config.get("learning_rate", 0.001)
|
||||
else:
|
||||
epochs = config.get("epochs", 100)
|
||||
learning_rate = config.get("learning_rate", 0.01)
|
||||
|
||||
assert epochs == 10
|
||||
assert learning_rate == 0.001
|
||||
|
||||
def test_model_lineage_set_for_finetune(self):
|
||||
"""Scheduler should set model_type and base_model_version_id for fine-tune."""
|
||||
config = {
|
||||
"base_model_path": "/path/to/model.pt",
|
||||
"base_model_version_id": str(uuid4()),
|
||||
}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "finetune"
|
||||
assert base_model_version_id is not None
|
||||
assert gating_status == "pending"
|
||||
|
||||
def test_model_lineage_skipped_for_base(self):
|
||||
"""Scheduler should set model_type='base' for fresh training."""
|
||||
config = {"model_name": "yolo26s.pt"}
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
assert model_type == "base"
|
||||
assert gating_status == "skipped"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Test TrainingConfig freeze/cos_lr
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestTrainingConfigFineTuneFields:
|
||||
"""Tests for freeze and cos_lr fields in shared TrainingConfig."""
|
||||
|
||||
def test_default_freeze_is_zero(self):
|
||||
"""TrainingConfig freeze should default to 0."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.freeze == 0
|
||||
|
||||
def test_default_cos_lr_is_false(self):
|
||||
"""TrainingConfig cos_lr should default to False."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="test.pt",
|
||||
data_yaml="data.yaml",
|
||||
)
|
||||
assert config.cos_lr is False
|
||||
|
||||
def test_finetune_config(self):
|
||||
"""TrainingConfig should accept fine-tune parameters."""
|
||||
from shared.training import TrainingConfig
|
||||
|
||||
config = TrainingConfig(
|
||||
model_path="base_model.pt",
|
||||
data_yaml="data.yaml",
|
||||
epochs=10,
|
||||
learning_rate=0.001,
|
||||
freeze=10,
|
||||
cos_lr=True,
|
||||
)
|
||||
assert config.freeze == 10
|
||||
assert config.cos_lr is True
|
||||
assert config.epochs == 10
|
||||
assert config.learning_rate == 0.001
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
Tests for Training Export with expand_bbox integration.
|
||||
Tests for Training Export with uniform expand_bbox integration.
|
||||
|
||||
Tests the export endpoint's integration with field-specific bbox expansion.
|
||||
Tests the export endpoint's integration with uniform bbox expansion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.bbox import expand_bbox
|
||||
from shared.bbox import expand_bbox, UNIFORM_PAD
|
||||
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
@@ -17,149 +17,87 @@ class TestExpandBboxForExport:
|
||||
|
||||
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
|
||||
"""Verify expand_bbox works with pixel-to-normalized conversion."""
|
||||
# Annotation stored as normalized coords
|
||||
x_center_norm = 0.5
|
||||
y_center_norm = 0.5
|
||||
width_norm = 0.1
|
||||
height_norm = 0.05
|
||||
|
||||
# Image dimensions
|
||||
img_width = 2480 # A4 at 300 DPI
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Convert to pixel coords
|
||||
x_center_px = x_center_norm * img_width
|
||||
y_center_px = y_center_norm * img_height
|
||||
width_px = width_norm * img_width
|
||||
height_px = height_norm * img_height
|
||||
|
||||
# Convert to corner coords
|
||||
x0 = x_center_px - width_px / 2
|
||||
y0 = y_center_px - height_px / 2
|
||||
x1 = x_center_px + width_px / 2
|
||||
y1 = y_center_px + height_px / 2
|
||||
|
||||
# Apply expansion
|
||||
class_name = "invoice_number"
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify expanded bbox is larger
|
||||
assert ex0 < x0 # Left expanded
|
||||
assert ey0 < y0 # Top expanded
|
||||
assert ex1 > x1 # Right expanded
|
||||
assert ey1 > y1 # Bottom expanded
|
||||
assert ex0 < x0
|
||||
assert ey0 < y0
|
||||
assert ex1 > x1
|
||||
assert ey1 > y1
|
||||
|
||||
# Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify valid normalized coords
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 <= new_width <= 1
|
||||
assert 0 <= new_height <= 1
|
||||
|
||||
def test_expand_bbox_manual_mode_minimal_expansion(self):
|
||||
"""Verify manual annotations use minimal expansion."""
|
||||
# Small bbox
|
||||
def test_expand_bbox_uniform_for_all_sources(self):
|
||||
"""Verify all annotation sources get the same uniform expansion."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Auto mode (field-specific expansion)
|
||||
auto_result = expand_bbox(
|
||||
# All sources now get the same uniform expansion
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
# Manual mode (minimal expansion)
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=True,
|
||||
expected = (
|
||||
100 - UNIFORM_PAD,
|
||||
100 - UNIFORM_PAD,
|
||||
200 + UNIFORM_PAD,
|
||||
150 + UNIFORM_PAD,
|
||||
)
|
||||
|
||||
# Auto expansion should be larger than manual
|
||||
auto_width = auto_result[2] - auto_result[0]
|
||||
manual_width = manual_result[2] - manual_result[0]
|
||||
assert auto_width > manual_width
|
||||
|
||||
auto_height = auto_result[3] - auto_result[1]
|
||||
manual_height = manual_result[3] - manual_result[1]
|
||||
assert auto_height > manual_height
|
||||
|
||||
def test_expand_bbox_different_sources_use_correct_mode(self):
|
||||
"""Verify different annotation sources use correct expansion mode."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Define source to manual_mode mapping
|
||||
source_mode_mapping = {
|
||||
"manual": True, # Manual annotations -> minimal expansion
|
||||
"auto": False, # Auto-labeled -> field-specific expansion
|
||||
"imported": True, # Imported (from CSV) -> minimal expansion
|
||||
}
|
||||
|
||||
results = {}
|
||||
for source, manual_mode in source_mode_mapping.items():
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="ocr_number",
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
results[source] = result
|
||||
|
||||
# Auto should have largest expansion
|
||||
auto_area = (results["auto"][2] - results["auto"][0]) * \
|
||||
(results["auto"][3] - results["auto"][1])
|
||||
manual_area = (results["manual"][2] - results["manual"][0]) * \
|
||||
(results["manual"][3] - results["manual"][1])
|
||||
imported_area = (results["imported"][2] - results["imported"][0]) * \
|
||||
(results["imported"][3] - results["imported"][1])
|
||||
|
||||
assert auto_area > manual_area
|
||||
assert auto_area > imported_area
|
||||
# Manual and imported should be the same (both use minimal mode)
|
||||
assert manual_area == imported_area
|
||||
assert result == expected
|
||||
|
||||
def test_expand_bbox_all_field_types_work(self):
|
||||
"""Verify expand_bbox works for all field types."""
|
||||
"""Verify expand_bbox works for all field types (same result)."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
for class_name in CLASS_NAMES:
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
# All fields should produce the same result with uniform padding
|
||||
first_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
)
|
||||
|
||||
# Verify result is a valid bbox
|
||||
assert len(result) == 4
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
assert len(first_result) == 4
|
||||
x0, y0, x1, y1 = first_result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
|
||||
|
||||
class TestExportAnnotationExpansion:
|
||||
@@ -167,7 +105,6 @@ class TestExportAnnotationExpansion:
|
||||
|
||||
def test_annotation_bbox_conversion_workflow(self):
|
||||
"""Test full annotation bbox conversion workflow."""
|
||||
# Simulate stored annotation (normalized coords)
|
||||
class MockAnnotation:
|
||||
class_id = FIELD_CLASS_IDS["invoice_number"]
|
||||
class_name = "invoice_number"
|
||||
@@ -181,7 +118,6 @@ class TestExportAnnotationExpansion:
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Step 1: Convert normalized to pixel corner coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
@@ -189,38 +125,27 @@ class TestExportAnnotationExpansion:
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Step 2: Determine manual_mode based on source
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Step 3: Apply expand_bbox
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Step 4: Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify expansion happened (auto mode)
|
||||
assert new_width > ann.width
|
||||
assert new_height > ann.height
|
||||
|
||||
# Verify valid YOLO format
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 < new_width <= 1
|
||||
assert 0 < new_height <= 1
|
||||
|
||||
def test_export_applies_expansion_to_each_annotation(self):
|
||||
"""Test that export applies expansion to each annotation."""
|
||||
# Simulate multiple annotations with different sources
|
||||
# Use smaller bboxes so manual mode padding has visible effect
|
||||
def test_export_applies_uniform_expansion_to_all_annotations(self):
|
||||
"""Test that export applies uniform expansion to all annotations."""
|
||||
annotations = [
|
||||
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
|
||||
@@ -232,7 +157,6 @@ class TestExportAnnotationExpansion:
|
||||
|
||||
expanded_annotations = []
|
||||
for ann in annotations:
|
||||
# Convert to pixel coords
|
||||
half_w = (ann["width"] * img_width) / 2
|
||||
half_h = (ann["height"] * img_height) / 2
|
||||
x0 = ann["x_center"] * img_width - half_w
|
||||
@@ -240,19 +164,12 @@ class TestExportAnnotationExpansion:
|
||||
x1 = ann["x_center"] * img_width + half_w
|
||||
y1 = ann["y_center"] * img_height + half_h
|
||||
|
||||
# Determine manual_mode
|
||||
manual_mode = ann["source"] in ("manual", "imported")
|
||||
|
||||
# Apply expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann["class_name"],
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized
|
||||
expanded_annotations.append({
|
||||
"class_name": ann["class_name"],
|
||||
"source": ann["source"],
|
||||
@@ -262,106 +179,48 @@ class TestExportAnnotationExpansion:
|
||||
"height": (ey1 - ey0) / img_height,
|
||||
})
|
||||
|
||||
# Verify auto-labeled annotation expanded more than manual/imported
|
||||
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
|
||||
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
|
||||
|
||||
# Auto mode should expand more than manual mode
|
||||
# (auto has larger scale factors and max_pad)
|
||||
assert auto_ann["width"] > manual_ann["width"]
|
||||
assert auto_ann["height"] > manual_ann["height"]
|
||||
|
||||
# All annotations should be expanded (at least slightly for manual mode)
|
||||
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
|
||||
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
|
||||
# Width and height should be >= original (expansion or equal, with small tolerance)
|
||||
tolerance = 0.01 # 1% tolerance for integer rounding
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance), \
|
||||
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance), \
|
||||
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
|
||||
# All annotations get the same expansion
|
||||
tolerance = 0.01
|
||||
for orig, exp in zip(annotations, expanded_annotations):
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance)
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance)
|
||||
|
||||
|
||||
class TestExpandBboxEdgeCases:
|
||||
"""Tests for edge cases in export bbox expansion."""
|
||||
|
||||
def test_bbox_at_image_edge_left(self):
|
||||
"""Test bbox at left edge of image."""
|
||||
bbox = (0, 100, 50, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Left edge should be clamped to 0
|
||||
assert result[0] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_right(self):
|
||||
"""Test bbox at right edge of image."""
|
||||
bbox = (2400, 100, 2480, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Right edge should be clamped to image width
|
||||
assert result[2] <= img_width
|
||||
assert result[2] <= 2480
|
||||
|
||||
def test_bbox_at_image_edge_top(self):
|
||||
"""Test bbox at top edge of image."""
|
||||
bbox = (100, 0, 200, 50)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Top edge should be clamped to 0
|
||||
assert result[1] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_bottom(self):
|
||||
"""Test bbox at bottom edge of image."""
|
||||
bbox = (100, 3400, 200, 3508)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Bottom edge should be clamped to image height
|
||||
assert result[3] <= img_height
|
||||
assert result[3] <= 3508
|
||||
|
||||
def test_very_small_bbox(self):
|
||||
"""Test very small bbox gets expanded."""
|
||||
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
bbox = (100, 100, 105, 105)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
result = expand_bbox(bbox=bbox, image_width=2480, image_height=3508)
|
||||
|
||||
# Should still produce a valid expanded bbox
|
||||
assert result[2] > result[0]
|
||||
assert result[3] > result[1]
|
||||
|
||||
Reference in New Issue
Block a user