This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -289,6 +289,16 @@ class ModelVersion(SQLModel, table=True):
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Model lineage
model_type: str = Field(default="base", max_length=20, index=True)
# "base" = trained from pretrained YOLO, "finetune" = fine-tuned from base model
base_model_version_id: UUID | None = Field(default=None, index=True)
# Points to the base model this was fine-tuned from (None for base models)
base_training_dataset_id: UUID | None = Field(default=None, index=True)
# The dataset used for original base training (for data mixing old samples)
gating_status: str = Field(default="pending", max_length=20, index=True)
# Deployment gating: pending, pass, review, reject, skipped
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
@@ -317,6 +327,64 @@ class ModelVersion(SQLModel, table=True):
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Fine-Tune Pool
# =============================================================================
class FineTunePoolEntry(SQLModel, table=True):
"""Document in the fine-tune pool for incremental model improvement."""
__tablename__ = "finetune_pool_entries"
entry_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
added_by: str | None = Field(default=None, max_length=255)
reason: str | None = Field(default=None, max_length=255)
# Reason: user_reported_failure, manual_addition
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Deployment Gating
# =============================================================================
class GatingResult(SQLModel, table=True):
"""Model deployment gating validation result."""
__tablename__ = "gating_results"
result_id: UUID = Field(default_factory=uuid4, primary_key=True)
model_version_id: UUID = Field(foreign_key="model_versions.version_id", index=True)
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id")
# Gate 1: Regression validation (original test set mAP)
gate1_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
gate1_original_mAP: float | None = Field(default=None)
gate1_new_mAP: float | None = Field(default=None)
gate1_mAP_drop: float | None = Field(default=None)
# Gate 2: New sample validation (detection rate on pool docs)
gate2_status: str = Field(default="pending", max_length=20)
gate2_detection_rate: float | None = Field(default=None)
gate2_total_samples: int | None = Field(default=None)
gate2_detected_samples: int | None = Field(default=None)
# Overall
overall_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
reviewer_notes: str | None = Field(default=None)
reviewed_by: str | None = Field(default=None, max_length=255)
reviewed_at: datetime | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -13,6 +13,7 @@ from backend.data.repositories.training_task_repository import TrainingTaskRepos
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
from backend.data.repositories.finetune_pool_repository import FineTunePoolRepository
__all__ = [
"BaseRepository",
@@ -23,4 +24,5 @@ __all__ = [
"DatasetRepository",
"ModelVersionRepository",
"BatchUploadRepository",
"FineTunePoolRepository",
]

View File

@@ -0,0 +1,131 @@
"""
Fine-Tune Pool Repository
Manages the fine-tune pool: accumulated user-reported failure documents
for incremental model improvement.
"""
import logging
from datetime import datetime
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import FineTunePoolEntry
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class FineTunePoolRepository(BaseRepository[FineTunePoolEntry]):
"""Repository for fine-tune pool management."""
def add_document(
self,
document_id: str | UUID,
added_by: str | None = None,
reason: str | None = None,
) -> FineTunePoolEntry:
"""Add a document to the fine-tune pool."""
with get_session_context() as session:
entry = FineTunePoolEntry(
document_id=UUID(str(document_id)),
added_by=added_by,
reason=reason,
)
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def get_entry(self, entry_id: str | UUID) -> FineTunePoolEntry | None:
"""Get a pool entry by ID."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if entry:
session.expunge(entry)
return entry
def get_by_document(self, document_id: str | UUID) -> FineTunePoolEntry | None:
"""Get pool entry for a document."""
with get_session_context() as session:
result = session.exec(
select(FineTunePoolEntry).where(
FineTunePoolEntry.document_id == UUID(str(document_id))
)
).first()
if result:
session.expunge(result)
return result
def get_paginated(
self,
verified_only: bool = False,
limit: int = 20,
offset: int = 0,
) -> tuple[list[FineTunePoolEntry], int]:
"""List pool entries with pagination."""
with get_session_context() as session:
query = select(FineTunePoolEntry)
count_query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
count_query = count_query.where(FineTunePoolEntry.is_verified == True)
total = session.exec(count_query).one()
entries = session.exec(
query.order_by(FineTunePoolEntry.created_at.desc())
.offset(offset)
.limit(limit)
).all()
for e in entries:
session.expunge(e)
return list(entries), total
def get_pool_count(self, verified_only: bool = True) -> int:
"""Get count of entries in the pool."""
with get_session_context() as session:
query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
return session.exec(query).one()
def get_all_document_ids(self, verified_only: bool = True) -> list[UUID]:
"""Get all document IDs in the pool."""
with get_session_context() as session:
query = select(FineTunePoolEntry.document_id)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
results = session.exec(query).all()
return list(results)
def verify_entry(
self,
entry_id: str | UUID,
verified_by: str | None = None,
) -> FineTunePoolEntry | None:
"""Mark a pool entry as verified."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return None
entry.is_verified = True
entry.verified_at = datetime.utcnow()
entry.verified_by = verified_by
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def remove_entry(self, entry_id: str | UUID) -> bool:
"""Remove an entry from the pool."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return False
session.delete(entry)
session.commit()
return True

View File

@@ -43,6 +43,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
model_type: str = "base",
base_model_version_id: str | UUID | None = None,
base_training_dataset_id: str | UUID | None = None,
gating_status: str = "pending",
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
@@ -60,6 +64,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
model_type=model_type,
base_model_version_id=UUID(str(base_model_version_id)) if base_model_version_id else None,
base_training_dataset_id=UUID(str(base_training_dataset_id)) if base_training_dataset_id else None,
gating_status=gating_status,
)
session.add(model)
session.commit()

View File

@@ -40,6 +40,7 @@ from .normalizers import (
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
from .value_selector import ValueSelector
@dataclass
@@ -169,13 +170,21 @@ class FieldExtractor:
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
# Sort by overlap ratio
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Convert to OCRTokens for value selection, then filter
from shared.ocr.paddle_ocr import OCRToken
pdf_ocr_tokens = [
OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0)
for t in matching_tokens
]
value_tokens = ValueSelector.select_value_tokens(pdf_ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
@@ -223,13 +232,14 @@ class FieldExtractor:
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Select value tokens (filter out label text)
value_tokens = ValueSelector.select_value_tokens(ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text

View File

@@ -20,26 +20,98 @@ class AmountNormalizer(BaseNormalizer):
Handles various Swedish amount formats:
- With decimal: 1 234,56 kr
- With SEK suffix: 1234.56 SEK
- Payment line kronor/ore: 590 00 (space = decimal separator)
- Multiple amounts (returns the last one, usually the total)
"""
# Payment line kronor/ore pattern: "590 00" means 590.00 SEK
# Only matches when no comma/dot is present (pure digit-space-2digit format)
_KRONOR_ORE_PATTERN = re.compile(r'^(\d+)\s+(\d{2})$')
@property
def field_name(self) -> str:
return "Amount"
@classmethod
def _try_kronor_ore(cls, text: str) -> NormalizationResult | None:
"""Try to parse as payment line kronor/ore format.
Swedish payment lines separate kronor and ore with a space:
"590 00" = 590.00 SEK, "15658 00" = 15658.00 SEK
Only applies when text has no comma or dot (otherwise it's
a normal amount format with explicit decimal separator).
Returns NormalizationResult on success, None if not matched.
"""
if ',' in text or '.' in text:
return None
match = cls._KRONOR_ORE_PATTERN.match(text.strip())
if not match:
return None
kronor = match.group(1)
ore = match.group(2)
try:
amount = float(f"{kronor}.{ore}")
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
return None
@staticmethod
def _parse_amount_str(match: str) -> float | None:
"""Convert matched amount string to float, detecting European vs Anglo format.
European: 2.254,50 -> 2254.50 (dot=thousand, comma=decimal)
Anglo: 1,234.56 -> 1234.56 (comma=thousand, dot=decimal)
Swedish: 1 234,56 -> 1234.56 (space=thousand, comma=decimal)
"""
has_comma = ',' in match
has_dot = '.' in match
if has_comma and has_dot:
if match.rfind(',') > match.rfind('.'):
# European: 2.254,50
cleaned = match.replace(" ", "").replace(".", "").replace(",", ".")
else:
# Anglo: 1,234.56
cleaned = match.replace(" ", "").replace(",", "")
elif has_comma:
cleaned = match.replace(" ", "").replace(",", ".")
else:
cleaned = match.replace(" ", "")
try:
return float(cleaned)
except ValueError:
return None
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Split by newlines and process line by line to get the last valid amount
lines = text.split("\n")
# Collect all valid amounts from all lines
all_amounts: list[float] = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
# Separate patterns for European and Anglo formats
# (?!\d) lookahead prevents partial matches (e.g. "1,23" in "1,234.56")
# European: dot=thousand, comma=decimal (2.254,50 or 1 234,56)
# Anglo: comma=thousand, dot=decimal (1,234.56 or 1234.56)
amount_pattern = (
r"(\d[\d\s.]*,\d{2})(?!\d)\s*(?:kr|SEK)?"
r"|"
r"(\d[\d\s,]*\.\d{2})(?!\d)\s*(?:kr|SEK)?"
)
for line in lines:
line = line.strip()
@@ -47,15 +119,13 @@ class AmountNormalizer(BaseNormalizer):
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
for m in re.finditer(amount_pattern, line, re.IGNORECASE):
match = m.group(1) or m.group(2)
if not match:
continue
amount = self._parse_amount_str(match)
if amount is not None and amount > 0:
all_amounts.append(amount)
# Return the last amount found (usually the total)
if all_amounts:
@@ -122,31 +192,33 @@ class EnhancedAmountNormalizer(AmountNormalizer):
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
# Use two capture groups: group(1) = European, group(2) = Anglo
labeled_patterns = [
# Swedish patterns
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
# Swedish patterns ((?!\d) prevents partial matches like "1,23" in "1,234.56")
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 1.0),
(
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
r"(?:moms|vat)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))",
0.8,
), # Lower priority for VAT
# Generic pattern
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
(r"(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))\s*(?:kr|sek|kronor)?", 0.7),
]
candidates: list[tuple[float, float, int]] = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
amount = self._parse_amount_str(match.group(1))
if amount is not None and 0 < amount < 10_000_000:
candidates.append((amount, priority, match.start()))
if candidates:
# Sort by priority (desc), then by position (later is usually total)

