This commit is contained in:
Yaojia Wang
2026-02-11 23:40:38 +01:00
parent f1a7bfe6b7
commit ad5ed46b4c
117 changed files with 5741 additions and 7669 deletions

View File

@@ -289,6 +289,16 @@ class ModelVersion(SQLModel, table=True):
is_active: bool = Field(default=False, index=True)
# Only one version can be active at a time for inference
# Model lineage
model_type: str = Field(default="base", max_length=20, index=True)
# "base" = trained from pretrained YOLO, "finetune" = fine-tuned from base model
base_model_version_id: UUID | None = Field(default=None, index=True)
# Points to the base model this was fine-tuned from (None for base models)
base_training_dataset_id: UUID | None = Field(default=None, index=True)
# The dataset used for original base training (for data mixing old samples)
gating_status: str = Field(default="pending", max_length=20, index=True)
# Deployment gating: pending, pass, review, reject, skipped
# Training association
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
@@ -317,6 +327,64 @@ class ModelVersion(SQLModel, table=True):
updated_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Fine-Tune Pool
# =============================================================================
class FineTunePoolEntry(SQLModel, table=True):
"""Document in the fine-tune pool for incremental model improvement."""
__tablename__ = "finetune_pool_entries"
entry_id: UUID = Field(default_factory=uuid4, primary_key=True)
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
added_by: str | None = Field(default=None, max_length=255)
reason: str | None = Field(default=None, max_length=255)
# Reason: user_reported_failure, manual_addition
is_verified: bool = Field(default=False, index=True)
verified_at: datetime | None = Field(default=None)
verified_by: str | None = Field(default=None, max_length=255)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Deployment Gating
# =============================================================================
class GatingResult(SQLModel, table=True):
"""Model deployment gating validation result."""
__tablename__ = "gating_results"
result_id: UUID = Field(default_factory=uuid4, primary_key=True)
model_version_id: UUID = Field(foreign_key="model_versions.version_id", index=True)
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id")
# Gate 1: Regression validation (original test set mAP)
gate1_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
gate1_original_mAP: float | None = Field(default=None)
gate1_new_mAP: float | None = Field(default=None)
gate1_mAP_drop: float | None = Field(default=None)
# Gate 2: New sample validation (detection rate on pool docs)
gate2_status: str = Field(default="pending", max_length=20)
gate2_detection_rate: float | None = Field(default=None)
gate2_total_samples: int | None = Field(default=None)
gate2_detected_samples: int | None = Field(default=None)
# Overall
overall_status: str = Field(default="pending", max_length=20)
# pending, pass, review, reject
reviewer_notes: str | None = Field(default=None)
reviewed_by: str | None = Field(default=None, max_length=255)
reviewed_at: datetime | None = Field(default=None)
created_at: datetime = Field(default_factory=datetime.utcnow)
# =============================================================================
# Annotation History (v2)
# =============================================================================

View File

@@ -13,6 +13,7 @@ from backend.data.repositories.training_task_repository import TrainingTaskRepos
from backend.data.repositories.dataset_repository import DatasetRepository
from backend.data.repositories.model_version_repository import ModelVersionRepository
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
from backend.data.repositories.finetune_pool_repository import FineTunePoolRepository
__all__ = [
"BaseRepository",
@@ -23,4 +24,5 @@ __all__ = [
"DatasetRepository",
"ModelVersionRepository",
"BatchUploadRepository",
"FineTunePoolRepository",
]

View File

@@ -0,0 +1,131 @@
"""
Fine-Tune Pool Repository
Manages the fine-tune pool: accumulated user-reported failure documents
for incremental model improvement.
"""
import logging
from datetime import datetime
from uuid import UUID
from sqlalchemy import func
from sqlmodel import select
from backend.data.database import get_session_context
from backend.data.admin_models import FineTunePoolEntry
from backend.data.repositories.base import BaseRepository
logger = logging.getLogger(__name__)
class FineTunePoolRepository(BaseRepository[FineTunePoolEntry]):
"""Repository for fine-tune pool management."""
def add_document(
self,
document_id: str | UUID,
added_by: str | None = None,
reason: str | None = None,
) -> FineTunePoolEntry:
"""Add a document to the fine-tune pool."""
with get_session_context() as session:
entry = FineTunePoolEntry(
document_id=UUID(str(document_id)),
added_by=added_by,
reason=reason,
)
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def get_entry(self, entry_id: str | UUID) -> FineTunePoolEntry | None:
"""Get a pool entry by ID."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if entry:
session.expunge(entry)
return entry
def get_by_document(self, document_id: str | UUID) -> FineTunePoolEntry | None:
"""Get pool entry for a document."""
with get_session_context() as session:
result = session.exec(
select(FineTunePoolEntry).where(
FineTunePoolEntry.document_id == UUID(str(document_id))
)
).first()
if result:
session.expunge(result)
return result
def get_paginated(
self,
verified_only: bool = False,
limit: int = 20,
offset: int = 0,
) -> tuple[list[FineTunePoolEntry], int]:
"""List pool entries with pagination."""
with get_session_context() as session:
query = select(FineTunePoolEntry)
count_query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
count_query = count_query.where(FineTunePoolEntry.is_verified == True)
total = session.exec(count_query).one()
entries = session.exec(
query.order_by(FineTunePoolEntry.created_at.desc())
.offset(offset)
.limit(limit)
).all()
for e in entries:
session.expunge(e)
return list(entries), total
def get_pool_count(self, verified_only: bool = True) -> int:
"""Get count of entries in the pool."""
with get_session_context() as session:
query = select(func.count()).select_from(FineTunePoolEntry)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
return session.exec(query).one()
def get_all_document_ids(self, verified_only: bool = True) -> list[UUID]:
"""Get all document IDs in the pool."""
with get_session_context() as session:
query = select(FineTunePoolEntry.document_id)
if verified_only:
query = query.where(FineTunePoolEntry.is_verified == True)
results = session.exec(query).all()
return list(results)
def verify_entry(
self,
entry_id: str | UUID,
verified_by: str | None = None,
) -> FineTunePoolEntry | None:
"""Mark a pool entry as verified."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return None
entry.is_verified = True
entry.verified_at = datetime.utcnow()
entry.verified_by = verified_by
session.add(entry)
session.commit()
session.refresh(entry)
session.expunge(entry)
return entry
def remove_entry(self, entry_id: str | UUID) -> bool:
"""Remove an entry from the pool."""
with get_session_context() as session:
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
if not entry:
return False
session.delete(entry)
session.commit()
return True

View File

@@ -43,6 +43,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config: dict[str, Any] | None = None,
file_size: int | None = None,
trained_at: datetime | None = None,
model_type: str = "base",
base_model_version_id: str | UUID | None = None,
base_training_dataset_id: str | UUID | None = None,
gating_status: str = "pending",
) -> ModelVersion:
"""Create a new model version."""
with get_session_context() as session:
@@ -60,6 +64,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
training_config=training_config,
file_size=file_size,
trained_at=trained_at,
model_type=model_type,
base_model_version_id=UUID(str(base_model_version_id)) if base_model_version_id else None,
base_training_dataset_id=UUID(str(base_training_dataset_id)) if base_training_dataset_id else None,
gating_status=gating_status,
)
session.add(model)
session.commit()

