This commit is contained in:
Yaojia Wang
2026-02-03 22:29:53 +01:00
parent 4c7fc3015c
commit c2c8f2dd04
10 changed files with 786 additions and 68 deletions

BIN
.coverage

Binary file not shown.

View File

@@ -0,0 +1,25 @@
"""
Domain Layer
Business logic separated from technical implementation.
Contains document classification and invoice validation logic.
"""
from backend.domain.document_classifier import (
ClassificationResult,
DocumentClassifier,
)
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationIssue,
ValidationResult,
)
from backend.domain.utils import has_value
__all__ = [
"ClassificationResult",
"DocumentClassifier",
"InvoiceValidator",
"ValidationIssue",
"ValidationResult",
"has_value",
]

View File

@@ -0,0 +1,108 @@
"""
Document Classifier
Business logic for classifying documents based on extracted fields.
Separates classification logic from inference pipeline.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ClassificationResult:
"""
Immutable result of document classification.
Attributes:
document_type: Either "invoice" or "letter"
confidence: Confidence score between 0.0 and 1.0
reason: Human-readable explanation of classification
"""
document_type: str
confidence: float
reason: str
class DocumentClassifier:
"""
Classifies documents as invoice or letter based on extracted fields.
Classification Rules:
1. If payment_line is present -> invoice (high confidence)
2. If 2+ invoice indicators present -> invoice (medium confidence)
3. If 1 invoice indicator present -> invoice (lower confidence)
4. Otherwise -> letter
Invoice indicator fields:
- payment_line (strongest indicator)
- OCR
- Amount
- Bankgiro
- Plusgiro
- InvoiceNumber
"""
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
{
"payment_line",
"OCR",
"Amount",
"Bankgiro",
"Plusgiro",
"InvoiceNumber",
}
)
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
"""
Classify document type based on extracted fields.
Args:
fields: Dictionary of field names to extracted values.
Empty strings or whitespace-only strings are treated as missing.
Returns:
Immutable ClassificationResult with type, confidence, and reason.
"""
# Rule 1: payment_line is the strongest indicator
if has_value(fields.get("payment_line")):
return ClassificationResult(
document_type="invoice",
confidence=0.95,
reason="payment_line detected",
)
# Count present invoice indicators (excluding payment_line already checked)
present_indicators = [
field
for field in self.INVOICE_INDICATOR_FIELDS
if field != "payment_line" and has_value(fields.get(field))
]
indicator_count = len(present_indicators)
# Rule 2: Multiple indicators -> invoice with medium-high confidence
if indicator_count >= 2:
return ClassificationResult(
document_type="invoice",
confidence=0.8,
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
)
# Rule 3: Single indicator -> invoice with lower confidence
if indicator_count == 1:
return ClassificationResult(
document_type="invoice",
confidence=0.6,
reason=f"1 invoice indicator present: {present_indicators[0]}",
)
# Rule 4: No indicators -> letter
return ClassificationResult(
document_type="letter",
confidence=0.7,
reason="no invoice indicators found",
)

View File

@@ -0,0 +1,141 @@
"""
Invoice Validator
Business logic for validating extracted invoice fields.
Checks for required fields, format validity, and confidence thresholds.
"""
from __future__ import annotations
from dataclasses import dataclass
from backend.domain.utils import has_value
@dataclass(frozen=True)
class ValidationIssue:
"""
Single validation issue.
Attributes:
field: Name of the field with the issue
severity: One of "error", "warning", "info"
message: Human-readable description of the issue
"""
field: str
severity: str
message: str
@dataclass(frozen=True)
class ValidationResult:
"""
Immutable result of invoice validation.
Attributes:
is_valid: True if no errors (warnings are allowed)
issues: Tuple of validation issues found
confidence: Average confidence score of validated fields
"""
is_valid: bool
issues: tuple[ValidationIssue, ...]
confidence: float
class InvoiceValidator:
"""
Validates extracted invoice fields for completeness and consistency.
Validation Rules:
1. Required fields must be present (Amount)
2. At least one payment reference should be present (warning if missing)
3. Field confidence should be above threshold (warning if below)
Required fields:
- Amount
Payment reference fields (at least one expected):
- OCR
- Bankgiro
- Plusgiro
- payment_line
"""
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
DEFAULT_MIN_CONFIDENCE: float = 0.5
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
"""
Initialize validator.
Args:
min_confidence: Minimum confidence threshold for valid fields.
Fields below this threshold produce warnings.
"""
self._min_confidence = min_confidence
def validate(
self,
fields: dict[str, str | None],
confidence: dict[str, float],
) -> ValidationResult:
"""
Validate extracted invoice fields.
Args:
fields: Dictionary of field names to extracted values
confidence: Dictionary of field names to confidence scores
Returns:
Immutable ValidationResult with validity status and issues
"""
issues: list[ValidationIssue] = []
# Check required fields
for field in self.REQUIRED_FIELDS:
if not has_value(fields.get(field)):
issues.append(
ValidationIssue(
field=field,
severity="error",
message=f"Required field '{field}' is missing",
)
)
# Check payment reference (at least one expected)
has_payment_ref = any(
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
)
if not has_payment_ref:
issues.append(
ValidationIssue(
field="payment_reference",
severity="warning",
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
)
)
# Check confidence thresholds
for field, conf in confidence.items():
if conf < self._min_confidence:
issues.append(
ValidationIssue(
field=field,
severity="warning",
message=f"Low confidence ({conf:.2f}) for field '{field}'",
)
)
# Calculate overall validity
has_errors = any(i.severity == "error" for i in issues)
avg_confidence = (
sum(confidence.values()) / len(confidence) if confidence else 0.0
)
return ValidationResult(
is_valid=not has_errors,
issues=tuple(issues),
confidence=avg_confidence,
)

View File

@@ -0,0 +1,23 @@
"""
Domain Layer Utilities
Shared helper functions for domain layer classes.
"""
from __future__ import annotations
def has_value(value: str | None) -> bool:
"""
Check if a field value is present and non-empty.
Args:
value: Field value to check
Returns:
True if value is a non-empty, non-whitespace string
"""
if value is None:
return False
if not isinstance(value, str):
return bool(value)
return bool(value.strip())

View File

@@ -1,21 +1,24 @@
"""
Inference Service
Inference Service (Adapter Layer)
Business logic for invoice field extraction.
Orchestrates technical pipeline and business domain logic.
Acts as adapter between API layer and internal components.
"""
from __future__ import annotations
import io
import logging
import time
import uuid
from contextlib import contextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Callable, Generator
import numpy as np
from PIL import Image
from backend.domain.document_classifier import DocumentClassifier
from backend.web.services.storage_helpers import get_storage_helper
if TYPE_CHECKING:
@@ -50,9 +53,12 @@ class ServiceResult:
class InferenceService:
"""
Service for running invoice field extraction.
Service for running invoice field extraction (Adapter Pattern).
Orchestrates:
- Technical layer: InferencePipeline, YOLODetector
- Business layer: DocumentClassifier
Encapsulates YOLO detection and OCR extraction logic.
Supports dynamic model loading from database.
"""
@@ -61,6 +67,7 @@ class InferenceService:
model_config: ModelConfig,
storage_config: StorageConfig,
model_path_resolver: ModelPathResolver | None = None,
document_classifier: DocumentClassifier | None = None,
) -> None:
"""
Initialize inference service.
@@ -71,12 +78,19 @@ class InferenceService:
model_path_resolver: Optional function to resolve model path from database.
If provided, will be called to get active model path.
If returns None, falls back to model_config.model_path.
document_classifier: Optional custom classifier (uses default if None)
"""
self.model_config = model_config
self.storage_config = storage_config
self._model_path_resolver = model_path_resolver
# Technical layer (lazy initialized)
self._pipeline = None
self._detector = None
# Business layer (eagerly initialized, no heavy resources)
self._classifier = document_classifier or DocumentClassifier()
self._is_initialized = False
self._current_model_path: Path | None = None
self._business_features_enabled = False
@@ -219,22 +233,12 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections for visualization
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
result.detections = self._format_detections(pipeline_result.raw_detections)
# Save visualization if requested
if save_visualization and pipeline_result.raw_detections:
@@ -293,22 +297,12 @@ class InferenceService:
result.success = pipeline_result.success
result.errors = pipeline_result.errors
# Determine document type based on payment_line presence
# If no payment_line found, it's likely a letter, not an invoice
if not result.fields.get('payment_line'):
result.document_type = "letter"
else:
result.document_type = "invoice"
# Business layer: classify document type
classification = self._classifier.classify(result.fields)
result.document_type = classification.document_type
# Get raw detections
result.detections = [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in pipeline_result.raw_detections
]
result.detections = self._format_detections(pipeline_result.raw_detections)
# Include business features if extracted
if extract_line_items:
@@ -329,10 +323,19 @@ class InferenceService:
result.processing_time_ms = (time.time() - start_time) * 1000
return result
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
"""Save visualization image with detections."""
from ultralytics import YOLO
def _format_detections(self, raw_detections: list) -> list[dict]:
"""Format raw detections for response."""
return [
{
"field": d.class_name,
"confidence": d.confidence,
"bbox": list(d.bbox),
}
for d in raw_detections
]
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
"""Save visualization image with detections using existing detector."""
# Get storage helper for results directory
storage = get_storage_helper()
results_dir = storage.get_results_base_path()
@@ -340,9 +343,8 @@ class InferenceService:
logger.warning("Cannot save visualization: local storage not available")
return None
# Load model and run prediction with visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(image_path), verbose=False)
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(image_path), verbose=False)
# Save annotated image
output_path = results_dir / f"{doc_id}_result.png"
@@ -351,11 +353,20 @@ class InferenceService:
return output_path
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
"""Save visualization for PDF (first page)."""
@contextmanager
def _temp_image_file(
self, results_dir: Path, doc_id: str
) -> Generator[Path, None, None]:
"""Context manager for temporary image file with guaranteed cleanup."""
temp_path = results_dir / f"{doc_id}_temp.png"
try:
yield temp_path
finally:
temp_path.unlink(missing_ok=True)
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
"""Save visualization for PDF (first page) using existing detector."""
from shared.pdf.renderer import render_pdf_to_images
from ultralytics import YOLO
import io
# Get storage helper for results directory
storage = get_storage_helper()
@@ -369,20 +380,19 @@ class InferenceService:
pdf_path, dpi=self.model_config.dpi
):
image = Image.open(io.BytesIO(image_bytes))
temp_path = results_dir / f"{doc_id}_temp.png"
image.save(temp_path)
# Run YOLO and save visualization
model = YOLO(str(self.model_config.model_path))
results = model.predict(str(temp_path), verbose=False)
# Use context manager for temp file to guarantee cleanup
with self._temp_image_file(results_dir, doc_id) as temp_path:
image.save(temp_path)
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
# Reuse self._detector instead of creating new YOLO instance
results = self._detector.model.predict(str(temp_path), verbose=False)
# Cleanup temp file
temp_path.unlink(missing_ok=True)
return output_path
output_path = results_dir / f"{doc_id}_result.png"
for r in results:
r.save(filename=str(output_path))
return output_path
# If no pages rendered
return None

1
tests/domain/__init__.py Normal file
View File

@@ -0,0 +1 @@
# Domain layer tests

View File

@@ -0,0 +1,176 @@
"""
Tests for DocumentClassifier - TDD RED phase.
Test document type classification based on extracted fields.
"""
import pytest
from backend.domain.document_classifier import DocumentClassifier, ClassificationResult
class TestDocumentClassifier:
"""Test document classification logic."""
@pytest.fixture
def classifier(self) -> DocumentClassifier:
"""Create classifier instance."""
return DocumentClassifier()
# ==================== Invoice Detection Tests ====================
def test_classify_with_payment_line_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""Payment line is the strongest invoice indicator."""
fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.9
assert "payment_line" in result.reason
def test_classify_with_multiple_indicators_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""Multiple invoice indicators -> invoice with medium confidence."""
fields = {
"Amount": "1200.00",
"Bankgiro": "123-4567",
"payment_line": None,
}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.7
def test_classify_with_ocr_and_amount_returns_invoice(
self, classifier: DocumentClassifier
) -> None:
"""OCR + Amount is typical invoice pattern."""
fields = {
"OCR": "123456789012",
"Amount": "500.00",
}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert result.confidence >= 0.7
def test_classify_with_single_indicator_returns_invoice_lower_confidence(
self, classifier: DocumentClassifier
) -> None:
"""Single indicator -> invoice but lower confidence."""
fields = {"Amount": "100.00"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
assert 0.5 <= result.confidence < 0.8
def test_classify_with_invoice_number_only(
self, classifier: DocumentClassifier
) -> None:
"""Invoice number alone suggests invoice."""
fields = {"InvoiceNumber": "INV-2024-001"}
result = classifier.classify(fields)
assert result.document_type == "invoice"
# ==================== Letter Detection Tests ====================
def test_classify_with_no_indicators_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""No invoice indicators -> letter."""
fields: dict[str, str | None] = {}
result = classifier.classify(fields)
assert result.document_type == "letter"
assert result.confidence >= 0.5
def test_classify_with_empty_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""All fields empty or None -> letter."""
fields = {
"payment_line": None,
"OCR": None,
"Amount": None,
"Bankgiro": None,
}
result = classifier.classify(fields)
assert result.document_type == "letter"
def test_classify_with_only_non_indicator_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Fields that don't indicate invoice -> letter."""
fields = {
"CustomerNumber": "C12345",
"SupplierOrgNumber": "556677-8899",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
# ==================== Edge Cases ====================
def test_classify_with_empty_string_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Empty strings should be treated as missing."""
fields = {
"payment_line": "",
"Amount": "",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
def test_classify_with_whitespace_only_fields_returns_letter(
self, classifier: DocumentClassifier
) -> None:
"""Whitespace-only strings should be treated as missing."""
fields = {
"payment_line": " ",
"Amount": "\t\n",
}
result = classifier.classify(fields)
assert result.document_type == "letter"
# ==================== ClassificationResult Immutability ====================
def test_classification_result_is_immutable(
self, classifier: DocumentClassifier
) -> None:
"""ClassificationResult should be a frozen dataclass."""
fields = {"payment_line": "test"}
result = classifier.classify(fields)
with pytest.raises((AttributeError, TypeError)):
result.document_type = "modified" # type: ignore
def test_classification_result_has_required_fields(
self, classifier: DocumentClassifier
) -> None:
"""ClassificationResult must have document_type, confidence, reason."""
fields = {"Amount": "100.00"}
result = classifier.classify(fields)
assert hasattr(result, "document_type")
assert hasattr(result, "confidence")
assert hasattr(result, "reason")
assert isinstance(result.document_type, str)
assert isinstance(result.confidence, float)
assert isinstance(result.reason, str)

View File

@@ -0,0 +1,232 @@
"""
Tests for InvoiceValidator - TDD RED phase.
Test invoice field validation logic.
"""
import pytest
from backend.domain.invoice_validator import (
InvoiceValidator,
ValidationResult,
ValidationIssue,
)
class TestInvoiceValidator:
"""Test invoice validation logic."""
@pytest.fixture
def validator(self) -> InvoiceValidator:
"""Create validator instance with default settings."""
return InvoiceValidator()
@pytest.fixture
def validator_strict(self) -> InvoiceValidator:
"""Create validator with strict confidence threshold."""
return InvoiceValidator(min_confidence=0.8)
# ==================== Valid Invoice Tests ====================
def test_validate_complete_invoice_is_valid(
self, validator: InvoiceValidator
) -> None:
"""Complete invoice with all required fields is valid."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
"Bankgiro": "123-4567",
}
confidence = {
"Amount": 0.95,
"OCR": 0.90,
"Bankgiro": 0.85,
}
result = validator.validate(fields, confidence)
assert result.is_valid is True
assert len([i for i in result.issues if i.severity == "error"]) == 0
def test_validate_invoice_with_payment_line_is_valid(
self, validator: InvoiceValidator
) -> None:
"""Invoice with payment_line as payment reference is valid."""
fields = {
"Amount": "500.00",
"payment_line": "# 123 # 500 00 5 > 308#",
}
confidence = {"Amount": 0.9, "payment_line": 0.85}
result = validator.validate(fields, confidence)
assert result.is_valid is True
# ==================== Invalid Invoice Tests ====================
def test_validate_missing_amount_is_invalid(
self, validator: InvoiceValidator
) -> None:
"""Missing Amount field should produce error."""
fields = {
"OCR": "123456789012",
"Bankgiro": "123-4567",
}
confidence = {"OCR": 0.9, "Bankgiro": 0.85}
result = validator.validate(fields, confidence)
assert result.is_valid is False
error_fields = [i.field for i in result.issues if i.severity == "error"]
assert "Amount" in error_fields
def test_validate_missing_payment_reference_produces_warning(
self, validator: InvoiceValidator
) -> None:
"""Missing all payment references should produce warning."""
fields = {"Amount": "1200.00"}
confidence = {"Amount": 0.9}
result = validator.validate(fields, confidence)
# Missing payment ref is warning, not error
warning_fields = [i.field for i in result.issues if i.severity == "warning"]
assert "payment_reference" in warning_fields
# ==================== Confidence Threshold Tests ====================
def test_validate_low_confidence_produces_warning(
self, validator: InvoiceValidator
) -> None:
"""Fields below confidence threshold should produce warning."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
}
confidence = {
"Amount": 0.9,
"OCR": 0.3, # Below default threshold of 0.5
}
result = validator.validate(fields, confidence)
low_conf_warnings = [
i for i in result.issues
if i.severity == "warning" and "confidence" in i.message.lower()
]
assert len(low_conf_warnings) > 0
def test_validate_strict_threshold_more_warnings(
self, validator_strict: InvoiceValidator
) -> None:
"""Strict validator should produce more warnings."""
fields = {
"Amount": "1200.00",
"OCR": "123456789012",
}
confidence = {
"Amount": 0.7, # Below 0.8 threshold
"OCR": 0.6, # Below 0.8 threshold
}
result = validator_strict.validate(fields, confidence)
low_conf_warnings = [
i for i in result.issues
if i.severity == "warning" and "confidence" in i.message.lower()
]
assert len(low_conf_warnings) >= 2
# ==================== Edge Cases ====================
def test_validate_empty_fields_is_invalid(
self, validator: InvoiceValidator
) -> None:
"""Empty fields dict should be invalid."""
fields: dict[str, str | None] = {}
confidence: dict[str, float] = {}
result = validator.validate(fields, confidence)
assert result.is_valid is False
def test_validate_none_field_values_treated_as_missing(
self, validator: InvoiceValidator
) -> None:
"""None values should be treated as missing."""
fields = {
"Amount": None,
"OCR": "123456789012",
}
confidence = {"OCR": 0.9}
result = validator.validate(fields, confidence)
assert result.is_valid is False
error_fields = [i.field for i in result.issues if i.severity == "error"]
assert "Amount" in error_fields
def test_validate_empty_string_treated_as_missing(
self, validator: InvoiceValidator
) -> None:
"""Empty string should be treated as missing."""
fields = {
"Amount": "",
"OCR": "123456789012",
}
confidence = {"OCR": 0.9}
result = validator.validate(fields, confidence)
assert result.is_valid is False
# ==================== ValidationResult Properties ====================
def test_validation_result_is_immutable(
self, validator: InvoiceValidator
) -> None:
"""ValidationResult should be a frozen dataclass."""
fields = {"Amount": "100.00", "OCR": "123"}
confidence = {"Amount": 0.9, "OCR": 0.9}
result = validator.validate(fields, confidence)
with pytest.raises((AttributeError, TypeError)):
result.is_valid = False # type: ignore
def test_validation_result_issues_is_tuple(
self, validator: InvoiceValidator
) -> None:
"""Issues should be a tuple (immutable)."""
fields = {"Amount": "100.00"}
confidence = {"Amount": 0.9}
result = validator.validate(fields, confidence)
assert isinstance(result.issues, tuple)
def test_validation_result_has_confidence(
self, validator: InvoiceValidator
) -> None:
"""ValidationResult should have confidence score."""
fields = {"Amount": "100.00", "OCR": "123"}
confidence = {"Amount": 0.9, "OCR": 0.8}
result = validator.validate(fields, confidence)
assert hasattr(result, "confidence")
assert 0.0 <= result.confidence <= 1.0
# ==================== ValidationIssue Tests ====================
def test_validation_issue_has_required_fields(
self, validator: InvoiceValidator
) -> None:
"""ValidationIssue must have field, severity, message."""
fields: dict[str, str | None] = {}
confidence: dict[str, float] = {}
result = validator.validate(fields, confidence)
assert len(result.issues) > 0
issue = result.issues[0]
assert hasattr(issue, "field")
assert hasattr(issue, "severity")
assert hasattr(issue, "message")
assert issue.severity in ("error", "warning", "info")

View File

@@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering:
@patch('backend.pipeline.pipeline.InferencePipeline')
@patch('backend.pipeline.yolo_detector.YOLODetector')
@patch('shared.pdf.renderer.render_pdf_to_images')
@patch('ultralytics.YOLO')
def test_pdf_visualization_imports_correctly(
self,
mock_yolo_class,
mock_render_pdf,
mock_yolo_detector,
mock_pipeline,
@@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering:
This catches the import error we had with:
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
"""
# Setup mocks
# Setup mocks for detector
mock_detector_instance = Mock()
mock_pipeline_instance = Mock()
mock_model = Mock()
mock_result = Mock()
mock_result.save = Mock()
mock_model.predict.return_value = [mock_result]
mock_detector_instance.model = mock_model
mock_yolo_detector.return_value = mock_detector_instance
# Setup mock for pipeline
mock_pipeline_instance = Mock()
mock_pipeline.return_value = mock_pipeline_instance
# Initialize service to setup _detector
inference_service.initialize()
# Create a fake PDF path
pdf_path = tmp_path / "test.pdf"
pdf_path.touch()
@@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering:
img.save(image_bytes, format='PNG')
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
# Mock YOLO
mock_model_instance = Mock()
mock_result = Mock()
mock_result.save = Mock()
mock_model_instance.predict.return_value = [mock_result]
mock_yolo_class.return_value = mock_model_instance
# This should not raise ImportError
# This should not raise ImportError and should use self._detector.model
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
# Verify import was successful
# Verify import was successful and detector.model was used
mock_render_pdf.assert_called_once()
mock_model.predict.assert_called_once()
assert result_path is not None