View File

@@ -301,6 +301,27 @@ class InferencePipeline:
all_extracted = []
all_ocr_text = [] # Collect OCR text for VAT extraction
# Check if PDF has readable text layer (avoids OCR for text PDFs)
from shared.pdf.extractor import PDFDocument
is_text_pdf = False
pdf_tokens_by_page: dict[int, list] = {}
try:
with PDFDocument(pdf_path) as pdf_doc:
is_text_pdf = pdf_doc.is_text_pdf()
if is_text_pdf:
for pg in range(pdf_doc.page_count):
pdf_tokens_by_page[pg] = list(
pdf_doc.extract_text_tokens(pg)
)
logger.info(
"Text PDF detected, extracted %d tokens from %d pages",
sum(len(t) for t in pdf_tokens_by_page.values()),
len(pdf_tokens_by_page),
)
except Exception as e:
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
is_text_pdf = False
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
@@ -313,7 +334,17 @@ class InferencePipeline:
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
if is_text_pdf and page_no in pdf_tokens_by_page:
extracted = self.extractor.extract_from_detection_with_pdf(
detection,
pdf_tokens_by_page[page_no],
image.width,
image.height,
)
else:
extracted = self.extractor.extract_from_detection(
detection, image_array
)
all_extracted.append(extracted)
# Collect full-page OCR text for VAT extraction (only if business features enabled)

View File

@@ -0,0 +1,172 @@
"""
Value Selector Module.
Selects the most likely value token(s) from OCR output per field type,
filtering out label text before sending to normalizer.
Stateless and pure -- easy to test, no side effects.
"""
import re
from typing import Final
from shared.ocr.paddle_ocr import OCRToken
# Swedish label keywords commonly found near field values
LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({
"fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum",
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
"total", "att", "betala", "kundnummer", "organisationsnummer",
"org", "nr", "datum", "nummer", "ref", "referens",
"momsreg", "vat", "moms", "sek", "kr",
"org.nr", "kund", "faktura", "invoice",
})
# Patterns
_DATE_PATTERN = re.compile(
r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15
r"|"
r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024
r"|"
r"\d{8}" # 20240115
)
_AMOUNT_PATTERN = re.compile(
r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK
r"|"
r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK
)
_BANKGIRO_PATTERN = re.compile(
r"^\d{3,4}-\d{4}$" # 123-4567
r"|"
r"^\d{7,8}$" # 1234567 or 12345678
)
_PLUSGIRO_PATTERN = re.compile(
r"^\d+-\d$" # 12345-6
r"|"
r"^\d{2,8}$" # 1234567
)
_ORG_NUMBER_PATTERN = re.compile(
r"\d{6}-?\d{4}" # 556123-4567 or 5561234567
)
def _is_label(text: str) -> bool:
"""Check if token text is a known Swedish label keyword."""
cleaned = text.lower().rstrip(":").strip()
return cleaned in LABEL_KEYWORDS
def _count_digits(text: str) -> int:
"""Count digit characters in text."""
return sum(c.isdigit() for c in text)
class ValueSelector:
"""Selects value token(s) from OCR output, filtering label text.
Pure static methods -- no state, no side effects.
Fallback: always returns all tokens if no pattern matches,
so this can never make results worse than current behavior.
"""
@staticmethod
def select_value_tokens(
tokens: list[OCRToken],
field_name: str,
) -> list[OCRToken]:
"""Select the most likely value token(s) for a given field.
Args:
tokens: OCR tokens from the detected region.
field_name: Normalized field name (e.g. "InvoiceDate", "Amount").
Returns:
Filtered list of value tokens. Falls back to all tokens
if no field-specific pattern matches.
"""
if not tokens:
return []
selector = _FIELD_SELECTORS.get(field_name, _fallback_selector)
selected = selector(tokens)
# Safety: never return empty if we had input tokens
if not selected:
return list(tokens)
return selected
@staticmethod
def _select_date(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _DATE_PATTERN)
@staticmethod
def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _AMOUNT_PATTERN)
@staticmethod
def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _BANKGIRO_PATTERN)
@staticmethod
def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _PLUSGIRO_PATTERN)
@staticmethod
def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN)
@staticmethod
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
"""Select token with the longest digit sequence (min 5 digits)."""
best: OCRToken | None = None
best_count = 0
for token in tokens:
digit_count = _count_digits(token.text)
if digit_count >= 5 and digit_count > best_count:
best = token
best_count = digit_count
return [best] if best else []
@staticmethod
def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]:
"""Remove label keywords, keep remaining tokens."""
return [t for t in tokens if not _is_label(t.text)]
@staticmethod
def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]:
"""Payment line keeps all tokens (needs full text)."""
return list(tokens)
def _select_by_pattern(
tokens: list[OCRToken],
pattern: re.Pattern[str],
) -> list[OCRToken]:
"""Select tokens matching a regex pattern."""
return [t for t in tokens if pattern.search(t.text.strip())]
def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]:
"""Default: return all tokens unchanged."""
return list(tokens)
# Map field names to selector functions
_FIELD_SELECTORS: Final[dict[str, callable]] = {
"InvoiceDate": ValueSelector._select_date,
"InvoiceDueDate": ValueSelector._select_date,
"Amount": ValueSelector._select_amount,
"Bankgiro": ValueSelector._select_bankgiro,
"Plusgiro": ValueSelector._select_plusgiro,
"OCR": ValueSelector._select_ocr_number,
"InvoiceNumber": ValueSelector._select_by_label_exclusion,
"supplier_org_number": ValueSelector._select_org_number,
"customer_number": ValueSelector._select_by_label_exclusion,
"payment_line": ValueSelector._select_payment_line,
}