View File

@@ -40,6 +40,7 @@ from .normalizers import (
EnhancedAmountNormalizer,
EnhancedDateNormalizer,
)
from .value_selector import ValueSelector
@dataclass
@@ -169,13 +170,21 @@ class FieldExtractor:
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
matching_tokens.append((token, overlap_ratio))
# Sort by overlap ratio and combine text
# Sort by overlap ratio
matching_tokens.sort(key=lambda x: -x[1])
raw_text = ' '.join(t[0].text for t in matching_tokens)
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Convert to OCRTokens for value selection, then filter
from shared.ocr.paddle_ocr import OCRToken
pdf_ocr_tokens = [
OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0)
for t in matching_tokens
]
value_tokens = ValueSelector.select_value_tokens(pdf_ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text
@@ -223,13 +232,14 @@ class FieldExtractor:
# Run OCR on region
ocr_tokens = self.ocr_engine.extract_from_image(region)
# Combine all OCR text
raw_text = ' '.join(t.text for t in ocr_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Get field name
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
# Select value tokens (filter out label text)
value_tokens = ValueSelector.select_value_tokens(ocr_tokens, field_name)
raw_text = ' '.join(t.text for t in value_tokens)
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
# Normalize and validate
normalized_value, is_valid, validation_error = self._normalize_and_validate(
field_name, raw_text

View File

@@ -20,26 +20,98 @@ class AmountNormalizer(BaseNormalizer):
Handles various Swedish amount formats:
- With decimal: 1 234,56 kr
- With SEK suffix: 1234.56 SEK
- Payment line kronor/ore: 590 00 (space = decimal separator)
- Multiple amounts (returns the last one, usually the total)
"""
# Payment line kronor/ore pattern: "590 00" means 590.00 SEK
# Only matches when no comma/dot is present (pure digit-space-2digit format)
_KRONOR_ORE_PATTERN = re.compile(r'^(\d+)\s+(\d{2})$')
@property
def field_name(self) -> str:
return "Amount"
@classmethod
def _try_kronor_ore(cls, text: str) -> NormalizationResult | None:
"""Try to parse as payment line kronor/ore format.
Swedish payment lines separate kronor and ore with a space:
"590 00" = 590.00 SEK, "15658 00" = 15658.00 SEK
Only applies when text has no comma or dot (otherwise it's
a normal amount format with explicit decimal separator).
Returns NormalizationResult on success, None if not matched.
"""
if ',' in text or '.' in text:
return None
match = cls._KRONOR_ORE_PATTERN.match(text.strip())
if not match:
return None
kronor = match.group(1)
ore = match.group(2)
try:
amount = float(f"{kronor}.{ore}")
if amount > 0:
return NormalizationResult.success(f"{amount:.2f}")
except ValueError:
pass
return None
@staticmethod
def _parse_amount_str(match: str) -> float | None:
"""Convert matched amount string to float, detecting European vs Anglo format.
European: 2.254,50 -> 2254.50 (dot=thousand, comma=decimal)
Anglo: 1,234.56 -> 1234.56 (comma=thousand, dot=decimal)
Swedish: 1 234,56 -> 1234.56 (space=thousand, comma=decimal)
"""
has_comma = ',' in match
has_dot = '.' in match
if has_comma and has_dot:
if match.rfind(',') > match.rfind('.'):
# European: 2.254,50
cleaned = match.replace(" ", "").replace(".", "").replace(",", ".")
else:
# Anglo: 1,234.56
cleaned = match.replace(" ", "").replace(",", "")
elif has_comma:
cleaned = match.replace(" ", "").replace(",", ".")
else:
cleaned = match.replace(" ", "")
try:
return float(cleaned)
except ValueError:
return None
def normalize(self, text: str) -> NormalizationResult:
text = text.strip()
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Split by newlines and process line by line to get the last valid amount
lines = text.split("\n")
# Collect all valid amounts from all lines
all_amounts: list[float] = []
# Pattern for Swedish amount format (with decimals)
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
# Separate patterns for European and Anglo formats
# (?!\d) lookahead prevents partial matches (e.g. "1,23" in "1,234.56")
# European: dot=thousand, comma=decimal (2.254,50 or 1 234,56)
# Anglo: comma=thousand, dot=decimal (1,234.56 or 1234.56)
amount_pattern = (
r"(\d[\d\s.]*,\d{2})(?!\d)\s*(?:kr|SEK)?"
r"|"
r"(\d[\d\s,]*\.\d{2})(?!\d)\s*(?:kr|SEK)?"
)
for line in lines:
line = line.strip()
@@ -47,15 +119,13 @@ class AmountNormalizer(BaseNormalizer):
continue
# Find all amounts in this line
matches = re.findall(amount_pattern, line, re.IGNORECASE)
for match in matches:
amount_str = match.replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if amount > 0:
all_amounts.append(amount)
except ValueError:
for m in re.finditer(amount_pattern, line, re.IGNORECASE):
match = m.group(1) or m.group(2)
if not match:
continue
amount = self._parse_amount_str(match)
if amount is not None and amount > 0:
all_amounts.append(amount)
# Return the last amount found (usually the total)
if all_amounts:
@@ -122,31 +192,33 @@ class EnhancedAmountNormalizer(AmountNormalizer):
if not text:
return NormalizationResult.failure("Empty text")
# Early check: payment line kronor/ore format ("590 00" → 590.00)
kronor_ore_result = self._try_kronor_ore(text)
if kronor_ore_result is not None:
return kronor_ore_result
# Strategy 1: Apply OCR corrections first
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
# Strategy 2: Look for labeled amounts (highest priority)
# Use two capture groups: group(1) = European, group(2) = Anglo
labeled_patterns = [
# Swedish patterns
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
# Swedish patterns ((?!\d) prevents partial matches like "1,23" in "1,234.56")
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 1.0),
(
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
r"(?:moms|vat)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))",
0.8,
), # Lower priority for VAT
# Generic pattern
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
(r"(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))\s*(?:kr|sek|kronor)?", 0.7),
]
candidates: list[tuple[float, float, int]] = []
for pattern, priority in labeled_patterns:
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
amount_str = match.group(1).replace(" ", "").replace(",", ".")
try:
amount = float(amount_str)
if 0 < amount < 10_000_000: # Reasonable range
candidates.append((amount, priority, match.start()))
except ValueError:
continue
amount = self._parse_amount_str(match.group(1))
if amount is not None and 0 < amount < 10_000_000:
candidates.append((amount, priority, match.start()))
if candidates:
# Sort by priority (desc), then by position (later is usually total)

View File

@@ -301,6 +301,27 @@ class InferencePipeline:
all_extracted = []
all_ocr_text = [] # Collect OCR text for VAT extraction
# Check if PDF has readable text layer (avoids OCR for text PDFs)
from shared.pdf.extractor import PDFDocument
is_text_pdf = False
pdf_tokens_by_page: dict[int, list] = {}
try:
with PDFDocument(pdf_path) as pdf_doc:
is_text_pdf = pdf_doc.is_text_pdf()
if is_text_pdf:
for pg in range(pdf_doc.page_count):
pdf_tokens_by_page[pg] = list(
pdf_doc.extract_text_tokens(pg)
)
logger.info(
"Text PDF detected, extracted %d tokens from %d pages",
sum(len(t) for t in pdf_tokens_by_page.values()),
len(pdf_tokens_by_page),
)
except Exception as e:
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
is_text_pdf = False
# Process each page
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
# Convert to numpy array
@@ -313,7 +334,17 @@ class InferencePipeline:
# Extract fields from detections
for detection in detections:
extracted = self.extractor.extract_from_detection(detection, image_array)
if is_text_pdf and page_no in pdf_tokens_by_page:
extracted = self.extractor.extract_from_detection_with_pdf(
detection,
pdf_tokens_by_page[page_no],
image.width,
image.height,
)
else:
extracted = self.extractor.extract_from_detection(
detection, image_array
)
all_extracted.append(extracted)
# Collect full-page OCR text for VAT extraction (only if business features enabled)

View File

@@ -0,0 +1,172 @@
"""
Value Selector Module.
Selects the most likely value token(s) from OCR output per field type,
filtering out label text before sending to normalizer.
Stateless and pure -- easy to test, no side effects.
"""
import re
from typing import Final
from shared.ocr.paddle_ocr import OCRToken
# Swedish label keywords commonly found near field values
LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({
"fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum",
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
"total", "att", "betala", "kundnummer", "organisationsnummer",
"org", "nr", "datum", "nummer", "ref", "referens",
"momsreg", "vat", "moms", "sek", "kr",
"org.nr", "kund", "faktura", "invoice",
})
# Patterns
_DATE_PATTERN = re.compile(
r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15
r"|"
r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024
r"|"
r"\d{8}" # 20240115
)
_AMOUNT_PATTERN = re.compile(
r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK
r"|"
r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK
)
_BANKGIRO_PATTERN = re.compile(
r"^\d{3,4}-\d{4}$" # 123-4567
r"|"
r"^\d{7,8}$" # 1234567 or 12345678
)
_PLUSGIRO_PATTERN = re.compile(
r"^\d+-\d$" # 12345-6
r"|"
r"^\d{2,8}$" # 1234567
)
_ORG_NUMBER_PATTERN = re.compile(
r"\d{6}-?\d{4}" # 556123-4567 or 5561234567
)
def _is_label(text: str) -> bool:
"""Check if token text is a known Swedish label keyword."""
cleaned = text.lower().rstrip(":").strip()
return cleaned in LABEL_KEYWORDS
def _count_digits(text: str) -> int:
"""Count digit characters in text."""
return sum(c.isdigit() for c in text)
class ValueSelector:
"""Selects value token(s) from OCR output, filtering label text.
Pure static methods -- no state, no side effects.
Fallback: always returns all tokens if no pattern matches,
so this can never make results worse than current behavior.
"""
@staticmethod
def select_value_tokens(
tokens: list[OCRToken],
field_name: str,
) -> list[OCRToken]:
"""Select the most likely value token(s) for a given field.
Args:
tokens: OCR tokens from the detected region.
field_name: Normalized field name (e.g. "InvoiceDate", "Amount").
Returns:
Filtered list of value tokens. Falls back to all tokens
if no field-specific pattern matches.
"""
if not tokens:
return []
selector = _FIELD_SELECTORS.get(field_name, _fallback_selector)
selected = selector(tokens)
# Safety: never return empty if we had input tokens
if not selected:
return list(tokens)
return selected
@staticmethod
def _select_date(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _DATE_PATTERN)
@staticmethod
def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _AMOUNT_PATTERN)
@staticmethod
def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _BANKGIRO_PATTERN)
@staticmethod
def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _PLUSGIRO_PATTERN)
@staticmethod
def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]:
return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN)
@staticmethod
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
"""Select token with the longest digit sequence (min 5 digits)."""
best: OCRToken | None = None
best_count = 0
for token in tokens:
digit_count = _count_digits(token.text)
if digit_count >= 5 and digit_count > best_count:
best = token
best_count = digit_count
return [best] if best else []
@staticmethod
def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]:
"""Remove label keywords, keep remaining tokens."""
return [t for t in tokens if not _is_label(t.text)]
@staticmethod
def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]:
"""Payment line keeps all tokens (needs full text)."""
return list(tokens)
def _select_by_pattern(
tokens: list[OCRToken],
pattern: re.Pattern[str],
) -> list[OCRToken]:
"""Select tokens matching a regex pattern."""
return [t for t in tokens if pattern.search(t.text.strip())]
def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]:
"""Default: return all tokens unchanged."""
return list(tokens)
# Map field names to selector functions
_FIELD_SELECTORS: Final[dict[str, callable]] = {
"InvoiceDate": ValueSelector._select_date,
"InvoiceDueDate": ValueSelector._select_date,
"Amount": ValueSelector._select_amount,
"Bankgiro": ValueSelector._select_bankgiro,
"Plusgiro": ValueSelector._select_plusgiro,
"OCR": ValueSelector._select_ocr_number,
"InvoiceNumber": ValueSelector._select_by_label_exclusion,
"supplier_org_number": ValueSelector._select_org_number,
"customer_number": ValueSelector._select_by_label_exclusion,
"payment_line": ValueSelector._select_payment_line,
}

View File

@@ -12,6 +12,7 @@ from .documents import register_document_routes
from .export import register_export_routes
from .datasets import register_dataset_routes
from .models import register_model_routes
from .pool import register_pool_routes
def create_training_router() -> APIRouter:
@@ -23,6 +24,7 @@ def create_training_router() -> APIRouter:
register_export_routes(router)
register_dataset_routes(router)
register_model_routes(router)
register_pool_routes(router)
return router

View File

@@ -12,6 +12,7 @@ from backend.web.core.auth import (
AnnotationRepoDep,
ModelVersionRepoDep,
TrainingTaskRepoDep,
FineTunePoolRepoDep,
)
from backend.web.schemas.admin import (
DatasetCreateRequest,
@@ -233,6 +234,7 @@ def register_dataset_routes(router: APIRouter) -> None:
datasets_repo: DatasetRepoDep,
models: ModelVersionRepoDep,
tasks: TrainingTaskRepoDep,
pool: FineTunePoolRepoDep,
) -> TrainingTaskResponse:
"""Create a training task from a dataset.
@@ -261,13 +263,39 @@ def register_dataset_routes(router: APIRouter) -> None:
status_code=404,
detail=f"Base model version not found: {base_model_version_id}",
)
# Chain prevention: never fine-tune from a fine-tuned model
if getattr(base_model, "model_type", "base") == "finetune":
original_base_id = getattr(base_model, "base_model_version_id", None)
raise HTTPException(
status_code=400,
detail=(
f"Cannot chain fine-tunes. Model {base_model.version} is already "
f"a fine-tuned model. Select the original base model instead"
f"{f' (base_model_version_id: {original_base_id})' if original_base_id else ''}."
),
)
# Pool threshold: require minimum verified pool entries for fine-tuning
from backend.web.services.data_mixer import MIN_POOL_SIZE
verified_count = pool.get_pool_count(verified_only=True)
if verified_count < MIN_POOL_SIZE:
raise HTTPException(
status_code=400,
detail=(
f"Fine-tuning requires at least {MIN_POOL_SIZE} verified pool entries "
f"(currently {verified_count}). Add more documents to the fine-tune "
f"pool and verify them before starting fine-tuning."
),
)
# Store the resolved model path for the training worker
config_dict["base_model_path"] = base_model.model_path
config_dict["base_model_version"] = base_model.version
logger.info(
"Incremental training: using model %s (%s) as base",
"Fine-tuning: using base model %s (%s) with %d verified pool entries",
base_model.version,
base_model.model_path,
verified_count,
)
task_id = tasks.create(

View File

@@ -124,16 +124,11 @@ def register_export_routes(router: APIRouter) -> None:
x1 = ann.x_center * img_width + half_w
y1 = ann.y_center * img_height + half_h
# Use manual_mode for manual/imported annotations
manual_mode = ann.source in ("manual", "imported")
# Apply field-specific bbox expansion
# Apply uniform bbox expansion
ex0, ey0, ex1, ey1 = expand_bbox(
bbox=(x0, y0, x1, y1),
image_width=img_width,
image_height=img_height,
field_type=ann.class_name,
manual_mode=manual_mode,
)
# Convert back to normalized YOLO format

View File

@@ -88,6 +88,9 @@ def register_model_routes(router: APIRouter) -> None:
name=m.name,
status=m.status,
is_active=m.is_active,
model_type=getattr(m, "model_type", "base"),
base_model_version_id=str(m.base_model_version_id) if getattr(m, "base_model_version_id", None) else None,
gating_status=getattr(m, "gating_status", "pending"),
metrics_mAP=m.metrics_mAP,
document_count=m.document_count,
trained_at=m.trained_at,
@@ -121,6 +124,9 @@ def register_model_routes(router: APIRouter) -> None:
name=model.name,
status=model.status,
is_active=model.is_active,
model_type=getattr(model, "model_type", "base"),
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
gating_status=getattr(model, "gating_status", "pending"),
metrics_mAP=model.metrics_mAP,
document_count=model.document_count,
trained_at=model.trained_at,
@@ -153,6 +159,10 @@ def register_model_routes(router: APIRouter) -> None:
model_path=model.model_path,
status=model.status,
is_active=model.is_active,
model_type=getattr(model, "model_type", "base"),
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
base_training_dataset_id=str(model.base_training_dataset_id) if getattr(model, "base_training_dataset_id", None) else None,
gating_status=getattr(model, "gating_status", "pending"),
task_id=str(model.task_id) if model.task_id else None,
dataset_id=str(model.dataset_id) if model.dataset_id else None,
metrics_mAP=model.metrics_mAP,
@@ -209,6 +219,25 @@ def register_model_routes(router: APIRouter) -> None:
) -> ModelVersionResponse:
"""Activate a model version for inference."""
_validate_uuid(version_id, "version_id")
# Check gating status before activation (for fine-tuned models)
pre_check = models.get(version_id)
if not pre_check:
raise HTTPException(status_code=404, detail="Model version not found")
model_type = getattr(pre_check, "model_type", "base")
gating_status = getattr(pre_check, "gating_status", "skipped")
if model_type == "finetune":
if gating_status == "reject":
raise HTTPException(
status_code=400,
detail="Model failed deployment gating validation. Cannot activate a rejected model.",
)
if gating_status == "pending":
raise HTTPException(
status_code=400,
detail="Model gating validation not yet completed. Wait for validation to finish.",
)
model = models.activate(version_id)
if not model:
raise HTTPException(status_code=404, detail="Model version not found")
@@ -227,6 +256,8 @@ def register_model_routes(router: APIRouter) -> None:
message = "Model version activated for inference"
if model_reloaded:
message += " (model reloaded)"
if gating_status == "review":
message += " (WARNING: gating status is REVIEW - manual inspection recommended)"
return ModelVersionResponse(
version_id=str(model.version_id),

View File

@@ -0,0 +1,159 @@
"""Fine-Tune Pool Endpoints."""
import logging
from typing import Annotated
from fastapi import APIRouter, HTTPException, Query
from backend.web.core.auth import AdminTokenDep, FineTunePoolRepoDep, DocumentRepoDep
from backend.web.schemas.admin.pool import (
PoolAddRequest,
PoolEntryItem,
PoolEntryResponse,
PoolListResponse,
PoolStatsResponse,
)
from ._utils import _validate_uuid
logger = logging.getLogger(__name__)
def register_pool_routes(router: APIRouter) -> None:
"""Register fine-tune pool endpoints on the router."""
@router.post(
"/pool",
response_model=PoolEntryResponse,
summary="Add document to fine-tune pool",
description="Add a labeled document to the fine-tune pool for future fine-tuning.",
)
async def add_to_pool(
request: PoolAddRequest,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
docs: DocumentRepoDep,
) -> PoolEntryResponse:
"""Add a document to the fine-tune pool."""
_validate_uuid(request.document_id, "document_id")
# Verify document exists
doc = docs.get(request.document_id)
if not doc:
raise HTTPException(status_code=404, detail="Document not found")
# Check if already in pool
existing = pool.get_by_document(request.document_id)
if existing:
raise HTTPException(
status_code=409,
detail=f"Document already in fine-tune pool (entry_id: {existing.entry_id})",
)
entry = pool.add_document(
document_id=request.document_id,
added_by=admin_token,
reason=request.reason,
)
return PoolEntryResponse(
entry_id=str(entry.entry_id),
message="Document added to fine-tune pool",
)
@router.get(
"/pool",
response_model=PoolListResponse,
summary="List fine-tune pool entries",
)
async def list_pool_entries(
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
verified_only: Annotated[bool, Query(description="Filter to verified only")] = False,
limit: Annotated[int, Query(ge=1, le=100)] = 20,
offset: Annotated[int, Query(ge=0)] = 0,
) -> PoolListResponse:
"""List entries in the fine-tune pool."""
entries, total = pool.get_paginated(
verified_only=verified_only,
limit=limit,
offset=offset,
)
return PoolListResponse(
total=total,
limit=limit,
offset=offset,
entries=[
PoolEntryItem(
entry_id=str(e.entry_id),
document_id=str(e.document_id),
added_by=e.added_by,
reason=e.reason,
is_verified=e.is_verified,
verified_at=e.verified_at,
verified_by=e.verified_by,
created_at=e.created_at,
)
for e in entries
],
)
@router.get(
"/pool/stats",
response_model=PoolStatsResponse,
summary="Get fine-tune pool statistics",
)
async def get_pool_stats(
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> PoolStatsResponse:
"""Get statistics about the fine-tune pool."""
total = pool.get_pool_count(verified_only=False)
verified = pool.get_pool_count(verified_only=True)
return PoolStatsResponse(
total_entries=total,
verified_entries=verified,
unverified_entries=total - verified,
is_ready=verified >= 50,
)
@router.post(
"/pool/{entry_id}/verify",
response_model=PoolEntryResponse,
summary="Verify a pool entry",
description="Mark a pool entry as verified (human-reviewed).",
)
async def verify_pool_entry(
entry_id: str,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> PoolEntryResponse:
"""Mark a pool entry as verified."""
_validate_uuid(entry_id, "entry_id")
entry = pool.verify_entry(entry_id, verified_by=admin_token)
if not entry:
raise HTTPException(status_code=404, detail="Pool entry not found")
return PoolEntryResponse(
entry_id=str(entry.entry_id),
message="Pool entry verified",
)
@router.delete(
"/pool/{entry_id}",
summary="Remove from fine-tune pool",
)
async def remove_from_pool(
entry_id: str,
admin_token: AdminTokenDep,
pool: FineTunePoolRepoDep,
) -> dict:
"""Remove a document from the fine-tune pool."""
_validate_uuid(entry_id, "entry_id")
success = pool.remove_entry(entry_id)
if not success:
raise HTTPException(status_code=404, detail="Pool entry not found")
return {"message": "Entry removed from fine-tune pool"}

View File

@@ -17,6 +17,7 @@ from backend.data.repositories import (
TrainingTaskRepository,
ModelVersionRepository,
BatchUploadRepository,
FineTunePoolRepository,
)
@@ -95,6 +96,12 @@ def get_batch_upload_repository() -> BatchUploadRepository:
return BatchUploadRepository()
@lru_cache(maxsize=1)
def get_finetune_pool_repository() -> FineTunePoolRepository:
"""Get the FineTunePoolRepository instance (thread-safe singleton)."""
return FineTunePoolRepository()
def reset_all_repositories() -> None:
"""Reset all repository instances (for testing)."""
get_token_repository.cache_clear()
@@ -104,6 +111,7 @@ def reset_all_repositories() -> None:
get_training_task_repository.cache_clear()
get_model_version_repository.cache_clear()
get_batch_upload_repository.cache_clear()
get_finetune_pool_repository.cache_clear()
# Repository dependency type aliases
@@ -113,3 +121,4 @@ DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]
FineTunePoolRepoDep = Annotated[FineTunePoolRepository, Depends(get_finetune_pool_repository)]

View File

@@ -16,6 +16,7 @@ from backend.data.repositories import (
ModelVersionRepository,
DocumentRepository,
AnnotationRepository,
FineTunePoolRepository,
)
from backend.web.core.task_interface import TaskRunner, TaskStatus
from backend.web.services.storage_helpers import get_storage_helper
@@ -47,6 +48,7 @@ class TrainingScheduler(TaskRunner):
self._model_versions = ModelVersionRepository()
self._documents = DocumentRepository()
self._annotations = AnnotationRepository()
self._pool = FineTunePoolRepository()
@property
def name(self) -> str:
@@ -168,7 +170,7 @@ class TrainingScheduler(TaskRunner):
try:
# Get training configuration
model_name = config.get("model_name", "yolo11n.pt")
model_name = config.get("model_name", "yolo26s.pt")
base_model_path = config.get("base_model_path") # For incremental training
epochs = config.get("epochs", 100)
batch_size = config.get("batch_size", 16)
@@ -182,14 +184,19 @@ class TrainingScheduler(TaskRunner):
augmentation_multiplier = config.get("augmentation_multiplier", 2)
# Determine which model to use as base
if base_model_path:
# Incremental training: use existing trained model
is_finetune = bool(base_model_path)
if is_finetune:
# Fine-tuning: use existing trained model as base
if not Path(base_model_path).exists():
raise ValueError(f"Base model not found: {base_model_path}")
effective_model = base_model_path
# Override parameters for fine-tuning (best practices)
epochs = config.get("epochs", 10)
learning_rate = config.get("learning_rate", 0.001)
self._training_tasks.add_log(
task_id, "INFO",
f"Incremental training from: {base_model_path}",
f"Fine-tuning from: {base_model_path} "
f"(epochs={epochs}, freeze=10, cos_lr=true)",
)
else:
# Train from pretrained model
@@ -229,10 +236,16 @@ class TrainingScheduler(TaskRunner):
f"(total: {aug_result['total_images']})",
)
# Build mixed dataset for fine-tuning (pool samples + old data)
if is_finetune and dataset_id:
data_yaml, dataset_path = self._build_mixed_finetune_dataset(
task_id, dataset_path, data_yaml,
)
# Run YOLO training
result = self._run_yolo_training(
task_id=task_id,
model_name=effective_model, # Use base model or pretrained model
model_name=effective_model,
data_yaml=data_yaml,
epochs=epochs,
batch_size=batch_size,
@@ -240,6 +253,8 @@ class TrainingScheduler(TaskRunner):
learning_rate=learning_rate,
device=device,
project_name=project_name,
freeze=10 if is_finetune else 0,
cos_lr=is_finetune,
)
# Update task with results
@@ -261,13 +276,23 @@ class TrainingScheduler(TaskRunner):
)
# Auto-create model version for the completed training
self._create_model_version_from_training(
model_version = self._create_model_version_from_training(
task_id=task_id,
config=config,
dataset_id=dataset_id,
result=result,
)
# Auto-run gating validation for fine-tuned models
if is_finetune and model_version:
self._run_gating_after_finetune(
task_id=task_id,
model_version=model_version,
config=config,
data_yaml=data_yaml,
result=result,
)
except Exception as e:
logger.error(f"Training task {task_id} failed: {e}")
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
@@ -286,13 +311,16 @@ class TrainingScheduler(TaskRunner):
config: dict[str, Any],
dataset_id: str | None,
result: dict[str, Any],
) -> None:
"""Create a model version entry from completed training."""
) -> Any | None:
"""Create a model version entry from completed training.
Returns the created model version, or None on failure.
"""
try:
model_path = result.get("model_path")
if not model_path:
logger.warning(f"No model path in training result for task {task_id}")
return
return None
# Get task info for name
task = self._training_tasks.get(task_id)
@@ -322,6 +350,12 @@ class TrainingScheduler(TaskRunner):
if dataset:
document_count = dataset.total_documents
# Determine model lineage
is_finetune = bool(config.get("base_model_path"))
model_type = "finetune" if is_finetune else "base"
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
gating_status = "pending" if is_finetune else "skipped"
# Create model version
model_version = self._model_versions.create(
version=version,
@@ -337,6 +371,10 @@ class TrainingScheduler(TaskRunner):
training_config=config,
file_size=file_size,
trained_at=datetime.utcnow(),
model_type=model_type,
base_model_version_id=base_model_version_id,
base_training_dataset_id=dataset_id if not is_finetune else None,
gating_status=gating_status,
)
logger.info(
@@ -349,12 +387,105 @@ class TrainingScheduler(TaskRunner):
f"Model version {version} created (mAP: {mAP_display})",
)
return model_version
except Exception as e:
logger.error(f"Failed to create model version for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Failed to auto-create model version: {e}",
)
return None
def _build_mixed_finetune_dataset(
self,
task_id: str,
base_dataset_path: Path,
original_data_yaml: str,
) -> tuple[str, Path]:
"""Build a mixed dataset for fine-tuning.
Combines verified pool samples with randomly sampled old training data.
Returns:
Tuple of (data_yaml path, dataset_path) for the mixed dataset.
Falls back to original if mixing fails or pool is empty.
"""
try:
from backend.web.services.data_mixer import build_mixed_dataset
pool_doc_ids = self._pool.get_all_document_ids(verified_only=True)
if not pool_doc_ids:
self._training_tasks.add_log(
task_id, "INFO",
"No verified pool entries found, using original dataset",
)
return original_data_yaml, base_dataset_path
mixed_output = base_dataset_path.parent / f"mixed_{task_id[:8]}"
mix_result = build_mixed_dataset(
pool_document_ids=pool_doc_ids,
base_dataset_path=base_dataset_path,
output_dir=mixed_output,
)
self._training_tasks.add_log(
task_id, "INFO",
f"Data mixing: {mix_result['new_images']} new + "
f"{mix_result['old_images']} old = {mix_result['total_images']} total "
f"(ratio: {mix_result['mixing_ratio']}x)",
)
return mix_result["data_yaml"], mixed_output
except Exception as e:
logger.error(f"Data mixing failed for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Data mixing failed: {e}. Using original dataset.",
)
return original_data_yaml, base_dataset_path
def _run_gating_after_finetune(
self,
task_id: str,
model_version: Any,
config: dict[str, Any],
data_yaml: str,
result: dict[str, Any],
) -> None:
"""Run gating validation after a fine-tune training completes."""
try:
from backend.web.services.gating_validator import run_gating_validation
model_path = result.get("model_path")
base_model_version_id = config.get("base_model_version_id")
version_id = str(model_version.version_id)
self._training_tasks.add_log(
task_id, "INFO", "Running deployment gating validation...",
)
gating_result = run_gating_validation(
model_version_id=version_id,
new_model_path=model_path,
base_model_version_id=base_model_version_id,
data_yaml=data_yaml,
task_id=task_id,
)
self._training_tasks.add_log(
task_id, "INFO",
f"Gating result: {gating_result.overall_status} "
f"(gate1={gating_result.gate1_status}, gate2={gating_result.gate2_status})",
)
except Exception as e:
logger.error(f"Gating validation failed for task {task_id}: {e}")
self._training_tasks.add_log(
task_id, "WARNING",
f"Gating validation failed: {e}. Model remains in 'pending' state.",
)
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
"""Export training data for a task."""
@@ -456,6 +587,8 @@ names: {list(FIELD_CLASSES.values())}
learning_rate: float,
device: str,
project_name: str,
freeze: int = 0,
cos_lr: bool = False,
) -> dict[str, Any]:
"""Run YOLO training using shared trainer."""
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
@@ -479,6 +612,8 @@ names: {list(FIELD_CLASSES.values())}
project="runs/train",
name=f"{project_name}/task_{task_id[:8]}",
workers=0,
freeze=freeze,
cos_lr=cos_lr,
)
# Run training using shared trainer

View File

@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
from .training import * # noqa: F401, F403
from .datasets import * # noqa: F401, F403
from .models import * # noqa: F401, F403
from .pool import * # noqa: F401, F403
from .dashboard import * # noqa: F401, F403
# Resolve forward references for DocumentDetailResponse

View File

@@ -40,6 +40,9 @@ class ModelVersionItem(BaseModel):
name: str = Field(..., description="Model name")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
model_type: str = Field(default="base", description="Model type (base or finetune)")
base_model_version_id: str | None = Field(None, description="Base model version UUID (for fine-tuned models)")
gating_status: str = Field(default="pending", description="Deployment gating status")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
document_count: int = Field(..., description="Documents used in training")
trained_at: datetime | None = Field(None, description="Training completion time")
@@ -66,6 +69,10 @@ class ModelVersionDetailResponse(BaseModel):
model_path: str = Field(..., description="Path to model file")
status: str = Field(..., description="Status (active, inactive, archived)")
is_active: bool = Field(..., description="Is currently active for inference")
model_type: str = Field(default="base", description="Model type (base or finetune)")
base_model_version_id: str | None = Field(None, description="Base model version UUID")
base_training_dataset_id: str | None = Field(None, description="Base training dataset UUID")
gating_status: str = Field(default="pending", description="Deployment gating status")
task_id: str | None = Field(None, description="Training task UUID")
dataset_id: str | None = Field(None, description="Dataset UUID")
metrics_mAP: float | None = Field(None, description="Mean Average Precision")