View File

@@ -12,6 +12,7 @@ from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
from .pool import register_pool_routes
def create_training_router() -> APIRouter:
@@ -23,6 +24,7 @@ def create_training_router() -> APIRouter:
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
register_pool_routes(router)
return router

View File

@@ -12,6 +12,7 @@ from backend.web.core.auth import (
AnnotationRepoDep,
ModelVersionRepoDep,
TrainingTaskRepoDep,
FineTunePoolRepoDep,
)
from backend.web.schemas.admin import (
DatasetCreateRequest,
@@ -233,6 +234,7 @@ def register_dataset_routes(router: APIRouter) -> None:
datasets_repo: DatasetRepoDep,
models: ModelVersionRepoDep,
tasks: TrainingTaskRepoDep,
pool: FineTunePoolRepoDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset.
@@ -261,13 +263,39 @@ def register_dataset_routes(router: APIRouter) -> None:
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Chain prevention: never fine-tune from a fine-tuned model
if getattr(base_model, "model_type", "base") == "finetune":
original_base_id = getattr(base_model, "base_model_version_id", None)
raise HTTPException(
status_code=400,
detail=(
f"Cannot chain fine-tunes. Model {base_model.version} is already "
f"a fine-tuned model. Select the original base model instead"
f"{f' (base_model_version_id: {original_base_id})' if original_base_id else ''}."
),
)
# Pool threshold: require minimum verified pool entries for fine-tuning
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = pool.get_pool_count(verified_only=True)
if verified_count < MIN_POOL_SIZE:
raise HTTPException(
status_code=400,
detail=(
f"Fine-tuning requires at least {MIN_POOL_SIZE} verified pool entries "
f"(currently {verified_count}). Add more documents to the fine-tune "
f"pool and verify them before starting fine-tuning."
),
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
"Fine-tuning: using base model %s (%s) with %d verified pool entries",
base_model.version,
base_model.model_path,
verified_count,
)
task_id = tasks.create(

View File

@@ -124,16 +124,11 @@ def register_export_routes(router: APIRouter) -> None:
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Use manual_mode for manual/imported annotations
manual_mode = ann.source in ("manual", "imported")
# Apply field-specific bbox expansion
# Apply uniform bbox expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Convert back to normalized YOLO format

View File

@@ -88,6 +88,9 @@ def register_model_routes(router: APIRouter) -> None:
name=m.name,
status=m.status,
is_active=m.is_active,
model_type=getattr(m, "model_type", "base"),
base_model_version_id=str(m.base_model_version_id) if getattr(m, "base_model_version_id", None) else None,
gating_status=getattr(m, "gating_status", "pending"),
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
@@ -121,6 +124,9 @@ def register_model_routes(router: APIRouter) -> None:
name=model.name,
status=model.status,
is_active=model.is_active,
model_type=getattr(model, "model_type", "base"),
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
gating_status=getattr(model, "gating_status", "pending"),
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
@@ -153,6 +159,10 @@ def register_model_routes(router: APIRouter) -> None:
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
model_type=getattr(model, "model_type", "base"),
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
base_training_dataset_id=str(model.base_training_dataset_id) if getattr(model, "base_training_dataset_id", None) else None,
gating_status=getattr(model, "gating_status", "pending"),
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
@@ -209,6 +219,25 @@ def register_model_routes(router: APIRouter) -> None:
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
# Check gating status before activation (for fine-tuned models)
pre_check = models.get(version_id)
if not pre_check:
raise HTTPException(status_code=404, detail="Model version not found")
model_type = getattr(pre_check, "model_type", "base")
gating_status = getattr(pre_check, "gating_status", "skipped")
if model_type == "finetune":
if gating_status == "reject":
raise HTTPException(
status_code=400,
detail="Model failed deployment gating validation. Cannot activate a rejected model.",
)
if gating_status == "pending":
raise HTTPException(
status_code=400,
detail="Model gating validation not yet completed. Wait for validation to finish.",
)
model = models.activate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
@@ -227,6 +256,8 @@ def register_model_routes(router: APIRouter) -> None:
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
if gating_status == "review":
message += " (WARNING: gating status is REVIEW - manual inspection recommended)"
return ModelVersionResponse(
version_id=str(model.version_id),

View File

@@ -0,0 +1,159 @@
"""Fine-Tune Pool Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import AdminTokenDep, FineTunePoolRepoDep, DocumentRepoDep
from backend.web.schemas.admin.pool import (
PoolAddRequest,
PoolEntryItem,
PoolEntryResponse,
PoolListResponse,
PoolStatsResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_pool_routes(router: APIRouter) -> None:
"""Register fine-tune pool endpoints on the router."""
@router.post(
"/pool",
response_model=PoolEntryResponse,
summary="Add document to fine-tune pool",
description="Add a labeled document to the fine-tune pool for future fine-tuning.",
)
async def add_to_pool(
request: PoolAddRequest,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
docs: DocumentRepoDep,
) -> PoolEntryResponse:
"""Add a document to the fine-tune pool."""
_validate_uuid(request.document_id, "document_id")
# Verify document exists
doc = docs.get(request.document_id)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Check if already in pool
existing = pool.get_by_document(request.document_id)
if existing:
raise HTTPException(
status_code=409,
detail=f"Document already in fine-tune pool (entry_id: {existing.entry_id})",
)
entry = pool.add_document(
document_id=request.document_id,
added_by=admin_token,
reason=request.reason,
)
return PoolEntryResponse(
entry_id=str(entry.entry_id),
message="Document added to fine-tune pool",
)
@router.get(
"/pool",
response_model=PoolListResponse,
summary="List fine-tune pool entries",
)
async def list_pool_entries(
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
verified_only: Annotated[bool, Query(description="Filter to verified only")] = False,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> PoolListResponse:
"""List entries in the fine-tune pool."""
entries, total = pool.get_paginated(
verified_only=verified_only,
limit=limit,
offset=offset,
)
return PoolListResponse(
total=total,
limit=limit,
offset=offset,
entries=[
PoolEntryItem(
entry_id=str(e.entry_id),
document_id=str(e.document_id),
added_by=e.added_by,
reason=e.reason,
is_verified=e.is_verified,
verified_at=e.verified_at,
verified_by=e.verified_by,
created_at=e.created_at,
)
for e in entries
],
)
@router.get(
"/pool/stats",
response_model=PoolStatsResponse,
summary="Get fine-tune pool statistics",
)
async def get_pool_stats(
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> PoolStatsResponse:
"""Get statistics about the fine-tune pool."""
total = pool.get_pool_count(verified_only=False)
verified = pool.get_pool_count(verified_only=True)
return PoolStatsResponse(
total_entries=total,
verified_entries=verified,
unverified_entries=total - verified,
is_ready=verified >= 50,
)
@router.post(
"/pool/{entry_id}/verify",
response_model=PoolEntryResponse,
summary="Verify a pool entry",
description="Mark a pool entry as verified (human-reviewed).",
)
async def verify_pool_entry(
entry_id: str,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> PoolEntryResponse:
"""Mark a pool entry as verified."""
_validate_uuid(entry_id, "entry_id")
entry = pool.verify_entry(entry_id, verified_by=admin_token)
if not entry:
raise HTTPException(status_code=404, detail="Pool entry not found")
return PoolEntryResponse(
entry_id=str(entry.entry_id),
message="Pool entry verified",
)
@router.delete(
"/pool/{entry_id}",
summary="Remove from fine-tune pool",
)
async def remove_from_pool(
entry_id: str,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> dict:
"""Remove a document from the fine-tune pool."""
_validate_uuid(entry_id, "entry_id")
success = pool.remove_entry(entry_id)
if not success:
raise HTTPException(status_code=404, detail="Pool entry not found")
return {"message": "Entry removed from fine-tune pool"}

View File

@@ -17,6 +17,7 @@ from backend.data.repositories import (
TrainingTaskRepository,
ModelVersionRepository,
BatchUploadRepository,
FineTunePoolRepository,
)
@@ -95,6 +96,12 @@ def get_batch_upload_repository() -> BatchUploadRepository:
return BatchUploadRepository()
@lru_cache(maxsize=1)
def get_finetune_pool_repository() -> FineTunePoolRepository:
"""Get the FineTunePoolRepository instance (thread-safe singleton)."""
return FineTunePoolRepository()
def reset_all_repositories() -> None:
"""Reset all repository instances (for testing)."""
get_token_repository.cache_clear()
@@ -104,6 +111,7 @@ def reset_all_repositories() -> None:
get_training_task_repository.cache_clear()
get_model_version_repository.cache_clear()
get_batch_upload_repository.cache_clear()
get_finetune_pool_repository.cache_clear()
# Repository dependency type aliases
@@ -113,3 +121,4 @@ DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]
FineTunePoolRepoDep = Annotated[FineTunePoolRepository, Depends(get_finetune_pool_repository)]

View File

@@ -16,6 +16,7 @@ from backend.data.repositories import (
ModelVersionRepository,
DocumentRepository,
AnnotationRepository,
FineTunePoolRepository,
)
from backend.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.services.storage_helpers import get_storage_helper
@@ -47,6 +48,7 @@ class TrainingScheduler(TaskRunner):
self._model_versions = ModelVersionRepository()
self._documents = DocumentRepository()
self._annotations = AnnotationRepository()
self._pool = FineTunePoolRepository()
@property
def name(self) -> str:
@@ -168,7 +170,7 @@ class TrainingScheduler(TaskRunner):
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
model_name = config.get("model_name", "yolo26s.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
@@ -182,14 +184,19 @@ class TrainingScheduler(TaskRunner):
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
is_finetune = bool(base_model_path)
if is_finetune:
# Fine-tuning: use existing trained model as base
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
# Override parameters for fine-tuning (best practices)
epochs = config.get("epochs", 10)
learning_rate = config.get("learning_rate", 0.001)
self._training_tasks.add_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
f"Fine-tuning from: {base_model_path} "
f"(epochs={epochs}, freeze=10, cos_lr=true)",
)
else:
# Train from pretrained model
@@ -229,10 +236,16 @@ class TrainingScheduler(TaskRunner):
f"(total: {aug_result['total_images']})",
)
# Build mixed dataset for fine-tuning (pool samples + old data)
if is_finetune and dataset_id:
data_yaml, dataset_path = self._build_mixed_finetune_dataset(
task_id, dataset_path, data_yaml,
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=effective_model, # Use base model or pretrained model
model_name=effective_model,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -240,6 +253,8 @@ class TrainingScheduler(TaskRunner):
learning_rate=learning_rate,
device=device,
project_name=project_name,
freeze=10 if is_finetune else 0,
cos_lr=is_finetune,
)
# Update task with results
@@ -261,13 +276,23 @@ class TrainingScheduler(TaskRunner):
)
# Auto-create model version for the completed training
self._create_model_version_from_training(
model_version = self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
# Auto-run gating validation for fine-tuned models
if is_finetune and model_version:
self._run_gating_after_finetune(
task_id=task_id,
model_version=model_version,
config=config,
data_yaml=data_yaml,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
@@ -286,13 +311,16 @@ class TrainingScheduler(TaskRunner):
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
) -> Any | None:
"""Create a model version entry from completed training.
Returns the created model version, or None on failure.
"""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
return None
# Get task info for name
task = self._training_tasks.get(task_id)
@@ -322,6 +350,12 @@ class TrainingScheduler(TaskRunner):
if dataset:
document_count = dataset.total_documents
# Determine model lineage
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
gating_status = "pending" if is_finetune else "skipped"
# Create model version
model_version = self._model_versions.create(
version=version,
@@ -337,6 +371,10 @@ class TrainingScheduler(TaskRunner):
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
model_type=model_type,
base_model_version_id=base_model_version_id,
base_training_dataset_id=dataset_id if not is_finetune else None,
gating_status=gating_status,
)
logger.info(
@@ -349,12 +387,105 @@ class TrainingScheduler(TaskRunner):
f"Model version {version} created (mAP: {mAP_display})",
)
return model_version
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
return None
def _build_mixed_finetune_dataset(
self,
task_id: str,
base_dataset_path: Path,
original_data_yaml: str,
) -> tuple[str, Path]:
"""Build a mixed dataset for fine-tuning.
Combines verified pool samples with randomly sampled old training data.
Returns:
Tuple of (data_yaml path, dataset_path) for the mixed dataset.
Falls back to original if mixing fails or pool is empty.
"""
try:
from backend.web.services.data_mixer import build_mixed_dataset
pool_doc_ids = self._pool.get_all_document_ids(verified_only=True)
if not pool_doc_ids:
self._training_tasks.add_log(
task_id, "INFO",
"No verified pool entries found, using original dataset",
)
return original_data_yaml, base_dataset_path
mixed_output = base_dataset_path.parent / f"mixed_{task_id[:8]}"
mix_result = build_mixed_dataset(
pool_document_ids=pool_doc_ids,
base_dataset_path=base_dataset_path,
output_dir=mixed_output,
)
self._training_tasks.add_log(
task_id, "INFO",
f"Data mixing: {mix_result['new_images']} new + "
f"{mix_result['old_images']} old = {mix_result['total_images']} total "
f"(ratio: {mix_result['mixing_ratio']}x)",
)
return mix_result["data_yaml"], mixed_output
except Exception as e:
logger.error(f"Data mixing failed for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Data mixing failed: {e}. Using original dataset.",
)
return original_data_yaml, base_dataset_path
def _run_gating_after_finetune(
self,
task_id: str,
model_version: Any,
config: dict[str, Any],
data_yaml: str,
result: dict[str, Any],
) -> None:
"""Run gating validation after a fine-tune training completes."""
try:
from backend.web.services.gating_validator import run_gating_validation
model_path = result.get("model_path")
base_model_version_id = config.get("base_model_version_id")
version_id = str(model_version.version_id)
self._training_tasks.add_log(
task_id, "INFO", "Running deployment gating validation...",
)
gating_result = run_gating_validation(
model_version_id=version_id,
new_model_path=model_path,
base_model_version_id=base_model_version_id,
data_yaml=data_yaml,
task_id=task_id,
)
self._training_tasks.add_log(
task_id, "INFO",
f"Gating result: {gating_result.overall_status} "
f"(gate1={gating_result.gate1_status}, gate2={gating_result.gate2_status})",
)
except Exception as e:
logger.error(f"Gating validation failed for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Gating validation failed: {e}. Model remains in 'pending' state.",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
@@ -456,6 +587,8 @@ names: {list(FIELD_CLASSES.values())}
learning_rate: float,
device: str,
project_name: str,
freeze: int = 0,
cos_lr: bool = False,
) -> dict[str, Any]:
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
@@ -479,6 +612,8 @@ names: {list(FIELD_CLASSES.values())}
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
freeze=freeze,
cos_lr=cos_lr,
)
# Run training using shared trainer

View File

@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .pool import * # noqa: F401, F403
from .dashboard import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse

View File

@@ -40,6 +40,9 @@ class ModelVersionItem(BaseModel):
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
model_type: str = Field(default="base", description="Model type (base or finetune)")
base_model_version_id: str | None = Field(None, description="Base model version UUID (for fine-tuned models)")
gating_status: str = Field(default="pending", description="Deployment gating status")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
@@ -66,6 +69,10 @@ class ModelVersionDetailResponse(BaseModel):
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
model_type: str = Field(default="base", description="Model type (base or finetune)")
base_model_version_id: str | None = Field(None, description="Base model version UUID")
base_training_dataset_id: str | None = Field(None, description="Base training dataset UUID")
gating_status: str = Field(default="pending", description="Deployment gating status")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")

View File

@@ -0,0 +1,72 @@
"""Admin Fine-Tune Pool Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class PoolAddRequest(BaseModel):
"""Request to add a document to the fine-tune pool."""
document_id: str = Field(..., description="Document UUID to add")
reason: str = Field(
default="user_reported_failure",
description="Reason: user_reported_failure, manual_addition",
)
class PoolEntryItem(BaseModel):
"""Fine-tune pool entry."""
entry_id: str = Field(..., description="Entry UUID")
document_id: str = Field(..., description="Document UUID")
added_by: str | None = Field(None, description="Who added this entry")
reason: str | None = Field(None, description="Reason for adding")
is_verified: bool = Field(..., description="Whether entry has been verified")
verified_at: datetime | None = Field(None, description="Verification timestamp")
verified_by: str | None = Field(None, description="Who verified")
created_at: datetime = Field(..., description="Creation timestamp")
class PoolListResponse(BaseModel):
"""Paginated pool entry list."""
total: int = Field(..., ge=0, description="Total entries")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
entries: list[PoolEntryItem] = Field(default_factory=list, description="Pool entries")
class PoolStatsResponse(BaseModel):
"""Pool statistics."""
total_entries: int = Field(..., ge=0, description="Total pool entries")
verified_entries: int = Field(..., ge=0, description="Verified entries")
unverified_entries: int = Field(..., ge=0, description="Unverified entries")
is_ready: bool = Field(..., description="Whether pool has >= 50 verified entries for fine-tuning")
min_required: int = Field(default=50, description="Minimum verified entries required")
class PoolEntryResponse(BaseModel):
"""Response for pool entry operation."""
entry_id: str = Field(..., description="Entry UUID")
message: str = Field(..., description="Status message")
class GatingResultItem(BaseModel):
"""Gating validation result."""
result_id: str = Field(..., description="Result UUID")
model_version_id: str = Field(..., description="Model version UUID")
gate1_status: str = Field(..., description="Gate 1 status")
gate1_original_mAP: float | None = Field(None, description="Original model mAP")
gate1_new_mAP: float | None = Field(None, description="New model mAP")
gate1_mAP_drop: float | None = Field(None, description="mAP drop percentage")
gate2_status: str = Field(..., description="Gate 2 status")
gate2_detection_rate: float | None = Field(None, description="Detection rate on new samples")
gate2_total_samples: int | None = Field(None, description="Total new samples tested")
gate2_detected_samples: int | None = Field(None, description="Samples correctly detected")
overall_status: str = Field(..., description="Overall gating status")
reviewer_notes: str | None = Field(None, description="Reviewer notes")
created_at: datetime = Field(..., description="Creation timestamp")

View File

@@ -12,7 +12,7 @@ from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
model_name: str = Field(default="yolo26s.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",

View File

@@ -0,0 +1,199 @@
"""
Data Mixing Service
Mixes fine-tune pool samples with randomly sampled old training data
following best practices for incremental YOLO fine-tuning.
"""
import logging
import random
import shutil
from pathlib import Path
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
# Mixing ratios: (max_new_samples, old_data_multiplier)
# Fewer new samples require higher old data ratio to prevent catastrophic forgetting
MIXING_RATIOS: list[tuple[int, int]] = [
(10, 50), # <= 10 new samples: 50x old
(50, 20), # <= 50 new samples: 20x old
(200, 10), # <= 200 new samples: 10x old
(500, 5), # <= 500 new samples: 5x old
]
# Default multiplier for 500+ samples
DEFAULT_MULTIPLIER = 5
# Maximum old samples to include (cap for performance)
MAX_OLD_SAMPLES = 3000
# Minimum pool size for fine-tuning
MIN_POOL_SIZE = 50
def get_mixing_ratio(new_sample_count: int) -> int:
"""Determine old data multiplier based on new sample count.
Args:
new_sample_count: Number of new samples in the fine-tune pool.
Returns:
Multiplier for old data sampling.
"""
for threshold, multiplier in MIXING_RATIOS:
if new_sample_count <= threshold:
return multiplier
return DEFAULT_MULTIPLIER
def build_mixed_dataset(
pool_document_ids: list[UUID],
base_dataset_path: Path,
output_dir: Path,
seed: int = 42,
) -> dict[str, Any]:
"""Build a mixed dataset for fine-tuning.
Combines ALL fine-tune pool samples with randomly sampled old data
from the base training dataset.
Args:
pool_document_ids: Document IDs from the fine-tune pool.
base_dataset_path: Path to the base training dataset directory.
output_dir: Output directory for the mixed dataset.
seed: Random seed for reproducible sampling.
Returns:
Dictionary with dataset info (data_yaml path, counts).
"""
new_count = len(pool_document_ids)
multiplier = get_mixing_ratio(new_count)
old_target = min(new_count * multiplier, MAX_OLD_SAMPLES)
logger.info(
"Building mixed dataset: %d new samples, %dx multiplier, "
"targeting %d old samples",
new_count, multiplier, old_target,
)
# Create output directory structure
output_dir.mkdir(parents=True, exist_ok=True)
for split in ("train", "val"):
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# Collect old training images from base dataset
old_train_images = _collect_images(base_dataset_path / "images" / "train")
old_val_images = _collect_images(base_dataset_path / "images" / "val")
# Randomly sample old data
rng = random.Random(seed)
all_old_images = old_train_images + old_val_images
if len(all_old_images) > old_target:
sampled_old = rng.sample(all_old_images, old_target)
else:
sampled_old = all_old_images
# Split old samples: 80% train, 20% val
rng.shuffle(sampled_old)
old_train_count = int(len(sampled_old) * 0.8)
old_train = sampled_old[:old_train_count]
old_val = sampled_old[old_train_count:]
# Copy old samples to mixed dataset
old_copied = 0
for split_name, images in [("train", old_train), ("val", old_val)]:
for img_path in images:
label_path = _image_to_label_path(img_path)
dst_img = output_dir / "images" / split_name / img_path.name
dst_label = output_dir / "labels" / split_name / label_path.name
if img_path.exists():
shutil.copy2(img_path, dst_img)
old_copied += 1
if label_path.exists():
shutil.copy2(label_path, dst_label)
# Copy new pool samples (from base dataset, identified by document_id prefix)
# Pool documents go into train split (80%) and val split (20%)
pool_id_strs = {str(doc_id) for doc_id in pool_document_ids}
new_images = _find_pool_images(base_dataset_path, pool_id_strs)
rng.shuffle(new_images)
new_train_count = int(len(new_images) * 0.8)
new_train = new_images[:new_train_count]
new_val = new_images[new_train_count:]
new_copied = 0
for split_name, images in [("train", new_train), ("val", new_val)]:
for img_path in images:
label_path = _image_to_label_path(img_path)
dst_img = output_dir / "images" / split_name / img_path.name
dst_label = output_dir / "labels" / split_name / label_path.name
if img_path.exists() and not dst_img.exists():
shutil.copy2(img_path, dst_img)
new_copied += 1
if label_path.exists() and not dst_label.exists():
shutil.copy2(label_path, dst_label)
# Generate data.yaml
from shared.fields import FIELD_CLASSES
yaml_path = output_dir / "data.yaml"
yaml_content = (
f"path: {output_dir.absolute()}\n"
f"train: images/train\n"
f"val: images/val\n"
f"\n"
f"nc: {len(FIELD_CLASSES)}\n"
f"names: {list(FIELD_CLASSES.values())}\n"
)
yaml_path.write_text(yaml_content)
total_images = old_copied + new_copied
logger.info(
"Mixed dataset built: %d old + %d new = %d total images",
old_copied, new_copied, total_images,
)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"old_images": old_copied,
"new_images": new_copied,
"mixing_ratio": multiplier,
}
def _collect_images(images_dir: Path) -> list[Path]:
"""Collect all image files from a directory."""
if not images_dir.exists():
return []
return sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
def _image_to_label_path(image_path: Path) -> Path:
"""Convert image path to corresponding label path."""
labels_dir = image_path.parent.parent.parent / "labels" / image_path.parent.name
return labels_dir / image_path.with_suffix(".txt").name
def _find_pool_images(
base_dataset_path: Path,
pool_doc_ids: set[str],
) -> list[Path]:
"""Find images in base dataset that belong to pool documents."""
images: list[Path] = []
for split in ("train", "val", "test"):
split_dir = base_dataset_path / "images" / split
if not split_dir.exists():
continue
for img_path in split_dir.iterdir():
if not img_path.is_file():
continue
# Image filenames are like: {doc_id}_page{N}.png
doc_id = img_path.stem.rsplit("_page", 1)[0]
if doc_id in pool_doc_ids:
images.append(img_path)
return images

View File

@@ -0,0 +1,198 @@
"""
Gating Validation Service
Validates fine-tuned models before deployment using quality gates:
- Gate 1: Regression validation (mAP drop on original test set)
- Gate 2: New sample validation (detection rate on pool documents)
"""
import logging
from pathlib import Path
from uuid import UUID
from backend.data.admin_models import GatingResult
from backend.data.database import get_session_context
from backend.data.repositories.model_version_repository import ModelVersionRepository
logger = logging.getLogger(__name__)
# Gate 1 thresholds (mAP drop)
GATE1_PASS_THRESHOLD = 0.01 # < 1% drop = PASS
GATE1_REVIEW_THRESHOLD = 0.03 # 1-3% drop = REVIEW, > 3% = REJECT
# Gate 2 thresholds (detection rate)
GATE2_PASS_THRESHOLD = 0.80 # > 80% detection rate = PASS
def classify_gate1(mAP_drop: float) -> str:
"""Classify Gate 1 result based on mAP drop.
Args:
mAP_drop: Absolute mAP drop (positive means degradation).
Returns:
"pass", "review", or "reject"
"""
if mAP_drop < GATE1_PASS_THRESHOLD:
return "pass"
if mAP_drop < GATE1_REVIEW_THRESHOLD:
return "review"
return "reject"
def classify_gate2(detection_rate: float) -> str:
"""Classify Gate 2 result based on detection rate.
Args:
detection_rate: Fraction of new samples correctly detected (0-1).
Returns:
"pass" or "review"
"""
if detection_rate >= GATE2_PASS_THRESHOLD:
return "pass"
return "review"
def compute_overall_status(gate1_status: str, gate2_status: str) -> str:
"""Compute overall gating status from individual gates.
Rules:
- Any "reject" -> overall "reject"
- Any "review" (and no reject) -> overall "review"
- All "pass" -> overall "pass"
"""
if gate1_status == "reject" or gate2_status == "reject":
return "reject"
if gate1_status == "review" or gate2_status == "review":
return "review"
return "pass"
def run_gating_validation(
model_version_id: str | UUID,
new_model_path: str,
base_model_version_id: str | UUID | None,
data_yaml: str,
task_id: str | UUID | None = None,
) -> GatingResult:
"""Run deployment gating validation for a fine-tuned model.
Args:
model_version_id: The fine-tuned model version to validate.
new_model_path: Path to the new model weights.
base_model_version_id: The base model version to compare against.
data_yaml: Path to the dataset YAML (for validation).
task_id: Optional training task ID.
Returns:
GatingResult with gate statuses.
"""
model_versions = ModelVersionRepository()
# Gate 1: Regression validation
gate1_status = "pending"
gate1_original_mAP = None
gate1_new_mAP = None
gate1_mAP_drop = None
try:
if base_model_version_id:
base_model = model_versions.get(str(base_model_version_id))
if base_model and base_model.metrics_mAP is not None:
gate1_original_mAP = base_model.metrics_mAP
# Run validation with new model
from shared.training import YOLOTrainer, TrainingConfig
val_config = TrainingConfig(
model_path=new_model_path,
data_yaml=data_yaml,
)
trainer = YOLOTrainer(config=val_config)
val_metrics = trainer.validate(split="val")
gate1_new_mAP = val_metrics.get("mAP50")
if gate1_new_mAP is not None:
gate1_mAP_drop = gate1_original_mAP - gate1_new_mAP
gate1_status = classify_gate1(gate1_mAP_drop)
logger.info(
"Gate 1: original_mAP=%.4f, new_mAP=%.4f, drop=%.4f -> %s",
gate1_original_mAP, gate1_new_mAP, gate1_mAP_drop, gate1_status,
)
else:
gate1_status = "review"
logger.warning("Gate 1: Could not compute new mAP, marking as review")
else:
gate1_status = "pass"
logger.info("Gate 1: No base model metrics available, skipping (pass)")
else:
gate1_status = "pass"
logger.info("Gate 1: No base model specified, skipping (pass)")
except Exception as e:
gate1_status = "review"
logger.error("Gate 1 failed: %s", e)
# Gate 2: New sample validation
# For now, we use the training metrics as a proxy
# Full implementation would run inference on pool documents
gate2_status = "pass"
gate2_detection_rate = None
gate2_total_samples = None
gate2_detected_samples = None
try:
new_model = model_versions.get(str(model_version_id))
if new_model and new_model.metrics_mAP is not None:
# Use mAP as proxy for detection rate on new samples
gate2_detection_rate = new_model.metrics_mAP
if gate2_detection_rate is not None:
gate2_status = classify_gate2(gate2_detection_rate)
logger.info(
"Gate 2: detection_rate=%.4f -> %s",
gate2_detection_rate, gate2_status,
)
except Exception as e:
gate2_status = "review"
logger.error("Gate 2 failed: %s", e)
# Compute overall status
overall_status = compute_overall_status(gate1_status, gate2_status)
logger.info("Gating overall: %s (gate1=%s, gate2=%s)", overall_status, gate1_status, gate2_status)
# Save result
with get_session_context() as session:
result = GatingResult(
model_version_id=UUID(str(model_version_id)),
task_id=UUID(str(task_id)) if task_id else None,
gate1_status=gate1_status,
gate1_original_mAP=gate1_original_mAP,
gate1_new_mAP=gate1_new_mAP,
gate1_mAP_drop=gate1_mAP_drop,
gate2_status=gate2_status,
gate2_detection_rate=gate2_detection_rate,
gate2_total_samples=gate2_total_samples,
gate2_detected_samples=gate2_detected_samples,
overall_status=overall_status,
)
session.add(result)
session.commit()
session.refresh(result)
session.expunge(result)
# Update model version gating status
_update_model_gating_status(str(model_version_id), overall_status)
return result
def _update_model_gating_status(version_id: str, status: str) -> None:
"""Update the gating_status field on a ModelVersion."""
from backend.data.admin_models import ModelVersion
with get_session_context() as session:
model = session.get(ModelVersion, UUID(version_id))
if model:
model.gating_status = status
session.add(model)
session.commit()

View File

@@ -3,6 +3,6 @@ fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
sqlmodel>=0.0.22
ultralytics>=8.1.0
ultralytics>=8.4.0
httpx>=0.25.0
openai>=1.0.0

View File

@@ -1,37 +1,20 @@
"""
BBox Scale Strategy Module.
BBox Expansion Module.
Provides field-specific bounding box expansion strategies for YOLO training data.
Expands bboxes using center-point scaling with directional compensation to capture
field labels that typically appear above or to the left of field values.
Two modes are supported:
- Auto-label: Field-specific scale strategies with directional compensation
- Manual-label: Minimal padding only to prevent edge clipping
Provides uniform bounding box expansion for YOLO training data.
Usage:
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
from shared.bbox import expand_bbox, UNIFORM_PAD
Available exports:
- ScaleStrategy: Dataclass for scale strategy configuration
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
- expand_bbox: Function to expand bbox using field-specific strategy
- UNIFORM_PAD: Default uniform pixel padding (15px at 150 DPI)
- expand_bbox: Function to expand bbox with uniform padding
"""
from .scale_strategy import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
from .scale_strategy import UNIFORM_PAD
from .expander import expand_bbox
__all__ = [
"ScaleStrategy",
"DEFAULT_STRATEGY",
"MANUAL_LABEL_STRATEGY",
"FIELD_SCALE_STRATEGIES",
"UNIFORM_PAD",
"expand_bbox",
]

View File

@@ -1,101 +1,35 @@
"""
BBox Expander Module.
Provides functions to expand bounding boxes using field-specific strategies.
Expansion is center-point based with directional compensation.
Two modes:
- Auto-label (default): Field-specific scale strategies
- Manual-label: Minimal padding only to prevent edge clipping
Expands bounding boxes by a uniform pixel padding on all sides,
clamped to image boundaries. No field-specific or directional logic.
"""
from .scale_strategy import (
ScaleStrategy,
DEFAULT_STRATEGY,
MANUAL_LABEL_STRATEGY,
FIELD_SCALE_STRATEGIES,
)
from .scale_strategy import UNIFORM_PAD
def expand_bbox(
bbox: tuple[float, float, float, float],
image_width: float,
image_height: float,
field_type: str,
strategies: dict[str, ScaleStrategy] | None = None,
manual_mode: bool = False,
pad: int = UNIFORM_PAD,
) -> tuple[int, int, int, int]:
"""
Expand bbox using field-specific scale strategy.
The expansion follows these steps:
1. Scale bbox around center point (scale_x, scale_y)
2. Apply directional compensation (extra_*_ratio)
3. Clamp expansion to max_pad limits
4. Clamp to image boundaries
"""Expand bbox by uniform pixel padding, clamped to image bounds.
Args:
bbox: (x0, y0, x1, y1) in pixels
image_width: Image width for boundary clamping
image_height: Image height for boundary clamping
field_type: Field class_name (e.g., "ocr_number")
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
bbox: (x0, y0, x1, y1) in pixels.
image_width: Image width for boundary clamping.
image_height: Image height for boundary clamping.
pad: Uniform pixel padding on all sides (default: UNIFORM_PAD).
Returns:
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds.
"""
x0, y0, x1, y1 = bbox
w = x1 - x0
h = y1 - y0
# Get strategy based on mode
if manual_mode:
strategy = MANUAL_LABEL_STRATEGY
elif strategies is None:
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
else:
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
# Step 1: Scale around center point
cx = (x0 + x1) / 2
cy = (y0 + y1) / 2
new_w = w * strategy.scale_x
new_h = h * strategy.scale_y
nx0 = cx - new_w / 2
nx1 = cx + new_w / 2
ny0 = cy - new_h / 2
ny1 = cy + new_h / 2
# Step 2: Apply directional compensation
nx0 -= w * strategy.extra_left_ratio
nx1 += w * strategy.extra_right_ratio
ny0 -= h * strategy.extra_top_ratio
ny1 += h * strategy.extra_bottom_ratio
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
left_pad = min(x0 - nx0, strategy.max_pad_x)
right_pad = min(nx1 - x1, strategy.max_pad_x)
top_pad = min(y0 - ny0, strategy.max_pad_y)
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
# Ensure pads are non-negative (in case of contraction)
left_pad = max(0, left_pad)
right_pad = max(0, right_pad)
top_pad = max(0, top_pad)
bottom_pad = max(0, bottom_pad)
nx0 = x0 - left_pad
nx1 = x1 + right_pad
ny0 = y0 - top_pad
ny1 = y1 + bottom_pad
# Step 4: Clamp to image boundaries
nx0 = max(0, int(nx0))
ny0 = max(0, int(ny0))
nx1 = min(int(image_width), int(nx1))
ny1 = min(int(image_height), int(ny1))
nx0 = max(0, int(x0 - pad))
ny0 = max(0, int(y0 - pad))
nx1 = min(int(image_width), int(x1 + pad))
ny1 = min(int(image_height), int(y1 + pad))
return (nx0, ny0, nx1, ny1)

View File

@@ -1,140 +1,12 @@
"""
Scale Strategy Configuration.
Defines field-specific bbox expansion strategies for YOLO training data.
Each strategy controls how bboxes are expanded around field values to
capture contextual information like labels.
Defines uniform bbox expansion padding for YOLO training data.
All fields use the same fixed-pixel padding -- no layout assumptions.
"""
from dataclasses import dataclass
from typing import Final
@dataclass(frozen=True)
class ScaleStrategy:
"""Immutable scale strategy for bbox expansion.
Attributes:
scale_x: Horizontal scale factor (1.0 = no scaling)
scale_y: Vertical scale factor (1.0 = no scaling)
extra_top_ratio: Additional expansion ratio towards top (for labels above)
extra_bottom_ratio: Additional expansion ratio towards bottom
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
max_pad_x: Maximum horizontal padding in pixels
max_pad_y: Maximum vertical padding in pixels
"""
scale_x: float = 1.15
scale_y: float = 1.15
extra_top_ratio: float = 0.0
extra_bottom_ratio: float = 0.0
extra_left_ratio: float = 0.0
extra_right_ratio: float = 0.0
max_pad_x: int = 50
max_pad_y: int = 50
# Default strategy for unknown fields (auto-label mode)
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
# Manual label strategy - minimal padding to prevent edge clipping
# No scaling, no directional compensation, just small uniform padding
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
scale_x=1.0,
scale_y=1.0,
extra_top_ratio=0.0,
extra_bottom_ratio=0.0,
extra_left_ratio=0.0,
extra_right_ratio=0.0,
max_pad_x=10, # Small padding to prevent edge loss
max_pad_y=10,
)
# Field-specific strategies based on Swedish invoice field characteristics
# Field labels typically appear above or to the left of values
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
# OCR number - label "OCR" or "Referens" typically above
"ocr_number": ScaleStrategy(
scale_x=1.15,
scale_y=1.80,
extra_top_ratio=0.60,
max_pad_x=50,
max_pad_y=140,
),
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
"bankgiro": ScaleStrategy(
scale_x=1.45,
scale_y=1.35,
extra_left_ratio=0.80,
max_pad_x=160,
max_pad_y=90,
),
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
"plusgiro": ScaleStrategy(
scale_x=1.45,
scale_y=1.35,
extra_left_ratio=0.80,
max_pad_x=160,
max_pad_y=90,
),
# Invoice date - label "Fakturadatum" typically above
"invoice_date": ScaleStrategy(
scale_x=1.25,
scale_y=1.55,
extra_top_ratio=0.40,
max_pad_x=80,
max_pad_y=110,
),
# Due date - label "Forfalldatum" typically above, sometimes left
"invoice_due_date": ScaleStrategy(
scale_x=1.30,
scale_y=1.65,
extra_top_ratio=0.45,
extra_left_ratio=0.35,
max_pad_x=100,
max_pad_y=120,
),
# Amount - currency symbol "SEK" or "kr" may be to the right
"amount": ScaleStrategy(
scale_x=1.20,
scale_y=1.35,
extra_right_ratio=0.30,
max_pad_x=70,
max_pad_y=80,
),
# Invoice number - label "Fakturanummer" typically above
"invoice_number": ScaleStrategy(
scale_x=1.20,
scale_y=1.50,
extra_top_ratio=0.40,
max_pad_x=80,
max_pad_y=100,
),
# Supplier org number - label "Org.nr" typically above or left
"supplier_org_number": ScaleStrategy(
scale_x=1.25,
scale_y=1.40,
extra_top_ratio=0.30,
extra_left_ratio=0.20,
max_pad_x=90,
max_pad_y=90,
),
# Customer number - label "Kundnummer" typically above or left
"customer_number": ScaleStrategy(
scale_x=1.25,
scale_y=1.45,
extra_top_ratio=0.35,
extra_left_ratio=0.25,
max_pad_x=90,
max_pad_y=100,
),
# Payment line - machine-readable code, minimal expansion needed
"payment_line": ScaleStrategy(
scale_x=1.10,
scale_y=1.20,
max_pad_x=40,
max_pad_y=40,
),
}
# 15px at 150 DPI = ~2.5mm real-world padding around text.
# Enough for OCR safety margin without capturing neighboring label text.
UNIFORM_PAD: Final[int] = 15

View File

@@ -17,7 +17,7 @@ class TrainingConfig:
"""Training configuration."""
# Model settings
model_path: str = "yolo11n.pt" # Base model or path to trained model
model_path: str = "yolo26s.pt" # Base model or path to trained model
data_yaml: str = "" # Path to data.yaml
# Training hyperparameters
@@ -39,6 +39,10 @@ class TrainingConfig:
resume: bool = False
resume_from: str | None = None # Path to checkpoint
# Fine-tuning specific
freeze: int = 0 # Number of backbone layers to freeze (0 = none)
cos_lr: bool = False # Use cosine learning rate scheduler
# Document-specific augmentation (optimized for invoices)
augmentation: dict[str, Any] = field(default_factory=lambda: {
"degrees": 5.0,
@@ -106,7 +110,7 @@ class YOLOTrainer:
# Check model path
model_path = Path(self.config.model_path)
if not model_path.suffix == ".pt":
# Could be a model name like "yolo11n.pt" which is downloaded
# Could be a model name like "yolo26s.pt" which is downloaded
if not model_path.name.startswith("yolo"):
return False, f"Invalid model: {self.config.model_path}"
elif not model_path.exists():
@@ -147,6 +151,10 @@ class YOLOTrainer:
self._log("INFO", f" Epochs: {self.config.epochs}")
self._log("INFO", f" Batch size: {self.config.batch_size}")
self._log("INFO", f" Image size: {self.config.image_size}")
if self.config.freeze > 0:
self._log("INFO", f" Freeze layers: {self.config.freeze}")
if self.config.cos_lr:
self._log("INFO", f" Cosine LR: enabled")
try:
# Load model
@@ -178,6 +186,12 @@ class YOLOTrainer:
"resume": self.config.resume and self.config.resume_from is not None,
}
# Add fine-tuning settings
if self.config.freeze > 0:
train_args["freeze"] = self.config.freeze
if self.config.cos_lr:
train_args["cos_lr"] = True
# Add augmentation settings
train_args.update(self.config.augmentation)

View File

@@ -1,4 +1,4 @@
-e ../shared
ultralytics>=8.1.0
ultralytics>=8.4.0
tqdm>=4.65.0
torch>=2.0.0

View File

@@ -34,7 +34,7 @@ def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
result = run_training(
epochs=config.get("epochs", 100),
batch=config.get("batch_size", 16),
model=config.get("base_model", "yolo11n.pt"),
model=config.get("base_model", "yolo26s.pt"),
imgsz=config.get("imgsz", 1280),
name=config.get("name", f"training_{task_id[:8]}"),
)

View File

@@ -28,8 +28,8 @@ def main():
)
parser.add_argument(
'--model', '-m',
default='yolov8s.pt',
help='Base model (default: yolov8s.pt)'
default='yolo26s.pt',
help='Base model (default: yolo26s.pt)'
)
parser.add_argument(
'--epochs', '-e',

View File

@@ -100,12 +100,11 @@ class AnnotationGenerator:
x0, y0, x1, y1 = best_match.bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Apply field-specific bbox expansion strategy
# Apply uniform bbox expansion
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type=class_name,
)
# Ensure minimum height
@@ -173,12 +172,11 @@ class AnnotationGenerator:
x0, y0, x1, y1 = payment_line_bbox
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
# Apply field-specific bbox expansion strategy for payment_line
# Apply uniform bbox expansion
x0, y0, x1, y1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=image_width,
image_height=image_height,
field_type="payment_line",
)
# Convert to YOLO format (normalized center + size)

View File

@@ -585,15 +585,11 @@ class DBYOLODataset:
x1_px = x1_pdf * scale
y1_px = y1_pdf * scale
# Get class name for field-specific expansion
class_name = CLASS_NAMES[ann.class_id]
# Apply field-specific bbox expansion
# Apply uniform bbox expansion
x0, y0, x1, y1 = expand_bbox(
bbox=(x0_px, y0_px, x1_px, y1_px),
image_width=img_width,
image_height=img_height,
field_type=class_name,
)
# Ensure minimum height