View File

@@ -0,0 +1,72 @@
"""Admin Fine-Tune Pool Schemas."""
from datetime import datetime
from pydantic import BaseModel, Field
class PoolAddRequest(BaseModel):
"""Request to add a document to the fine-tune pool."""
document_id: str = Field(..., description="Document UUID to add")
reason: str = Field(
default="user_reported_failure",
description="Reason: user_reported_failure, manual_addition",
)
class PoolEntryItem(BaseModel):
"""Fine-tune pool entry."""
entry_id: str = Field(..., description="Entry UUID")
document_id: str = Field(..., description="Document UUID")
added_by: str | None = Field(None, description="Who added this entry")
reason: str | None = Field(None, description="Reason for adding")
is_verified: bool = Field(..., description="Whether entry has been verified")
verified_at: datetime | None = Field(None, description="Verification timestamp")
verified_by: str | None = Field(None, description="Who verified")
created_at: datetime = Field(..., description="Creation timestamp")
class PoolListResponse(BaseModel):
"""Paginated pool entry list."""
total: int = Field(..., ge=0, description="Total entries")
limit: int = Field(..., ge=1, description="Page size")
offset: int = Field(..., ge=0, description="Current offset")
entries: list[PoolEntryItem] = Field(default_factory=list, description="Pool entries")
class PoolStatsResponse(BaseModel):
"""Pool statistics."""
total_entries: int = Field(..., ge=0, description="Total pool entries")
verified_entries: int = Field(..., ge=0, description="Verified entries")
unverified_entries: int = Field(..., ge=0, description="Unverified entries")
is_ready: bool = Field(..., description="Whether pool has >= 50 verified entries for fine-tuning")
min_required: int = Field(default=50, description="Minimum verified entries required")
class PoolEntryResponse(BaseModel):
"""Response for pool entry operation."""
entry_id: str = Field(..., description="Entry UUID")
message: str = Field(..., description="Status message")
class GatingResultItem(BaseModel):
"""Gating validation result."""
result_id: str = Field(..., description="Result UUID")
model_version_id: str = Field(..., description="Model version UUID")
gate1_status: str = Field(..., description="Gate 1 status")
gate1_original_mAP: float | None = Field(None, description="Original model mAP")
gate1_new_mAP: float | None = Field(None, description="New model mAP")
gate1_mAP_drop: float | None = Field(None, description="mAP drop percentage")
gate2_status: str = Field(..., description="Gate 2 status")
gate2_detection_rate: float | None = Field(None, description="Detection rate on new samples")
gate2_total_samples: int | None = Field(None, description="Total new samples tested")
gate2_detected_samples: int | None = Field(None, description="Samples correctly detected")
overall_status: str = Field(..., description="Overall gating status")
reviewer_notes: str | None = Field(None, description="Reviewer notes")
created_at: datetime = Field(..., description="Creation timestamp")

View File

@@ -12,7 +12,7 @@ from .enums import TrainingStatus, TrainingType
class TrainingConfig(BaseModel):
"""Training configuration."""
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
model_name: str = Field(default="yolo26s.pt", description="Base model name (used if no base_model_version_id)")
base_model_version_id: str | None = Field(
default=None,
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",

View File

@@ -0,0 +1,199 @@
"""
Data Mixing Service
Mixes fine-tune pool samples with randomly sampled old training data
following best practices for incremental YOLO fine-tuning.
"""
import logging
import random
import shutil
from pathlib import Path
from typing import Any
from uuid import UUID
logger = logging.getLogger(__name__)
# Mixing ratios: (max_new_samples, old_data_multiplier)
# Fewer new samples require higher old data ratio to prevent catastrophic forgetting
MIXING_RATIOS: list[tuple[int, int]] = [
(10, 50), # <= 10 new samples: 50x old
(50, 20), # <= 50 new samples: 20x old
(200, 10), # <= 200 new samples: 10x old
(500, 5), # <= 500 new samples: 5x old
]
# Default multiplier for 500+ samples
DEFAULT_MULTIPLIER = 5
# Maximum old samples to include (cap for performance)
MAX_OLD_SAMPLES = 3000
# Minimum pool size for fine-tuning
MIN_POOL_SIZE = 50
def get_mixing_ratio(new_sample_count: int) -> int:
"""Determine old data multiplier based on new sample count.
Args:
new_sample_count: Number of new samples in the fine-tune pool.
Returns:
Multiplier for old data sampling.
"""
for threshold, multiplier in MIXING_RATIOS:
if new_sample_count <= threshold:
return multiplier
return DEFAULT_MULTIPLIER
def build_mixed_dataset(
pool_document_ids: list[UUID],
base_dataset_path: Path,
output_dir: Path,
seed: int = 42,
) -> dict[str, Any]:
"""Build a mixed dataset for fine-tuning.
Combines ALL fine-tune pool samples with randomly sampled old data
from the base training dataset.
Args:
pool_document_ids: Document IDs from the fine-tune pool.
base_dataset_path: Path to the base training dataset directory.
output_dir: Output directory for the mixed dataset.
seed: Random seed for reproducible sampling.
Returns:
Dictionary with dataset info (data_yaml path, counts).
"""
new_count = len(pool_document_ids)
multiplier = get_mixing_ratio(new_count)
old_target = min(new_count * multiplier, MAX_OLD_SAMPLES)
logger.info(
"Building mixed dataset: %d new samples, %dx multiplier, "
"targeting %d old samples",
new_count, multiplier, old_target,
)
# Create output directory structure
output_dir.mkdir(parents=True, exist_ok=True)
for split in ("train", "val"):
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
# Collect old training images from base dataset
old_train_images = _collect_images(base_dataset_path / "images" / "train")
old_val_images = _collect_images(base_dataset_path / "images" / "val")
# Randomly sample old data
rng = random.Random(seed)
all_old_images = old_train_images + old_val_images
if len(all_old_images) > old_target:
sampled_old = rng.sample(all_old_images, old_target)
else:
sampled_old = all_old_images
# Split old samples: 80% train, 20% val
rng.shuffle(sampled_old)
old_train_count = int(len(sampled_old) * 0.8)
old_train = sampled_old[:old_train_count]
old_val = sampled_old[old_train_count:]
# Copy old samples to mixed dataset
old_copied = 0
for split_name, images in [("train", old_train), ("val", old_val)]:
for img_path in images:
label_path = _image_to_label_path(img_path)
dst_img = output_dir / "images" / split_name / img_path.name
dst_label = output_dir / "labels" / split_name / label_path.name
if img_path.exists():
shutil.copy2(img_path, dst_img)
old_copied += 1
if label_path.exists():
shutil.copy2(label_path, dst_label)
# Copy new pool samples (from base dataset, identified by document_id prefix)
# Pool documents go into train split (80%) and val split (20%)
pool_id_strs = {str(doc_id) for doc_id in pool_document_ids}
new_images = _find_pool_images(base_dataset_path, pool_id_strs)
rng.shuffle(new_images)
new_train_count = int(len(new_images) * 0.8)
new_train = new_images[:new_train_count]
new_val = new_images[new_train_count:]
new_copied = 0
for split_name, images in [("train", new_train), ("val", new_val)]:
for img_path in images:
label_path = _image_to_label_path(img_path)
dst_img = output_dir / "images" / split_name / img_path.name
dst_label = output_dir / "labels" / split_name / label_path.name
if img_path.exists() and not dst_img.exists():
shutil.copy2(img_path, dst_img)
new_copied += 1
if label_path.exists() and not dst_label.exists():
shutil.copy2(label_path, dst_label)
# Generate data.yaml
from shared.fields import FIELD_CLASSES
yaml_path = output_dir / "data.yaml"
yaml_content = (
f"path: {output_dir.absolute()}\n"
f"train: images/train\n"
f"val: images/val\n"
f"\n"
f"nc: {len(FIELD_CLASSES)}\n"
f"names: {list(FIELD_CLASSES.values())}\n"
)
yaml_path.write_text(yaml_content)
total_images = old_copied + new_copied
logger.info(
"Mixed dataset built: %d old + %d new = %d total images",
old_copied, new_copied, total_images,
)
return {
"data_yaml": str(yaml_path),
"total_images": total_images,
"old_images": old_copied,
"new_images": new_copied,
"mixing_ratio": multiplier,
}
def _collect_images(images_dir: Path) -> list[Path]:
"""Collect all image files from a directory."""
if not images_dir.exists():
return []
return sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
def _image_to_label_path(image_path: Path) -> Path:
"""Convert image path to corresponding label path."""
labels_dir = image_path.parent.parent.parent / "labels" / image_path.parent.name
return labels_dir / image_path.with_suffix(".txt").name
def _find_pool_images(
base_dataset_path: Path,
pool_doc_ids: set[str],
) -> list[Path]:
"""Find images in base dataset that belong to pool documents."""
images: list[Path] = []
for split in ("train", "val", "test"):
split_dir = base_dataset_path / "images" / split
if not split_dir.exists():
continue
for img_path in split_dir.iterdir():
if not img_path.is_file():
continue
# Image filenames are like: {doc_id}_page{N}.png
doc_id = img_path.stem.rsplit("_page", 1)[0]
if doc_id in pool_doc_ids:
images.append(img_path)
return images

View File

@@ -0,0 +1,198 @@
"""
Gating Validation Service
Validates fine-tuned models before deployment using quality gates:
- Gate 1: Regression validation (mAP drop on original test set)
- Gate 2: New sample validation (detection rate on pool documents)
"""
import logging
from pathlib import Path
from uuid import UUID
from backend.data.admin_models import GatingResult
from backend.data.database import get_session_context
from backend.data.repositories.model_version_repository import ModelVersionRepository
logger = logging.getLogger(__name__)
# Gate 1 thresholds (mAP drop)
GATE1_PASS_THRESHOLD = 0.01 # < 1% drop = PASS
GATE1_REVIEW_THRESHOLD = 0.03 # 1-3% drop = REVIEW, > 3% = REJECT
# Gate 2 thresholds (detection rate)
GATE2_PASS_THRESHOLD = 0.80 # > 80% detection rate = PASS
def classify_gate1(mAP_drop: float) -> str:
"""Classify Gate 1 result based on mAP drop.
Args:
mAP_drop: Absolute mAP drop (positive means degradation).
Returns:
"pass", "review", or "reject"
"""
if mAP_drop < GATE1_PASS_THRESHOLD:
return "pass"
if mAP_drop < GATE1_REVIEW_THRESHOLD:
return "review"
return "reject"
def classify_gate2(detection_rate: float) -> str:
"""Classify Gate 2 result based on detection rate.
Args:
detection_rate: Fraction of new samples correctly detected (0-1).
Returns:
"pass" or "review"
"""
if detection_rate >= GATE2_PASS_THRESHOLD:
return "pass"
return "review"
def compute_overall_status(gate1_status: str, gate2_status: str) -> str:
"""Compute overall gating status from individual gates.
Rules:
- Any "reject" -> overall "reject"
- Any "review" (and no reject) -> overall "review"
- All "pass" -> overall "pass"
"""
if gate1_status == "reject" or gate2_status == "reject":
return "reject"
if gate1_status == "review" or gate2_status == "review":
return "review"
return "pass"
def run_gating_validation(
model_version_id: str | UUID,
new_model_path: str,
base_model_version_id: str | UUID | None,
data_yaml: str,
task_id: str | UUID | None = None,
) -> GatingResult:
"""Run deployment gating validation for a fine-tuned model.
Args:
model_version_id: The fine-tuned model version to validate.
new_model_path: Path to the new model weights.
base_model_version_id: The base model version to compare against.
data_yaml: Path to the dataset YAML (for validation).
task_id: Optional training task ID.
Returns:
GatingResult with gate statuses.
"""
model_versions = ModelVersionRepository()
# Gate 1: Regression validation
gate1_status = "pending"
gate1_original_mAP = None
gate1_new_mAP = None
gate1_mAP_drop = None
try:
if base_model_version_id:
base_model = model_versions.get(str(base_model_version_id))
if base_model and base_model.metrics_mAP is not None:
gate1_original_mAP = base_model.metrics_mAP
# Run validation with new model
from shared.training import YOLOTrainer, TrainingConfig
val_config = TrainingConfig(
model_path=new_model_path,
data_yaml=data_yaml,
)
trainer = YOLOTrainer(config=val_config)
val_metrics = trainer.validate(split="val")
gate1_new_mAP = val_metrics.get("mAP50")
if gate1_new_mAP is not None:
gate1_mAP_drop = gate1_original_mAP - gate1_new_mAP
gate1_status = classify_gate1(gate1_mAP_drop)
logger.info(
"Gate 1: original_mAP=%.4f, new_mAP=%.4f, drop=%.4f -> %s",
gate1_original_mAP, gate1_new_mAP, gate1_mAP_drop, gate1_status,
)
else:
gate1_status = "review"
logger.warning("Gate 1: Could not compute new mAP, marking as review")
else:
gate1_status = "pass"
logger.info("Gate 1: No base model metrics available, skipping (pass)")
else:
gate1_status = "pass"
logger.info("Gate 1: No base model specified, skipping (pass)")
except Exception as e:
gate1_status = "review"
logger.error("Gate 1 failed: %s", e)
# Gate 2: New sample validation
# For now, we use the training metrics as a proxy
# Full implementation would run inference on pool documents
gate2_status = "pass"
gate2_detection_rate = None
gate2_total_samples = None
gate2_detected_samples = None
try:
new_model = model_versions.get(str(model_version_id))
if new_model and new_model.metrics_mAP is not None:
# Use mAP as proxy for detection rate on new samples
gate2_detection_rate = new_model.metrics_mAP
if gate2_detection_rate is not None:
gate2_status = classify_gate2(gate2_detection_rate)
logger.info(
"Gate 2: detection_rate=%.4f -> %s",
gate2_detection_rate, gate2_status,
)
except Exception as e:
gate2_status = "review"
logger.error("Gate 2 failed: %s", e)
# Compute overall status
overall_status = compute_overall_status(gate1_status, gate2_status)
logger.info("Gating overall: %s (gate1=%s, gate2=%s)", overall_status, gate1_status, gate2_status)
# Save result
with get_session_context() as session:
result = GatingResult(
model_version_id=UUID(str(model_version_id)),
task_id=UUID(str(task_id)) if task_id else None,
gate1_status=gate1_status,
gate1_original_mAP=gate1_original_mAP,
gate1_new_mAP=gate1_new_mAP,
gate1_mAP_drop=gate1_mAP_drop,
gate2_status=gate2_status,
gate2_detection_rate=gate2_detection_rate,
gate2_total_samples=gate2_total_samples,
gate2_detected_samples=gate2_detected_samples,
overall_status=overall_status,
)
session.add(result)
session.commit()
session.refresh(result)
session.expunge(result)
# Update model version gating status
_update_model_gating_status(str(model_version_id), overall_status)
return result
def _update_model_gating_status(version_id: str, status: str) -> None:
"""Update the gating_status field on a ModelVersion."""
from backend.data.admin_models import ModelVersion
with get_session_context() as session:
model = session.get(ModelVersion, UUID(version_id))
if model:
model.gating_status = status
session.add(model)
session.commit()

View File

@@ -3,6 +3,6 @@ fastapi>=0.104.0
uvicorn[standard]>=0.24.0
python-multipart>=0.0.6
sqlmodel>=0.0.22
ultralytics>=8.1.0
ultralytics>=8.4.0
httpx>=0.25.0
openai>=1.0.0