WIP
This commit is contained in:
@@ -289,6 +289,16 @@ class ModelVersion(SQLModel, table=True):
|
||||
is_active: bool = Field(default=False, index=True)
|
||||
# Only one version can be active at a time for inference
|
||||
|
||||
# Model lineage
|
||||
model_type: str = Field(default="base", max_length=20, index=True)
|
||||
# "base" = trained from pretrained YOLO, "finetune" = fine-tuned from base model
|
||||
base_model_version_id: UUID | None = Field(default=None, index=True)
|
||||
# Points to the base model this was fine-tuned from (None for base models)
|
||||
base_training_dataset_id: UUID | None = Field(default=None, index=True)
|
||||
# The dataset used for original base training (for data mixing old samples)
|
||||
gating_status: str = Field(default="pending", max_length=20, index=True)
|
||||
# Deployment gating: pending, pass, review, reject, skipped
|
||||
|
||||
# Training association
|
||||
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id", index=True)
|
||||
dataset_id: UUID | None = Field(default=None, foreign_key="training_datasets.dataset_id", index=True)
|
||||
@@ -317,6 +327,64 @@ class ModelVersion(SQLModel, table=True):
|
||||
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fine-Tune Pool
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class FineTunePoolEntry(SQLModel, table=True):
|
||||
"""Document in the fine-tune pool for incremental model improvement."""
|
||||
|
||||
__tablename__ = "finetune_pool_entries"
|
||||
|
||||
entry_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
document_id: UUID = Field(foreign_key="admin_documents.document_id", index=True)
|
||||
added_by: str | None = Field(default=None, max_length=255)
|
||||
reason: str | None = Field(default=None, max_length=255)
|
||||
# Reason: user_reported_failure, manual_addition
|
||||
is_verified: bool = Field(default=False, index=True)
|
||||
verified_at: datetime | None = Field(default=None)
|
||||
verified_by: str | None = Field(default=None, max_length=255)
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Deployment Gating
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class GatingResult(SQLModel, table=True):
|
||||
"""Model deployment gating validation result."""
|
||||
|
||||
__tablename__ = "gating_results"
|
||||
|
||||
result_id: UUID = Field(default_factory=uuid4, primary_key=True)
|
||||
model_version_id: UUID = Field(foreign_key="model_versions.version_id", index=True)
|
||||
task_id: UUID | None = Field(default=None, foreign_key="training_tasks.task_id")
|
||||
|
||||
# Gate 1: Regression validation (original test set mAP)
|
||||
gate1_status: str = Field(default="pending", max_length=20)
|
||||
# pending, pass, review, reject
|
||||
gate1_original_mAP: float | None = Field(default=None)
|
||||
gate1_new_mAP: float | None = Field(default=None)
|
||||
gate1_mAP_drop: float | None = Field(default=None)
|
||||
|
||||
# Gate 2: New sample validation (detection rate on pool docs)
|
||||
gate2_status: str = Field(default="pending", max_length=20)
|
||||
gate2_detection_rate: float | None = Field(default=None)
|
||||
gate2_total_samples: int | None = Field(default=None)
|
||||
gate2_detected_samples: int | None = Field(default=None)
|
||||
|
||||
# Overall
|
||||
overall_status: str = Field(default="pending", max_length=20)
|
||||
# pending, pass, review, reject
|
||||
reviewer_notes: str | None = Field(default=None)
|
||||
reviewed_by: str | None = Field(default=None, max_length=255)
|
||||
reviewed_at: datetime | None = Field(default=None)
|
||||
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Annotation History (v2)
|
||||
# =============================================================================
|
||||
|
||||
@@ -13,6 +13,7 @@ from backend.data.repositories.training_task_repository import TrainingTaskRepos
|
||||
from backend.data.repositories.dataset_repository import DatasetRepository
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
from backend.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
from backend.data.repositories.finetune_pool_repository import FineTunePoolRepository
|
||||
|
||||
__all__ = [
|
||||
"BaseRepository",
|
||||
@@ -23,4 +24,5 @@ __all__ = [
|
||||
"DatasetRepository",
|
||||
"ModelVersionRepository",
|
||||
"BatchUploadRepository",
|
||||
"FineTunePoolRepository",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Fine-Tune Pool Repository
|
||||
|
||||
Manages the fine-tune pool: accumulated user-reported failure documents
|
||||
for incremental model improvement.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func
|
||||
from sqlmodel import select
|
||||
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.admin_models import FineTunePoolEntry
|
||||
from backend.data.repositories.base import BaseRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FineTunePoolRepository(BaseRepository[FineTunePoolEntry]):
|
||||
"""Repository for fine-tune pool management."""
|
||||
|
||||
def add_document(
|
||||
self,
|
||||
document_id: str | UUID,
|
||||
added_by: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> FineTunePoolEntry:
|
||||
"""Add a document to the fine-tune pool."""
|
||||
with get_session_context() as session:
|
||||
entry = FineTunePoolEntry(
|
||||
document_id=UUID(str(document_id)),
|
||||
added_by=added_by,
|
||||
reason=reason,
|
||||
)
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
session.refresh(entry)
|
||||
session.expunge(entry)
|
||||
return entry
|
||||
|
||||
def get_entry(self, entry_id: str | UUID) -> FineTunePoolEntry | None:
|
||||
"""Get a pool entry by ID."""
|
||||
with get_session_context() as session:
|
||||
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
|
||||
if entry:
|
||||
session.expunge(entry)
|
||||
return entry
|
||||
|
||||
def get_by_document(self, document_id: str | UUID) -> FineTunePoolEntry | None:
|
||||
"""Get pool entry for a document."""
|
||||
with get_session_context() as session:
|
||||
result = session.exec(
|
||||
select(FineTunePoolEntry).where(
|
||||
FineTunePoolEntry.document_id == UUID(str(document_id))
|
||||
)
|
||||
).first()
|
||||
if result:
|
||||
session.expunge(result)
|
||||
return result
|
||||
|
||||
def get_paginated(
|
||||
self,
|
||||
verified_only: bool = False,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> tuple[list[FineTunePoolEntry], int]:
|
||||
"""List pool entries with pagination."""
|
||||
with get_session_context() as session:
|
||||
query = select(FineTunePoolEntry)
|
||||
count_query = select(func.count()).select_from(FineTunePoolEntry)
|
||||
if verified_only:
|
||||
query = query.where(FineTunePoolEntry.is_verified == True)
|
||||
count_query = count_query.where(FineTunePoolEntry.is_verified == True)
|
||||
total = session.exec(count_query).one()
|
||||
entries = session.exec(
|
||||
query.order_by(FineTunePoolEntry.created_at.desc())
|
||||
.offset(offset)
|
||||
.limit(limit)
|
||||
).all()
|
||||
for e in entries:
|
||||
session.expunge(e)
|
||||
return list(entries), total
|
||||
|
||||
def get_pool_count(self, verified_only: bool = True) -> int:
|
||||
"""Get count of entries in the pool."""
|
||||
with get_session_context() as session:
|
||||
query = select(func.count()).select_from(FineTunePoolEntry)
|
||||
if verified_only:
|
||||
query = query.where(FineTunePoolEntry.is_verified == True)
|
||||
return session.exec(query).one()
|
||||
|
||||
def get_all_document_ids(self, verified_only: bool = True) -> list[UUID]:
|
||||
"""Get all document IDs in the pool."""
|
||||
with get_session_context() as session:
|
||||
query = select(FineTunePoolEntry.document_id)
|
||||
if verified_only:
|
||||
query = query.where(FineTunePoolEntry.is_verified == True)
|
||||
results = session.exec(query).all()
|
||||
return list(results)
|
||||
|
||||
def verify_entry(
|
||||
self,
|
||||
entry_id: str | UUID,
|
||||
verified_by: str | None = None,
|
||||
) -> FineTunePoolEntry | None:
|
||||
"""Mark a pool entry as verified."""
|
||||
with get_session_context() as session:
|
||||
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
|
||||
if not entry:
|
||||
return None
|
||||
entry.is_verified = True
|
||||
entry.verified_at = datetime.utcnow()
|
||||
entry.verified_by = verified_by
|
||||
session.add(entry)
|
||||
session.commit()
|
||||
session.refresh(entry)
|
||||
session.expunge(entry)
|
||||
return entry
|
||||
|
||||
def remove_entry(self, entry_id: str | UUID) -> bool:
|
||||
"""Remove an entry from the pool."""
|
||||
with get_session_context() as session:
|
||||
entry = session.get(FineTunePoolEntry, UUID(str(entry_id)))
|
||||
if not entry:
|
||||
return False
|
||||
session.delete(entry)
|
||||
session.commit()
|
||||
return True
|
||||
@@ -43,6 +43,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
|
||||
training_config: dict[str, Any] | None = None,
|
||||
file_size: int | None = None,
|
||||
trained_at: datetime | None = None,
|
||||
model_type: str = "base",
|
||||
base_model_version_id: str | UUID | None = None,
|
||||
base_training_dataset_id: str | UUID | None = None,
|
||||
gating_status: str = "pending",
|
||||
) -> ModelVersion:
|
||||
"""Create a new model version."""
|
||||
with get_session_context() as session:
|
||||
@@ -60,6 +64,10 @@ class ModelVersionRepository(BaseRepository[ModelVersion]):
|
||||
training_config=training_config,
|
||||
file_size=file_size,
|
||||
trained_at=trained_at,
|
||||
model_type=model_type,
|
||||
base_model_version_id=UUID(str(base_model_version_id)) if base_model_version_id else None,
|
||||
base_training_dataset_id=UUID(str(base_training_dataset_id)) if base_training_dataset_id else None,
|
||||
gating_status=gating_status,
|
||||
)
|
||||
session.add(model)
|
||||
session.commit()
|
||||
|
||||
@@ -40,6 +40,7 @@ from .normalizers import (
|
||||
EnhancedAmountNormalizer,
|
||||
EnhancedDateNormalizer,
|
||||
)
|
||||
from .value_selector import ValueSelector
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -169,13 +170,21 @@ class FieldExtractor:
|
||||
overlap_ratio = overlap_area / token_area if token_area > 0 else 0
|
||||
matching_tokens.append((token, overlap_ratio))
|
||||
|
||||
# Sort by overlap ratio and combine text
|
||||
# Sort by overlap ratio
|
||||
matching_tokens.sort(key=lambda x: -x[1])
|
||||
raw_text = ' '.join(t[0].text for t in matching_tokens)
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Convert to OCRTokens for value selection, then filter
|
||||
from shared.ocr.paddle_ocr import OCRToken
|
||||
pdf_ocr_tokens = [
|
||||
OCRToken(text=t[0].text, bbox=t[0].bbox, confidence=1.0)
|
||||
for t in matching_tokens
|
||||
]
|
||||
value_tokens = ValueSelector.select_value_tokens(pdf_ocr_tokens, field_name)
|
||||
raw_text = ' '.join(t.text for t in value_tokens)
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
@@ -223,13 +232,14 @@ class FieldExtractor:
|
||||
# Run OCR on region
|
||||
ocr_tokens = self.ocr_engine.extract_from_image(region)
|
||||
|
||||
# Combine all OCR text
|
||||
raw_text = ' '.join(t.text for t in ocr_tokens)
|
||||
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
||||
|
||||
# Get field name
|
||||
field_name = CLASS_TO_FIELD.get(detection.class_name, detection.class_name)
|
||||
|
||||
# Select value tokens (filter out label text)
|
||||
value_tokens = ValueSelector.select_value_tokens(ocr_tokens, field_name)
|
||||
raw_text = ' '.join(t.text for t in value_tokens)
|
||||
ocr_confidence = sum(t.confidence for t in ocr_tokens) / len(ocr_tokens) if ocr_tokens else 0.0
|
||||
|
||||
# Normalize and validate
|
||||
normalized_value, is_valid, validation_error = self._normalize_and_validate(
|
||||
field_name, raw_text
|
||||
|
||||
@@ -20,26 +20,98 @@ class AmountNormalizer(BaseNormalizer):
|
||||
Handles various Swedish amount formats:
|
||||
- With decimal: 1 234,56 kr
|
||||
- With SEK suffix: 1234.56 SEK
|
||||
- Payment line kronor/ore: 590 00 (space = decimal separator)
|
||||
- Multiple amounts (returns the last one, usually the total)
|
||||
"""
|
||||
|
||||
# Payment line kronor/ore pattern: "590 00" means 590.00 SEK
|
||||
# Only matches when no comma/dot is present (pure digit-space-2digit format)
|
||||
_KRONOR_ORE_PATTERN = re.compile(r'^(\d+)\s+(\d{2})$')
|
||||
|
||||
@property
|
||||
def field_name(self) -> str:
|
||||
return "Amount"
|
||||
|
||||
@classmethod
|
||||
def _try_kronor_ore(cls, text: str) -> NormalizationResult | None:
|
||||
"""Try to parse as payment line kronor/ore format.
|
||||
|
||||
Swedish payment lines separate kronor and ore with a space:
|
||||
"590 00" = 590.00 SEK, "15658 00" = 15658.00 SEK
|
||||
|
||||
Only applies when text has no comma or dot (otherwise it's
|
||||
a normal amount format with explicit decimal separator).
|
||||
|
||||
Returns NormalizationResult on success, None if not matched.
|
||||
"""
|
||||
if ',' in text or '.' in text:
|
||||
return None
|
||||
|
||||
match = cls._KRONOR_ORE_PATTERN.match(text.strip())
|
||||
if not match:
|
||||
return None
|
||||
|
||||
kronor = match.group(1)
|
||||
ore = match.group(2)
|
||||
try:
|
||||
amount = float(f"{kronor}.{ore}")
|
||||
if amount > 0:
|
||||
return NormalizationResult.success(f"{amount:.2f}")
|
||||
except ValueError:
|
||||
pass
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _parse_amount_str(match: str) -> float | None:
|
||||
"""Convert matched amount string to float, detecting European vs Anglo format.
|
||||
|
||||
European: 2.254,50 -> 2254.50 (dot=thousand, comma=decimal)
|
||||
Anglo: 1,234.56 -> 1234.56 (comma=thousand, dot=decimal)
|
||||
Swedish: 1 234,56 -> 1234.56 (space=thousand, comma=decimal)
|
||||
"""
|
||||
has_comma = ',' in match
|
||||
has_dot = '.' in match
|
||||
if has_comma and has_dot:
|
||||
if match.rfind(',') > match.rfind('.'):
|
||||
# European: 2.254,50
|
||||
cleaned = match.replace(" ", "").replace(".", "").replace(",", ".")
|
||||
else:
|
||||
# Anglo: 1,234.56
|
||||
cleaned = match.replace(" ", "").replace(",", "")
|
||||
elif has_comma:
|
||||
cleaned = match.replace(" ", "").replace(",", ".")
|
||||
else:
|
||||
cleaned = match.replace(" ", "")
|
||||
try:
|
||||
return float(cleaned)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def normalize(self, text: str) -> NormalizationResult:
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Early check: payment line kronor/ore format ("590 00" → 590.00)
|
||||
kronor_ore_result = self._try_kronor_ore(text)
|
||||
if kronor_ore_result is not None:
|
||||
return kronor_ore_result
|
||||
|
||||
# Split by newlines and process line by line to get the last valid amount
|
||||
lines = text.split("\n")
|
||||
|
||||
# Collect all valid amounts from all lines
|
||||
all_amounts: list[float] = []
|
||||
|
||||
# Pattern for Swedish amount format (with decimals)
|
||||
amount_pattern = r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|SEK)?"
|
||||
# Separate patterns for European and Anglo formats
|
||||
# (?!\d) lookahead prevents partial matches (e.g. "1,23" in "1,234.56")
|
||||
# European: dot=thousand, comma=decimal (2.254,50 or 1 234,56)
|
||||
# Anglo: comma=thousand, dot=decimal (1,234.56 or 1234.56)
|
||||
amount_pattern = (
|
||||
r"(\d[\d\s.]*,\d{2})(?!\d)\s*(?:kr|SEK)?"
|
||||
r"|"
|
||||
r"(\d[\d\s,]*\.\d{2})(?!\d)\s*(?:kr|SEK)?"
|
||||
)
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
@@ -47,15 +119,13 @@ class AmountNormalizer(BaseNormalizer):
|
||||
continue
|
||||
|
||||
# Find all amounts in this line
|
||||
matches = re.findall(amount_pattern, line, re.IGNORECASE)
|
||||
for match in matches:
|
||||
amount_str = match.replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if amount > 0:
|
||||
all_amounts.append(amount)
|
||||
except ValueError:
|
||||
for m in re.finditer(amount_pattern, line, re.IGNORECASE):
|
||||
match = m.group(1) or m.group(2)
|
||||
if not match:
|
||||
continue
|
||||
amount = self._parse_amount_str(match)
|
||||
if amount is not None and amount > 0:
|
||||
all_amounts.append(amount)
|
||||
|
||||
# Return the last amount found (usually the total)
|
||||
if all_amounts:
|
||||
@@ -122,31 +192,33 @@ class EnhancedAmountNormalizer(AmountNormalizer):
|
||||
if not text:
|
||||
return NormalizationResult.failure("Empty text")
|
||||
|
||||
# Early check: payment line kronor/ore format ("590 00" → 590.00)
|
||||
kronor_ore_result = self._try_kronor_ore(text)
|
||||
if kronor_ore_result is not None:
|
||||
return kronor_ore_result
|
||||
|
||||
# Strategy 1: Apply OCR corrections first
|
||||
corrected_text = OCRCorrections.correct_digits(text, aggressive=False).corrected
|
||||
|
||||
# Strategy 2: Look for labeled amounts (highest priority)
|
||||
# Use two capture groups: group(1) = European, group(2) = Anglo
|
||||
labeled_patterns = [
|
||||
# Swedish patterns
|
||||
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})", 1.0),
|
||||
# Swedish patterns ((?!\d) prevents partial matches like "1,23" in "1,234.56")
|
||||
(r"(?:att\s+betala|summa|total|belopp)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))", 1.0),
|
||||
(
|
||||
r"(?:moms|vat)\s*[:\s]*(\d[\d\s]*[,\.]\d{2})",
|
||||
r"(?:moms|vat)\s*[:\s]*(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))",
|
||||
0.8,
|
||||
), # Lower priority for VAT
|
||||
# Generic pattern
|
||||
(r"(\d[\d\s]*[,\.]\d{2})\s*(?:kr|sek|kronor)?", 0.7),
|
||||
(r"(\d[\d\s.]*,\d{2}(?!\d)|\d[\d\s,]*\.\d{2}(?!\d))\s*(?:kr|sek|kronor)?", 0.7),
|
||||
]
|
||||
|
||||
candidates: list[tuple[float, float, int]] = []
|
||||
for pattern, priority in labeled_patterns:
|
||||
for match in re.finditer(pattern, corrected_text, re.IGNORECASE):
|
||||
amount_str = match.group(1).replace(" ", "").replace(",", ".")
|
||||
try:
|
||||
amount = float(amount_str)
|
||||
if 0 < amount < 10_000_000: # Reasonable range
|
||||
candidates.append((amount, priority, match.start()))
|
||||
except ValueError:
|
||||
continue
|
||||
amount = self._parse_amount_str(match.group(1))
|
||||
if amount is not None and 0 < amount < 10_000_000:
|
||||
candidates.append((amount, priority, match.start()))
|
||||
|
||||
if candidates:
|
||||
# Sort by priority (desc), then by position (later is usually total)
|
||||
|
||||
@@ -301,6 +301,27 @@ class InferencePipeline:
|
||||
all_extracted = []
|
||||
all_ocr_text = [] # Collect OCR text for VAT extraction
|
||||
|
||||
# Check if PDF has readable text layer (avoids OCR for text PDFs)
|
||||
from shared.pdf.extractor import PDFDocument
|
||||
is_text_pdf = False
|
||||
pdf_tokens_by_page: dict[int, list] = {}
|
||||
try:
|
||||
with PDFDocument(pdf_path) as pdf_doc:
|
||||
is_text_pdf = pdf_doc.is_text_pdf()
|
||||
if is_text_pdf:
|
||||
for pg in range(pdf_doc.page_count):
|
||||
pdf_tokens_by_page[pg] = list(
|
||||
pdf_doc.extract_text_tokens(pg)
|
||||
)
|
||||
logger.info(
|
||||
"Text PDF detected, extracted %d tokens from %d pages",
|
||||
sum(len(t) for t in pdf_tokens_by_page.values()),
|
||||
len(pdf_tokens_by_page),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("PDF text detection failed, falling back to OCR: %s", e)
|
||||
is_text_pdf = False
|
||||
|
||||
# Process each page
|
||||
for page_no, image_bytes in render_pdf_to_images(pdf_path, dpi=self.dpi):
|
||||
# Convert to numpy array
|
||||
@@ -313,7 +334,17 @@ class InferencePipeline:
|
||||
|
||||
# Extract fields from detections
|
||||
for detection in detections:
|
||||
extracted = self.extractor.extract_from_detection(detection, image_array)
|
||||
if is_text_pdf and page_no in pdf_tokens_by_page:
|
||||
extracted = self.extractor.extract_from_detection_with_pdf(
|
||||
detection,
|
||||
pdf_tokens_by_page[page_no],
|
||||
image.width,
|
||||
image.height,
|
||||
)
|
||||
else:
|
||||
extracted = self.extractor.extract_from_detection(
|
||||
detection, image_array
|
||||
)
|
||||
all_extracted.append(extracted)
|
||||
|
||||
# Collect full-page OCR text for VAT extraction (only if business features enabled)
|
||||
|
||||
172
packages/backend/backend/pipeline/value_selector.py
Normal file
172
packages/backend/backend/pipeline/value_selector.py
Normal file
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Value Selector Module.
|
||||
|
||||
Selects the most likely value token(s) from OCR output per field type,
|
||||
filtering out label text before sending to normalizer.
|
||||
|
||||
Stateless and pure -- easy to test, no side effects.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from shared.ocr.paddle_ocr import OCRToken
|
||||
|
||||
|
||||
# Swedish label keywords commonly found near field values
|
||||
LABEL_KEYWORDS: Final[frozenset[str]] = frozenset({
|
||||
"fakturanummer", "fakturanr", "fakturadatum", "forfallodag", "forfalldatum",
|
||||
"bankgiro", "plusgiro", "bg", "pg", "ocr", "belopp", "summa",
|
||||
"total", "att", "betala", "kundnummer", "organisationsnummer",
|
||||
"org", "nr", "datum", "nummer", "ref", "referens",
|
||||
"momsreg", "vat", "moms", "sek", "kr",
|
||||
"org.nr", "kund", "faktura", "invoice",
|
||||
})
|
||||
|
||||
# Patterns
|
||||
_DATE_PATTERN = re.compile(
|
||||
r"\d{4}[-./]\d{2}[-./]\d{2}" # 2024-01-15, 2024.01.15
|
||||
r"|"
|
||||
r"\d{2}[-./]\d{2}[-./]\d{4}" # 15/01/2024
|
||||
r"|"
|
||||
r"\d{8}" # 20240115
|
||||
)
|
||||
|
||||
_AMOUNT_PATTERN = re.compile(
|
||||
r"\d[\d\s.]*,\d{2}(?:\s*(?:kr|SEK))?$" # European: 2.254,50 SEK
|
||||
r"|"
|
||||
r"\d[\d\s,]*\.\d{2}(?:\s*(?:kr|SEK))?$" # Anglo: 1,234.56 SEK
|
||||
)
|
||||
|
||||
_BANKGIRO_PATTERN = re.compile(
|
||||
r"^\d{3,4}-\d{4}$" # 123-4567
|
||||
r"|"
|
||||
r"^\d{7,8}$" # 1234567 or 12345678
|
||||
)
|
||||
|
||||
_PLUSGIRO_PATTERN = re.compile(
|
||||
r"^\d+-\d$" # 12345-6
|
||||
r"|"
|
||||
r"^\d{2,8}$" # 1234567
|
||||
)
|
||||
|
||||
_ORG_NUMBER_PATTERN = re.compile(
|
||||
r"\d{6}-?\d{4}" # 556123-4567 or 5561234567
|
||||
)
|
||||
|
||||
|
||||
def _is_label(text: str) -> bool:
|
||||
"""Check if token text is a known Swedish label keyword."""
|
||||
cleaned = text.lower().rstrip(":").strip()
|
||||
return cleaned in LABEL_KEYWORDS
|
||||
|
||||
|
||||
def _count_digits(text: str) -> int:
|
||||
"""Count digit characters in text."""
|
||||
return sum(c.isdigit() for c in text)
|
||||
|
||||
|
||||
class ValueSelector:
|
||||
"""Selects value token(s) from OCR output, filtering label text.
|
||||
|
||||
Pure static methods -- no state, no side effects.
|
||||
Fallback: always returns all tokens if no pattern matches,
|
||||
so this can never make results worse than current behavior.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def select_value_tokens(
|
||||
tokens: list[OCRToken],
|
||||
field_name: str,
|
||||
) -> list[OCRToken]:
|
||||
"""Select the most likely value token(s) for a given field.
|
||||
|
||||
Args:
|
||||
tokens: OCR tokens from the detected region.
|
||||
field_name: Normalized field name (e.g. "InvoiceDate", "Amount").
|
||||
|
||||
Returns:
|
||||
Filtered list of value tokens. Falls back to all tokens
|
||||
if no field-specific pattern matches.
|
||||
"""
|
||||
if not tokens:
|
||||
return []
|
||||
|
||||
selector = _FIELD_SELECTORS.get(field_name, _fallback_selector)
|
||||
selected = selector(tokens)
|
||||
|
||||
# Safety: never return empty if we had input tokens
|
||||
if not selected:
|
||||
return list(tokens)
|
||||
|
||||
return selected
|
||||
|
||||
@staticmethod
|
||||
def _select_date(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
return _select_by_pattern(tokens, _DATE_PATTERN)
|
||||
|
||||
@staticmethod
|
||||
def _select_amount(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
return _select_by_pattern(tokens, _AMOUNT_PATTERN)
|
||||
|
||||
@staticmethod
|
||||
def _select_bankgiro(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
return _select_by_pattern(tokens, _BANKGIRO_PATTERN)
|
||||
|
||||
@staticmethod
|
||||
def _select_plusgiro(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
return _select_by_pattern(tokens, _PLUSGIRO_PATTERN)
|
||||
|
||||
@staticmethod
|
||||
def _select_org_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
return _select_by_pattern(tokens, _ORG_NUMBER_PATTERN)
|
||||
|
||||
@staticmethod
|
||||
def _select_ocr_number(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Select token with the longest digit sequence (min 5 digits)."""
|
||||
best: OCRToken | None = None
|
||||
best_count = 0
|
||||
for token in tokens:
|
||||
digit_count = _count_digits(token.text)
|
||||
if digit_count >= 5 and digit_count > best_count:
|
||||
best = token
|
||||
best_count = digit_count
|
||||
return [best] if best else []
|
||||
|
||||
@staticmethod
|
||||
def _select_by_label_exclusion(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Remove label keywords, keep remaining tokens."""
|
||||
return [t for t in tokens if not _is_label(t.text)]
|
||||
|
||||
@staticmethod
|
||||
def _select_payment_line(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Payment line keeps all tokens (needs full text)."""
|
||||
return list(tokens)
|
||||
|
||||
|
||||
def _select_by_pattern(
|
||||
tokens: list[OCRToken],
|
||||
pattern: re.Pattern[str],
|
||||
) -> list[OCRToken]:
|
||||
"""Select tokens matching a regex pattern."""
|
||||
return [t for t in tokens if pattern.search(t.text.strip())]
|
||||
|
||||
|
||||
def _fallback_selector(tokens: list[OCRToken]) -> list[OCRToken]:
|
||||
"""Default: return all tokens unchanged."""
|
||||
return list(tokens)
|
||||
|
||||
|
||||
# Map field names to selector functions
|
||||
_FIELD_SELECTORS: Final[dict[str, callable]] = {
|
||||
"InvoiceDate": ValueSelector._select_date,
|
||||
"InvoiceDueDate": ValueSelector._select_date,
|
||||
"Amount": ValueSelector._select_amount,
|
||||
"Bankgiro": ValueSelector._select_bankgiro,
|
||||
"Plusgiro": ValueSelector._select_plusgiro,
|
||||
"OCR": ValueSelector._select_ocr_number,
|
||||
"InvoiceNumber": ValueSelector._select_by_label_exclusion,
|
||||
"supplier_org_number": ValueSelector._select_org_number,
|
||||
"customer_number": ValueSelector._select_by_label_exclusion,
|
||||
"payment_line": ValueSelector._select_payment_line,
|
||||
}
|
||||
@@ -12,6 +12,7 @@ from .documents import register_document_routes
|
||||
from .export import register_export_routes
|
||||
from .datasets import register_dataset_routes
|
||||
from .models import register_model_routes
|
||||
from .pool import register_pool_routes
|
||||
|
||||
|
||||
def create_training_router() -> APIRouter:
|
||||
@@ -23,6 +24,7 @@ def create_training_router() -> APIRouter:
|
||||
register_export_routes(router)
|
||||
register_dataset_routes(router)
|
||||
register_model_routes(router)
|
||||
register_pool_routes(router)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from backend.web.core.auth import (
|
||||
AnnotationRepoDep,
|
||||
ModelVersionRepoDep,
|
||||
TrainingTaskRepoDep,
|
||||
FineTunePoolRepoDep,
|
||||
)
|
||||
from backend.web.schemas.admin import (
|
||||
DatasetCreateRequest,
|
||||
@@ -233,6 +234,7 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
datasets_repo: DatasetRepoDep,
|
||||
models: ModelVersionRepoDep,
|
||||
tasks: TrainingTaskRepoDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
) -> TrainingTaskResponse:
|
||||
"""Create a training task from a dataset.
|
||||
|
||||
@@ -261,13 +263,39 @@ def register_dataset_routes(router: APIRouter) -> None:
|
||||
status_code=404,
|
||||
detail=f"Base model version not found: {base_model_version_id}",
|
||||
)
|
||||
# Chain prevention: never fine-tune from a fine-tuned model
|
||||
if getattr(base_model, "model_type", "base") == "finetune":
|
||||
original_base_id = getattr(base_model, "base_model_version_id", None)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Cannot chain fine-tunes. Model {base_model.version} is already "
|
||||
f"a fine-tuned model. Select the original base model instead"
|
||||
f"{f' (base_model_version_id: {original_base_id})' if original_base_id else ''}."
|
||||
),
|
||||
)
|
||||
# Pool threshold: require minimum verified pool entries for fine-tuning
|
||||
from backend.web.services.data_mixer import MIN_POOL_SIZE
|
||||
|
||||
verified_count = pool.get_pool_count(verified_only=True)
|
||||
if verified_count < MIN_POOL_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Fine-tuning requires at least {MIN_POOL_SIZE} verified pool entries "
|
||||
f"(currently {verified_count}). Add more documents to the fine-tune "
|
||||
f"pool and verify them before starting fine-tuning."
|
||||
),
|
||||
)
|
||||
|
||||
# Store the resolved model path for the training worker
|
||||
config_dict["base_model_path"] = base_model.model_path
|
||||
config_dict["base_model_version"] = base_model.version
|
||||
logger.info(
|
||||
"Incremental training: using model %s (%s) as base",
|
||||
"Fine-tuning: using base model %s (%s) with %d verified pool entries",
|
||||
base_model.version,
|
||||
base_model.model_path,
|
||||
verified_count,
|
||||
)
|
||||
|
||||
task_id = tasks.create(
|
||||
|
||||
@@ -124,16 +124,11 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Use manual_mode for manual/imported annotations
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
# Apply uniform bbox expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized YOLO format
|
||||
|
||||
@@ -88,6 +88,9 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
name=m.name,
|
||||
status=m.status,
|
||||
is_active=m.is_active,
|
||||
model_type=getattr(m, "model_type", "base"),
|
||||
base_model_version_id=str(m.base_model_version_id) if getattr(m, "base_model_version_id", None) else None,
|
||||
gating_status=getattr(m, "gating_status", "pending"),
|
||||
metrics_mAP=m.metrics_mAP,
|
||||
document_count=m.document_count,
|
||||
trained_at=m.trained_at,
|
||||
@@ -121,6 +124,9 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
name=model.name,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
model_type=getattr(model, "model_type", "base"),
|
||||
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
|
||||
gating_status=getattr(model, "gating_status", "pending"),
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
document_count=model.document_count,
|
||||
trained_at=model.trained_at,
|
||||
@@ -153,6 +159,10 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
model_path=model.model_path,
|
||||
status=model.status,
|
||||
is_active=model.is_active,
|
||||
model_type=getattr(model, "model_type", "base"),
|
||||
base_model_version_id=str(model.base_model_version_id) if getattr(model, "base_model_version_id", None) else None,
|
||||
base_training_dataset_id=str(model.base_training_dataset_id) if getattr(model, "base_training_dataset_id", None) else None,
|
||||
gating_status=getattr(model, "gating_status", "pending"),
|
||||
task_id=str(model.task_id) if model.task_id else None,
|
||||
dataset_id=str(model.dataset_id) if model.dataset_id else None,
|
||||
metrics_mAP=model.metrics_mAP,
|
||||
@@ -209,6 +219,25 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
) -> ModelVersionResponse:
|
||||
"""Activate a model version for inference."""
|
||||
_validate_uuid(version_id, "version_id")
|
||||
|
||||
# Check gating status before activation (for fine-tuned models)
|
||||
pre_check = models.get(version_id)
|
||||
if not pre_check:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
model_type = getattr(pre_check, "model_type", "base")
|
||||
gating_status = getattr(pre_check, "gating_status", "skipped")
|
||||
if model_type == "finetune":
|
||||
if gating_status == "reject":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model failed deployment gating validation. Cannot activate a rejected model.",
|
||||
)
|
||||
if gating_status == "pending":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Model gating validation not yet completed. Wait for validation to finish.",
|
||||
)
|
||||
|
||||
model = models.activate(version_id)
|
||||
if not model:
|
||||
raise HTTPException(status_code=404, detail="Model version not found")
|
||||
@@ -227,6 +256,8 @@ def register_model_routes(router: APIRouter) -> None:
|
||||
message = "Model version activated for inference"
|
||||
if model_reloaded:
|
||||
message += " (model reloaded)"
|
||||
if gating_status == "review":
|
||||
message += " (WARNING: gating status is REVIEW - manual inspection recommended)"
|
||||
|
||||
return ModelVersionResponse(
|
||||
version_id=str(model.version_id),
|
||||
|
||||
159
packages/backend/backend/web/api/v1/admin/training/pool.py
Normal file
159
packages/backend/backend/web/api/v1/admin/training/pool.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""Fine-Tune Pool Endpoints."""
|
||||
|
||||
import logging
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from backend.web.core.auth import AdminTokenDep, FineTunePoolRepoDep, DocumentRepoDep
|
||||
from backend.web.schemas.admin.pool import (
|
||||
PoolAddRequest,
|
||||
PoolEntryItem,
|
||||
PoolEntryResponse,
|
||||
PoolListResponse,
|
||||
PoolStatsResponse,
|
||||
)
|
||||
|
||||
from ._utils import _validate_uuid
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def register_pool_routes(router: APIRouter) -> None:
|
||||
"""Register fine-tune pool endpoints on the router."""
|
||||
|
||||
@router.post(
|
||||
"/pool",
|
||||
response_model=PoolEntryResponse,
|
||||
summary="Add document to fine-tune pool",
|
||||
description="Add a labeled document to the fine-tune pool for future fine-tuning.",
|
||||
)
|
||||
async def add_to_pool(
|
||||
request: PoolAddRequest,
|
||||
admin_token: AdminTokenDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
docs: DocumentRepoDep,
|
||||
) -> PoolEntryResponse:
|
||||
"""Add a document to the fine-tune pool."""
|
||||
_validate_uuid(request.document_id, "document_id")
|
||||
|
||||
# Verify document exists
|
||||
doc = docs.get(request.document_id)
|
||||
if not doc:
|
||||
raise HTTPException(status_code=404, detail="Document not found")
|
||||
|
||||
# Check if already in pool
|
||||
existing = pool.get_by_document(request.document_id)
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"Document already in fine-tune pool (entry_id: {existing.entry_id})",
|
||||
)
|
||||
|
||||
entry = pool.add_document(
|
||||
document_id=request.document_id,
|
||||
added_by=admin_token,
|
||||
reason=request.reason,
|
||||
)
|
||||
|
||||
return PoolEntryResponse(
|
||||
entry_id=str(entry.entry_id),
|
||||
message="Document added to fine-tune pool",
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/pool",
|
||||
response_model=PoolListResponse,
|
||||
summary="List fine-tune pool entries",
|
||||
)
|
||||
async def list_pool_entries(
|
||||
admin_token: AdminTokenDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
verified_only: Annotated[bool, Query(description="Filter to verified only")] = False,
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 20,
|
||||
offset: Annotated[int, Query(ge=0)] = 0,
|
||||
) -> PoolListResponse:
|
||||
"""List entries in the fine-tune pool."""
|
||||
entries, total = pool.get_paginated(
|
||||
verified_only=verified_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return PoolListResponse(
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
entries=[
|
||||
PoolEntryItem(
|
||||
entry_id=str(e.entry_id),
|
||||
document_id=str(e.document_id),
|
||||
added_by=e.added_by,
|
||||
reason=e.reason,
|
||||
is_verified=e.is_verified,
|
||||
verified_at=e.verified_at,
|
||||
verified_by=e.verified_by,
|
||||
created_at=e.created_at,
|
||||
)
|
||||
for e in entries
|
||||
],
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/pool/stats",
|
||||
response_model=PoolStatsResponse,
|
||||
summary="Get fine-tune pool statistics",
|
||||
)
|
||||
async def get_pool_stats(
|
||||
admin_token: AdminTokenDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
) -> PoolStatsResponse:
|
||||
"""Get statistics about the fine-tune pool."""
|
||||
total = pool.get_pool_count(verified_only=False)
|
||||
verified = pool.get_pool_count(verified_only=True)
|
||||
|
||||
return PoolStatsResponse(
|
||||
total_entries=total,
|
||||
verified_entries=verified,
|
||||
unverified_entries=total - verified,
|
||||
is_ready=verified >= 50,
|
||||
)
|
||||
|
||||
@router.post(
|
||||
"/pool/{entry_id}/verify",
|
||||
response_model=PoolEntryResponse,
|
||||
summary="Verify a pool entry",
|
||||
description="Mark a pool entry as verified (human-reviewed).",
|
||||
)
|
||||
async def verify_pool_entry(
|
||||
entry_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
) -> PoolEntryResponse:
|
||||
"""Mark a pool entry as verified."""
|
||||
_validate_uuid(entry_id, "entry_id")
|
||||
entry = pool.verify_entry(entry_id, verified_by=admin_token)
|
||||
if not entry:
|
||||
raise HTTPException(status_code=404, detail="Pool entry not found")
|
||||
|
||||
return PoolEntryResponse(
|
||||
entry_id=str(entry.entry_id),
|
||||
message="Pool entry verified",
|
||||
)
|
||||
|
||||
@router.delete(
|
||||
"/pool/{entry_id}",
|
||||
summary="Remove from fine-tune pool",
|
||||
)
|
||||
async def remove_from_pool(
|
||||
entry_id: str,
|
||||
admin_token: AdminTokenDep,
|
||||
pool: FineTunePoolRepoDep,
|
||||
) -> dict:
|
||||
"""Remove a document from the fine-tune pool."""
|
||||
_validate_uuid(entry_id, "entry_id")
|
||||
success = pool.remove_entry(entry_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Pool entry not found")
|
||||
|
||||
return {"message": "Entry removed from fine-tune pool"}
|
||||
@@ -17,6 +17,7 @@ from backend.data.repositories import (
|
||||
TrainingTaskRepository,
|
||||
ModelVersionRepository,
|
||||
BatchUploadRepository,
|
||||
FineTunePoolRepository,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,6 +96,12 @@ def get_batch_upload_repository() -> BatchUploadRepository:
|
||||
return BatchUploadRepository()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_finetune_pool_repository() -> FineTunePoolRepository:
|
||||
"""Get the FineTunePoolRepository instance (thread-safe singleton)."""
|
||||
return FineTunePoolRepository()
|
||||
|
||||
|
||||
def reset_all_repositories() -> None:
|
||||
"""Reset all repository instances (for testing)."""
|
||||
get_token_repository.cache_clear()
|
||||
@@ -104,6 +111,7 @@ def reset_all_repositories() -> None:
|
||||
get_training_task_repository.cache_clear()
|
||||
get_model_version_repository.cache_clear()
|
||||
get_batch_upload_repository.cache_clear()
|
||||
get_finetune_pool_repository.cache_clear()
|
||||
|
||||
|
||||
# Repository dependency type aliases
|
||||
@@ -113,3 +121,4 @@ DatasetRepoDep = Annotated[DatasetRepository, Depends(get_dataset_repository)]
|
||||
TrainingTaskRepoDep = Annotated[TrainingTaskRepository, Depends(get_training_task_repository)]
|
||||
ModelVersionRepoDep = Annotated[ModelVersionRepository, Depends(get_model_version_repository)]
|
||||
BatchUploadRepoDep = Annotated[BatchUploadRepository, Depends(get_batch_upload_repository)]
|
||||
FineTunePoolRepoDep = Annotated[FineTunePoolRepository, Depends(get_finetune_pool_repository)]
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.data.repositories import (
|
||||
ModelVersionRepository,
|
||||
DocumentRepository,
|
||||
AnnotationRepository,
|
||||
FineTunePoolRepository,
|
||||
)
|
||||
from backend.web.core.task_interface import TaskRunner, TaskStatus
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
@@ -47,6 +48,7 @@ class TrainingScheduler(TaskRunner):
|
||||
self._model_versions = ModelVersionRepository()
|
||||
self._documents = DocumentRepository()
|
||||
self._annotations = AnnotationRepository()
|
||||
self._pool = FineTunePoolRepository()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -168,7 +170,7 @@ class TrainingScheduler(TaskRunner):
|
||||
|
||||
try:
|
||||
# Get training configuration
|
||||
model_name = config.get("model_name", "yolo11n.pt")
|
||||
model_name = config.get("model_name", "yolo26s.pt")
|
||||
base_model_path = config.get("base_model_path") # For incremental training
|
||||
epochs = config.get("epochs", 100)
|
||||
batch_size = config.get("batch_size", 16)
|
||||
@@ -182,14 +184,19 @@ class TrainingScheduler(TaskRunner):
|
||||
augmentation_multiplier = config.get("augmentation_multiplier", 2)
|
||||
|
||||
# Determine which model to use as base
|
||||
if base_model_path:
|
||||
# Incremental training: use existing trained model
|
||||
is_finetune = bool(base_model_path)
|
||||
if is_finetune:
|
||||
# Fine-tuning: use existing trained model as base
|
||||
if not Path(base_model_path).exists():
|
||||
raise ValueError(f"Base model not found: {base_model_path}")
|
||||
effective_model = base_model_path
|
||||
# Override parameters for fine-tuning (best practices)
|
||||
epochs = config.get("epochs", 10)
|
||||
learning_rate = config.get("learning_rate", 0.001)
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Incremental training from: {base_model_path}",
|
||||
f"Fine-tuning from: {base_model_path} "
|
||||
f"(epochs={epochs}, freeze=10, cos_lr=true)",
|
||||
)
|
||||
else:
|
||||
# Train from pretrained model
|
||||
@@ -229,10 +236,16 @@ class TrainingScheduler(TaskRunner):
|
||||
f"(total: {aug_result['total_images']})",
|
||||
)
|
||||
|
||||
# Build mixed dataset for fine-tuning (pool samples + old data)
|
||||
if is_finetune and dataset_id:
|
||||
data_yaml, dataset_path = self._build_mixed_finetune_dataset(
|
||||
task_id, dataset_path, data_yaml,
|
||||
)
|
||||
|
||||
# Run YOLO training
|
||||
result = self._run_yolo_training(
|
||||
task_id=task_id,
|
||||
model_name=effective_model, # Use base model or pretrained model
|
||||
model_name=effective_model,
|
||||
data_yaml=data_yaml,
|
||||
epochs=epochs,
|
||||
batch_size=batch_size,
|
||||
@@ -240,6 +253,8 @@ class TrainingScheduler(TaskRunner):
|
||||
learning_rate=learning_rate,
|
||||
device=device,
|
||||
project_name=project_name,
|
||||
freeze=10 if is_finetune else 0,
|
||||
cos_lr=is_finetune,
|
||||
)
|
||||
|
||||
# Update task with results
|
||||
@@ -261,13 +276,23 @@ class TrainingScheduler(TaskRunner):
|
||||
)
|
||||
|
||||
# Auto-create model version for the completed training
|
||||
self._create_model_version_from_training(
|
||||
model_version = self._create_model_version_from_training(
|
||||
task_id=task_id,
|
||||
config=config,
|
||||
dataset_id=dataset_id,
|
||||
result=result,
|
||||
)
|
||||
|
||||
# Auto-run gating validation for fine-tuned models
|
||||
if is_finetune and model_version:
|
||||
self._run_gating_after_finetune(
|
||||
task_id=task_id,
|
||||
model_version=model_version,
|
||||
config=config,
|
||||
data_yaml=data_yaml,
|
||||
result=result,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Training task {task_id} failed: {e}")
|
||||
self._training_tasks.add_log(task_id, "ERROR", f"Training failed: {e}")
|
||||
@@ -286,13 +311,16 @@ class TrainingScheduler(TaskRunner):
|
||||
config: dict[str, Any],
|
||||
dataset_id: str | None,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Create a model version entry from completed training."""
|
||||
) -> Any | None:
|
||||
"""Create a model version entry from completed training.
|
||||
|
||||
Returns the created model version, or None on failure.
|
||||
"""
|
||||
try:
|
||||
model_path = result.get("model_path")
|
||||
if not model_path:
|
||||
logger.warning(f"No model path in training result for task {task_id}")
|
||||
return
|
||||
return None
|
||||
|
||||
# Get task info for name
|
||||
task = self._training_tasks.get(task_id)
|
||||
@@ -322,6 +350,12 @@ class TrainingScheduler(TaskRunner):
|
||||
if dataset:
|
||||
document_count = dataset.total_documents
|
||||
|
||||
# Determine model lineage
|
||||
is_finetune = bool(config.get("base_model_path"))
|
||||
model_type = "finetune" if is_finetune else "base"
|
||||
base_model_version_id = config.get("base_model_version_id") if is_finetune else None
|
||||
gating_status = "pending" if is_finetune else "skipped"
|
||||
|
||||
# Create model version
|
||||
model_version = self._model_versions.create(
|
||||
version=version,
|
||||
@@ -337,6 +371,10 @@ class TrainingScheduler(TaskRunner):
|
||||
training_config=config,
|
||||
file_size=file_size,
|
||||
trained_at=datetime.utcnow(),
|
||||
model_type=model_type,
|
||||
base_model_version_id=base_model_version_id,
|
||||
base_training_dataset_id=dataset_id if not is_finetune else None,
|
||||
gating_status=gating_status,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -349,12 +387,105 @@ class TrainingScheduler(TaskRunner):
|
||||
f"Model version {version} created (mAP: {mAP_display})",
|
||||
)
|
||||
|
||||
return model_version
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create model version for task {task_id}: {e}")
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Failed to auto-create model version: {e}",
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_mixed_finetune_dataset(
|
||||
self,
|
||||
task_id: str,
|
||||
base_dataset_path: Path,
|
||||
original_data_yaml: str,
|
||||
) -> tuple[str, Path]:
|
||||
"""Build a mixed dataset for fine-tuning.
|
||||
|
||||
Combines verified pool samples with randomly sampled old training data.
|
||||
|
||||
Returns:
|
||||
Tuple of (data_yaml path, dataset_path) for the mixed dataset.
|
||||
Falls back to original if mixing fails or pool is empty.
|
||||
"""
|
||||
try:
|
||||
from backend.web.services.data_mixer import build_mixed_dataset
|
||||
|
||||
pool_doc_ids = self._pool.get_all_document_ids(verified_only=True)
|
||||
if not pool_doc_ids:
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
"No verified pool entries found, using original dataset",
|
||||
)
|
||||
return original_data_yaml, base_dataset_path
|
||||
|
||||
mixed_output = base_dataset_path.parent / f"mixed_{task_id[:8]}"
|
||||
mix_result = build_mixed_dataset(
|
||||
pool_document_ids=pool_doc_ids,
|
||||
base_dataset_path=base_dataset_path,
|
||||
output_dir=mixed_output,
|
||||
)
|
||||
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Data mixing: {mix_result['new_images']} new + "
|
||||
f"{mix_result['old_images']} old = {mix_result['total_images']} total "
|
||||
f"(ratio: {mix_result['mixing_ratio']}x)",
|
||||
)
|
||||
|
||||
return mix_result["data_yaml"], mixed_output
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data mixing failed for task {task_id}: {e}")
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Data mixing failed: {e}. Using original dataset.",
|
||||
)
|
||||
return original_data_yaml, base_dataset_path
|
||||
|
||||
def _run_gating_after_finetune(
|
||||
self,
|
||||
task_id: str,
|
||||
model_version: Any,
|
||||
config: dict[str, Any],
|
||||
data_yaml: str,
|
||||
result: dict[str, Any],
|
||||
) -> None:
|
||||
"""Run gating validation after a fine-tune training completes."""
|
||||
try:
|
||||
from backend.web.services.gating_validator import run_gating_validation
|
||||
|
||||
model_path = result.get("model_path")
|
||||
base_model_version_id = config.get("base_model_version_id")
|
||||
version_id = str(model_version.version_id)
|
||||
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO", "Running deployment gating validation...",
|
||||
)
|
||||
|
||||
gating_result = run_gating_validation(
|
||||
model_version_id=version_id,
|
||||
new_model_path=model_path,
|
||||
base_model_version_id=base_model_version_id,
|
||||
data_yaml=data_yaml,
|
||||
task_id=task_id,
|
||||
)
|
||||
|
||||
self._training_tasks.add_log(
|
||||
task_id, "INFO",
|
||||
f"Gating result: {gating_result.overall_status} "
|
||||
f"(gate1={gating_result.gate1_status}, gate2={gating_result.gate2_status})",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Gating validation failed for task {task_id}: {e}")
|
||||
self._training_tasks.add_log(
|
||||
task_id, "WARNING",
|
||||
f"Gating validation failed: {e}. Model remains in 'pending' state.",
|
||||
)
|
||||
|
||||
def _export_training_data(self, task_id: str) -> dict[str, Any] | None:
|
||||
"""Export training data for a task."""
|
||||
@@ -456,6 +587,8 @@ names: {list(FIELD_CLASSES.values())}
|
||||
learning_rate: float,
|
||||
device: str,
|
||||
project_name: str,
|
||||
freeze: int = 0,
|
||||
cos_lr: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Run YOLO training using shared trainer."""
|
||||
from shared.training import YOLOTrainer, TrainingConfig as SharedTrainingConfig
|
||||
@@ -479,6 +612,8 @@ names: {list(FIELD_CLASSES.values())}
|
||||
project="runs/train",
|
||||
name=f"{project_name}/task_{task_id[:8]}",
|
||||
workers=0,
|
||||
freeze=freeze,
|
||||
cos_lr=cos_lr,
|
||||
)
|
||||
|
||||
# Run training using shared trainer
|
||||
|
||||
@@ -11,6 +11,7 @@ from .annotations import * # noqa: F401, F403
|
||||
from .training import * # noqa: F401, F403
|
||||
from .datasets import * # noqa: F401, F403
|
||||
from .models import * # noqa: F401, F403
|
||||
from .pool import * # noqa: F401, F403
|
||||
from .dashboard import * # noqa: F401, F403
|
||||
|
||||
# Resolve forward references for DocumentDetailResponse
|
||||
|
||||
@@ -40,6 +40,9 @@ class ModelVersionItem(BaseModel):
|
||||
name: str = Field(..., description="Model name")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
model_type: str = Field(default="base", description="Model type (base or finetune)")
|
||||
base_model_version_id: str | None = Field(None, description="Base model version UUID (for fine-tuned models)")
|
||||
gating_status: str = Field(default="pending", description="Deployment gating status")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
document_count: int = Field(..., description="Documents used in training")
|
||||
trained_at: datetime | None = Field(None, description="Training completion time")
|
||||
@@ -66,6 +69,10 @@ class ModelVersionDetailResponse(BaseModel):
|
||||
model_path: str = Field(..., description="Path to model file")
|
||||
status: str = Field(..., description="Status (active, inactive, archived)")
|
||||
is_active: bool = Field(..., description="Is currently active for inference")
|
||||
model_type: str = Field(default="base", description="Model type (base or finetune)")
|
||||
base_model_version_id: str | None = Field(None, description="Base model version UUID")
|
||||
base_training_dataset_id: str | None = Field(None, description="Base training dataset UUID")
|
||||
gating_status: str = Field(default="pending", description="Deployment gating status")
|
||||
task_id: str | None = Field(None, description="Training task UUID")
|
||||
dataset_id: str | None = Field(None, description="Dataset UUID")
|
||||
metrics_mAP: float | None = Field(None, description="Mean Average Precision")
|
||||
|
||||
72
packages/backend/backend/web/schemas/admin/pool.py
Normal file
72
packages/backend/backend/web/schemas/admin/pool.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Admin Fine-Tune Pool Schemas."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PoolAddRequest(BaseModel):
|
||||
"""Request to add a document to the fine-tune pool."""
|
||||
|
||||
document_id: str = Field(..., description="Document UUID to add")
|
||||
reason: str = Field(
|
||||
default="user_reported_failure",
|
||||
description="Reason: user_reported_failure, manual_addition",
|
||||
)
|
||||
|
||||
|
||||
class PoolEntryItem(BaseModel):
|
||||
"""Fine-tune pool entry."""
|
||||
|
||||
entry_id: str = Field(..., description="Entry UUID")
|
||||
document_id: str = Field(..., description="Document UUID")
|
||||
added_by: str | None = Field(None, description="Who added this entry")
|
||||
reason: str | None = Field(None, description="Reason for adding")
|
||||
is_verified: bool = Field(..., description="Whether entry has been verified")
|
||||
verified_at: datetime | None = Field(None, description="Verification timestamp")
|
||||
verified_by: str | None = Field(None, description="Who verified")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
|
||||
|
||||
class PoolListResponse(BaseModel):
|
||||
"""Paginated pool entry list."""
|
||||
|
||||
total: int = Field(..., ge=0, description="Total entries")
|
||||
limit: int = Field(..., ge=1, description="Page size")
|
||||
offset: int = Field(..., ge=0, description="Current offset")
|
||||
entries: list[PoolEntryItem] = Field(default_factory=list, description="Pool entries")
|
||||
|
||||
|
||||
class PoolStatsResponse(BaseModel):
|
||||
"""Pool statistics."""
|
||||
|
||||
total_entries: int = Field(..., ge=0, description="Total pool entries")
|
||||
verified_entries: int = Field(..., ge=0, description="Verified entries")
|
||||
unverified_entries: int = Field(..., ge=0, description="Unverified entries")
|
||||
is_ready: bool = Field(..., description="Whether pool has >= 50 verified entries for fine-tuning")
|
||||
min_required: int = Field(default=50, description="Minimum verified entries required")
|
||||
|
||||
|
||||
class PoolEntryResponse(BaseModel):
|
||||
"""Response for pool entry operation."""
|
||||
|
||||
entry_id: str = Field(..., description="Entry UUID")
|
||||
message: str = Field(..., description="Status message")
|
||||
|
||||
|
||||
class GatingResultItem(BaseModel):
|
||||
"""Gating validation result."""
|
||||
|
||||
result_id: str = Field(..., description="Result UUID")
|
||||
model_version_id: str = Field(..., description="Model version UUID")
|
||||
gate1_status: str = Field(..., description="Gate 1 status")
|
||||
gate1_original_mAP: float | None = Field(None, description="Original model mAP")
|
||||
gate1_new_mAP: float | None = Field(None, description="New model mAP")
|
||||
gate1_mAP_drop: float | None = Field(None, description="mAP drop percentage")
|
||||
gate2_status: str = Field(..., description="Gate 2 status")
|
||||
gate2_detection_rate: float | None = Field(None, description="Detection rate on new samples")
|
||||
gate2_total_samples: int | None = Field(None, description="Total new samples tested")
|
||||
gate2_detected_samples: int | None = Field(None, description="Samples correctly detected")
|
||||
overall_status: str = Field(..., description="Overall gating status")
|
||||
reviewer_notes: str | None = Field(None, description="Reviewer notes")
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
@@ -12,7 +12,7 @@ from .enums import TrainingStatus, TrainingType
|
||||
class TrainingConfig(BaseModel):
|
||||
"""Training configuration."""
|
||||
|
||||
model_name: str = Field(default="yolo11n.pt", description="Base model name (used if no base_model_version_id)")
|
||||
model_name: str = Field(default="yolo26s.pt", description="Base model name (used if no base_model_version_id)")
|
||||
base_model_version_id: str | None = Field(
|
||||
default=None,
|
||||
description="Model version UUID to use as base for incremental training. If set, uses this model instead of model_name.",
|
||||
|
||||
199
packages/backend/backend/web/services/data_mixer.py
Normal file
199
packages/backend/backend/web/services/data_mixer.py
Normal file
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Data Mixing Service
|
||||
|
||||
Mixes fine-tune pool samples with randomly sampled old training data
|
||||
following best practices for incremental YOLO fine-tuning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Mixing ratios: (max_new_samples, old_data_multiplier)
|
||||
# Fewer new samples require higher old data ratio to prevent catastrophic forgetting
|
||||
MIXING_RATIOS: list[tuple[int, int]] = [
|
||||
(10, 50), # <= 10 new samples: 50x old
|
||||
(50, 20), # <= 50 new samples: 20x old
|
||||
(200, 10), # <= 200 new samples: 10x old
|
||||
(500, 5), # <= 500 new samples: 5x old
|
||||
]
|
||||
|
||||
# Default multiplier for 500+ samples
|
||||
DEFAULT_MULTIPLIER = 5
|
||||
|
||||
# Maximum old samples to include (cap for performance)
|
||||
MAX_OLD_SAMPLES = 3000
|
||||
|
||||
# Minimum pool size for fine-tuning
|
||||
MIN_POOL_SIZE = 50
|
||||
|
||||
|
||||
def get_mixing_ratio(new_sample_count: int) -> int:
|
||||
"""Determine old data multiplier based on new sample count.
|
||||
|
||||
Args:
|
||||
new_sample_count: Number of new samples in the fine-tune pool.
|
||||
|
||||
Returns:
|
||||
Multiplier for old data sampling.
|
||||
"""
|
||||
for threshold, multiplier in MIXING_RATIOS:
|
||||
if new_sample_count <= threshold:
|
||||
return multiplier
|
||||
return DEFAULT_MULTIPLIER
|
||||
|
||||
|
||||
def build_mixed_dataset(
|
||||
pool_document_ids: list[UUID],
|
||||
base_dataset_path: Path,
|
||||
output_dir: Path,
|
||||
seed: int = 42,
|
||||
) -> dict[str, Any]:
|
||||
"""Build a mixed dataset for fine-tuning.
|
||||
|
||||
Combines ALL fine-tune pool samples with randomly sampled old data
|
||||
from the base training dataset.
|
||||
|
||||
Args:
|
||||
pool_document_ids: Document IDs from the fine-tune pool.
|
||||
base_dataset_path: Path to the base training dataset directory.
|
||||
output_dir: Output directory for the mixed dataset.
|
||||
seed: Random seed for reproducible sampling.
|
||||
|
||||
Returns:
|
||||
Dictionary with dataset info (data_yaml path, counts).
|
||||
"""
|
||||
new_count = len(pool_document_ids)
|
||||
multiplier = get_mixing_ratio(new_count)
|
||||
old_target = min(new_count * multiplier, MAX_OLD_SAMPLES)
|
||||
|
||||
logger.info(
|
||||
"Building mixed dataset: %d new samples, %dx multiplier, "
|
||||
"targeting %d old samples",
|
||||
new_count, multiplier, old_target,
|
||||
)
|
||||
|
||||
# Create output directory structure
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
for split in ("train", "val"):
|
||||
(output_dir / "images" / split).mkdir(parents=True, exist_ok=True)
|
||||
(output_dir / "labels" / split).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Collect old training images from base dataset
|
||||
old_train_images = _collect_images(base_dataset_path / "images" / "train")
|
||||
old_val_images = _collect_images(base_dataset_path / "images" / "val")
|
||||
|
||||
# Randomly sample old data
|
||||
rng = random.Random(seed)
|
||||
all_old_images = old_train_images + old_val_images
|
||||
if len(all_old_images) > old_target:
|
||||
sampled_old = rng.sample(all_old_images, old_target)
|
||||
else:
|
||||
sampled_old = all_old_images
|
||||
|
||||
# Split old samples: 80% train, 20% val
|
||||
rng.shuffle(sampled_old)
|
||||
old_train_count = int(len(sampled_old) * 0.8)
|
||||
old_train = sampled_old[:old_train_count]
|
||||
old_val = sampled_old[old_train_count:]
|
||||
|
||||
# Copy old samples to mixed dataset
|
||||
old_copied = 0
|
||||
for split_name, images in [("train", old_train), ("val", old_val)]:
|
||||
for img_path in images:
|
||||
label_path = _image_to_label_path(img_path)
|
||||
dst_img = output_dir / "images" / split_name / img_path.name
|
||||
dst_label = output_dir / "labels" / split_name / label_path.name
|
||||
if img_path.exists():
|
||||
shutil.copy2(img_path, dst_img)
|
||||
old_copied += 1
|
||||
if label_path.exists():
|
||||
shutil.copy2(label_path, dst_label)
|
||||
|
||||
# Copy new pool samples (from base dataset, identified by document_id prefix)
|
||||
# Pool documents go into train split (80%) and val split (20%)
|
||||
pool_id_strs = {str(doc_id) for doc_id in pool_document_ids}
|
||||
new_images = _find_pool_images(base_dataset_path, pool_id_strs)
|
||||
|
||||
rng.shuffle(new_images)
|
||||
new_train_count = int(len(new_images) * 0.8)
|
||||
new_train = new_images[:new_train_count]
|
||||
new_val = new_images[new_train_count:]
|
||||
|
||||
new_copied = 0
|
||||
for split_name, images in [("train", new_train), ("val", new_val)]:
|
||||
for img_path in images:
|
||||
label_path = _image_to_label_path(img_path)
|
||||
dst_img = output_dir / "images" / split_name / img_path.name
|
||||
dst_label = output_dir / "labels" / split_name / label_path.name
|
||||
if img_path.exists() and not dst_img.exists():
|
||||
shutil.copy2(img_path, dst_img)
|
||||
new_copied += 1
|
||||
if label_path.exists() and not dst_label.exists():
|
||||
shutil.copy2(label_path, dst_label)
|
||||
|
||||
# Generate data.yaml
|
||||
from shared.fields import FIELD_CLASSES
|
||||
|
||||
yaml_path = output_dir / "data.yaml"
|
||||
yaml_content = (
|
||||
f"path: {output_dir.absolute()}\n"
|
||||
f"train: images/train\n"
|
||||
f"val: images/val\n"
|
||||
f"\n"
|
||||
f"nc: {len(FIELD_CLASSES)}\n"
|
||||
f"names: {list(FIELD_CLASSES.values())}\n"
|
||||
)
|
||||
yaml_path.write_text(yaml_content)
|
||||
|
||||
total_images = old_copied + new_copied
|
||||
logger.info(
|
||||
"Mixed dataset built: %d old + %d new = %d total images",
|
||||
old_copied, new_copied, total_images,
|
||||
)
|
||||
|
||||
return {
|
||||
"data_yaml": str(yaml_path),
|
||||
"total_images": total_images,
|
||||
"old_images": old_copied,
|
||||
"new_images": new_copied,
|
||||
"mixing_ratio": multiplier,
|
||||
}
|
||||
|
||||
|
||||
def _collect_images(images_dir: Path) -> list[Path]:
|
||||
"""Collect all image files from a directory."""
|
||||
if not images_dir.exists():
|
||||
return []
|
||||
return sorted(images_dir.glob("*.png")) + sorted(images_dir.glob("*.jpg"))
|
||||
|
||||
|
||||
def _image_to_label_path(image_path: Path) -> Path:
|
||||
"""Convert image path to corresponding label path."""
|
||||
labels_dir = image_path.parent.parent.parent / "labels" / image_path.parent.name
|
||||
return labels_dir / image_path.with_suffix(".txt").name
|
||||
|
||||
|
||||
def _find_pool_images(
|
||||
base_dataset_path: Path,
|
||||
pool_doc_ids: set[str],
|
||||
) -> list[Path]:
|
||||
"""Find images in base dataset that belong to pool documents."""
|
||||
images: list[Path] = []
|
||||
for split in ("train", "val", "test"):
|
||||
split_dir = base_dataset_path / "images" / split
|
||||
if not split_dir.exists():
|
||||
continue
|
||||
for img_path in split_dir.iterdir():
|
||||
if not img_path.is_file():
|
||||
continue
|
||||
# Image filenames are like: {doc_id}_page{N}.png
|
||||
doc_id = img_path.stem.rsplit("_page", 1)[0]
|
||||
if doc_id in pool_doc_ids:
|
||||
images.append(img_path)
|
||||
return images
|
||||
198
packages/backend/backend/web/services/gating_validator.py
Normal file
198
packages/backend/backend/web/services/gating_validator.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Gating Validation Service
|
||||
|
||||
Validates fine-tuned models before deployment using quality gates:
|
||||
- Gate 1: Regression validation (mAP drop on original test set)
|
||||
- Gate 2: New sample validation (detection rate on pool documents)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from backend.data.admin_models import GatingResult
|
||||
from backend.data.database import get_session_context
|
||||
from backend.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gate 1 thresholds (mAP drop)
|
||||
GATE1_PASS_THRESHOLD = 0.01 # < 1% drop = PASS
|
||||
GATE1_REVIEW_THRESHOLD = 0.03 # 1-3% drop = REVIEW, > 3% = REJECT
|
||||
|
||||
# Gate 2 thresholds (detection rate)
|
||||
GATE2_PASS_THRESHOLD = 0.80 # > 80% detection rate = PASS
|
||||
|
||||
|
||||
def classify_gate1(mAP_drop: float) -> str:
|
||||
"""Classify Gate 1 result based on mAP drop.
|
||||
|
||||
Args:
|
||||
mAP_drop: Absolute mAP drop (positive means degradation).
|
||||
|
||||
Returns:
|
||||
"pass", "review", or "reject"
|
||||
"""
|
||||
if mAP_drop < GATE1_PASS_THRESHOLD:
|
||||
return "pass"
|
||||
if mAP_drop < GATE1_REVIEW_THRESHOLD:
|
||||
return "review"
|
||||
return "reject"
|
||||
|
||||
|
||||
def classify_gate2(detection_rate: float) -> str:
|
||||
"""Classify Gate 2 result based on detection rate.
|
||||
|
||||
Args:
|
||||
detection_rate: Fraction of new samples correctly detected (0-1).
|
||||
|
||||
Returns:
|
||||
"pass" or "review"
|
||||
"""
|
||||
if detection_rate >= GATE2_PASS_THRESHOLD:
|
||||
return "pass"
|
||||
return "review"
|
||||
|
||||
|
||||
def compute_overall_status(gate1_status: str, gate2_status: str) -> str:
|
||||
"""Compute overall gating status from individual gates.
|
||||
|
||||
Rules:
|
||||
- Any "reject" -> overall "reject"
|
||||
- Any "review" (and no reject) -> overall "review"
|
||||
- All "pass" -> overall "pass"
|
||||
"""
|
||||
if gate1_status == "reject" or gate2_status == "reject":
|
||||
return "reject"
|
||||
if gate1_status == "review" or gate2_status == "review":
|
||||
return "review"
|
||||
return "pass"
|
||||
|
||||
|
||||
def run_gating_validation(
|
||||
model_version_id: str | UUID,
|
||||
new_model_path: str,
|
||||
base_model_version_id: str | UUID | None,
|
||||
data_yaml: str,
|
||||
task_id: str | UUID | None = None,
|
||||
) -> GatingResult:
|
||||
"""Run deployment gating validation for a fine-tuned model.
|
||||
|
||||
Args:
|
||||
model_version_id: The fine-tuned model version to validate.
|
||||
new_model_path: Path to the new model weights.
|
||||
base_model_version_id: The base model version to compare against.
|
||||
data_yaml: Path to the dataset YAML (for validation).
|
||||
task_id: Optional training task ID.
|
||||
|
||||
Returns:
|
||||
GatingResult with gate statuses.
|
||||
"""
|
||||
model_versions = ModelVersionRepository()
|
||||
|
||||
# Gate 1: Regression validation
|
||||
gate1_status = "pending"
|
||||
gate1_original_mAP = None
|
||||
gate1_new_mAP = None
|
||||
gate1_mAP_drop = None
|
||||
|
||||
try:
|
||||
if base_model_version_id:
|
||||
base_model = model_versions.get(str(base_model_version_id))
|
||||
if base_model and base_model.metrics_mAP is not None:
|
||||
gate1_original_mAP = base_model.metrics_mAP
|
||||
|
||||
# Run validation with new model
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
val_config = TrainingConfig(
|
||||
model_path=new_model_path,
|
||||
data_yaml=data_yaml,
|
||||
)
|
||||
trainer = YOLOTrainer(config=val_config)
|
||||
val_metrics = trainer.validate(split="val")
|
||||
gate1_new_mAP = val_metrics.get("mAP50")
|
||||
|
||||
if gate1_new_mAP is not None:
|
||||
gate1_mAP_drop = gate1_original_mAP - gate1_new_mAP
|
||||
gate1_status = classify_gate1(gate1_mAP_drop)
|
||||
logger.info(
|
||||
"Gate 1: original_mAP=%.4f, new_mAP=%.4f, drop=%.4f -> %s",
|
||||
gate1_original_mAP, gate1_new_mAP, gate1_mAP_drop, gate1_status,
|
||||
)
|
||||
else:
|
||||
gate1_status = "review"
|
||||
logger.warning("Gate 1: Could not compute new mAP, marking as review")
|
||||
else:
|
||||
gate1_status = "pass"
|
||||
logger.info("Gate 1: No base model metrics available, skipping (pass)")
|
||||
else:
|
||||
gate1_status = "pass"
|
||||
logger.info("Gate 1: No base model specified, skipping (pass)")
|
||||
except Exception as e:
|
||||
gate1_status = "review"
|
||||
logger.error("Gate 1 failed: %s", e)
|
||||
|
||||
# Gate 2: New sample validation
|
||||
# For now, we use the training metrics as a proxy
|
||||
# Full implementation would run inference on pool documents
|
||||
gate2_status = "pass"
|
||||
gate2_detection_rate = None
|
||||
gate2_total_samples = None
|
||||
gate2_detected_samples = None
|
||||
|
||||
try:
|
||||
new_model = model_versions.get(str(model_version_id))
|
||||
if new_model and new_model.metrics_mAP is not None:
|
||||
# Use mAP as proxy for detection rate on new samples
|
||||
gate2_detection_rate = new_model.metrics_mAP
|
||||
if gate2_detection_rate is not None:
|
||||
gate2_status = classify_gate2(gate2_detection_rate)
|
||||
logger.info(
|
||||
"Gate 2: detection_rate=%.4f -> %s",
|
||||
gate2_detection_rate, gate2_status,
|
||||
)
|
||||
except Exception as e:
|
||||
gate2_status = "review"
|
||||
logger.error("Gate 2 failed: %s", e)
|
||||
|
||||
# Compute overall status
|
||||
overall_status = compute_overall_status(gate1_status, gate2_status)
|
||||
logger.info("Gating overall: %s (gate1=%s, gate2=%s)", overall_status, gate1_status, gate2_status)
|
||||
|
||||
# Save result
|
||||
with get_session_context() as session:
|
||||
result = GatingResult(
|
||||
model_version_id=UUID(str(model_version_id)),
|
||||
task_id=UUID(str(task_id)) if task_id else None,
|
||||
gate1_status=gate1_status,
|
||||
gate1_original_mAP=gate1_original_mAP,
|
||||
gate1_new_mAP=gate1_new_mAP,
|
||||
gate1_mAP_drop=gate1_mAP_drop,
|
||||
gate2_status=gate2_status,
|
||||
gate2_detection_rate=gate2_detection_rate,
|
||||
gate2_total_samples=gate2_total_samples,
|
||||
gate2_detected_samples=gate2_detected_samples,
|
||||
overall_status=overall_status,
|
||||
)
|
||||
session.add(result)
|
||||
session.commit()
|
||||
session.refresh(result)
|
||||
session.expunge(result)
|
||||
|
||||
# Update model version gating status
|
||||
_update_model_gating_status(str(model_version_id), overall_status)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _update_model_gating_status(version_id: str, status: str) -> None:
|
||||
"""Update the gating_status field on a ModelVersion."""
|
||||
from backend.data.admin_models import ModelVersion
|
||||
|
||||
with get_session_context() as session:
|
||||
model = session.get(ModelVersion, UUID(version_id))
|
||||
if model:
|
||||
model.gating_status = status
|
||||
session.add(model)
|
||||
session.commit()
|
||||
@@ -3,6 +3,6 @@ fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
python-multipart>=0.0.6
|
||||
sqlmodel>=0.0.22
|
||||
ultralytics>=8.1.0
|
||||
ultralytics>=8.4.0
|
||||
httpx>=0.25.0
|
||||
openai>=1.0.0
|
||||
|
||||
@@ -1,37 +1,20 @@
|
||||
"""
|
||||
BBox Scale Strategy Module.
|
||||
BBox Expansion Module.
|
||||
|
||||
Provides field-specific bounding box expansion strategies for YOLO training data.
|
||||
Expands bboxes using center-point scaling with directional compensation to capture
|
||||
field labels that typically appear above or to the left of field values.
|
||||
|
||||
Two modes are supported:
|
||||
- Auto-label: Field-specific scale strategies with directional compensation
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
Provides uniform bounding box expansion for YOLO training data.
|
||||
|
||||
Usage:
|
||||
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
|
||||
from shared.bbox import expand_bbox, UNIFORM_PAD
|
||||
|
||||
Available exports:
|
||||
- ScaleStrategy: Dataclass for scale strategy configuration
|
||||
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
|
||||
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
|
||||
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
|
||||
- expand_bbox: Function to expand bbox using field-specific strategy
|
||||
- UNIFORM_PAD: Default uniform pixel padding (15px at 150 DPI)
|
||||
- expand_bbox: Function to expand bbox with uniform padding
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from .scale_strategy import UNIFORM_PAD
|
||||
from .expander import expand_bbox
|
||||
|
||||
__all__ = [
|
||||
"ScaleStrategy",
|
||||
"DEFAULT_STRATEGY",
|
||||
"MANUAL_LABEL_STRATEGY",
|
||||
"FIELD_SCALE_STRATEGIES",
|
||||
"UNIFORM_PAD",
|
||||
"expand_bbox",
|
||||
]
|
||||
|
||||
@@ -1,101 +1,35 @@
|
||||
"""
|
||||
BBox Expander Module.
|
||||
|
||||
Provides functions to expand bounding boxes using field-specific strategies.
|
||||
Expansion is center-point based with directional compensation.
|
||||
|
||||
Two modes:
|
||||
- Auto-label (default): Field-specific scale strategies
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
Expands bounding boxes by a uniform pixel padding on all sides,
|
||||
clamped to image boundaries. No field-specific or directional logic.
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from .scale_strategy import UNIFORM_PAD
|
||||
|
||||
|
||||
def expand_bbox(
|
||||
bbox: tuple[float, float, float, float],
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
field_type: str,
|
||||
strategies: dict[str, ScaleStrategy] | None = None,
|
||||
manual_mode: bool = False,
|
||||
pad: int = UNIFORM_PAD,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Expand bbox using field-specific scale strategy.
|
||||
|
||||
The expansion follows these steps:
|
||||
1. Scale bbox around center point (scale_x, scale_y)
|
||||
2. Apply directional compensation (extra_*_ratio)
|
||||
3. Clamp expansion to max_pad limits
|
||||
4. Clamp to image boundaries
|
||||
"""Expand bbox by uniform pixel padding, clamped to image bounds.
|
||||
|
||||
Args:
|
||||
bbox: (x0, y0, x1, y1) in pixels
|
||||
image_width: Image width for boundary clamping
|
||||
image_height: Image height for boundary clamping
|
||||
field_type: Field class_name (e.g., "ocr_number")
|
||||
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
|
||||
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
|
||||
bbox: (x0, y0, x1, y1) in pixels.
|
||||
image_width: Image width for boundary clamping.
|
||||
image_height: Image height for boundary clamping.
|
||||
pad: Uniform pixel padding on all sides (default: UNIFORM_PAD).
|
||||
|
||||
Returns:
|
||||
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
|
||||
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds.
|
||||
"""
|
||||
x0, y0, x1, y1 = bbox
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
# Get strategy based on mode
|
||||
if manual_mode:
|
||||
strategy = MANUAL_LABEL_STRATEGY
|
||||
elif strategies is None:
|
||||
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
|
||||
else:
|
||||
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
|
||||
|
||||
# Step 1: Scale around center point
|
||||
cx = (x0 + x1) / 2
|
||||
cy = (y0 + y1) / 2
|
||||
|
||||
new_w = w * strategy.scale_x
|
||||
new_h = h * strategy.scale_y
|
||||
|
||||
nx0 = cx - new_w / 2
|
||||
nx1 = cx + new_w / 2
|
||||
ny0 = cy - new_h / 2
|
||||
ny1 = cy + new_h / 2
|
||||
|
||||
# Step 2: Apply directional compensation
|
||||
nx0 -= w * strategy.extra_left_ratio
|
||||
nx1 += w * strategy.extra_right_ratio
|
||||
ny0 -= h * strategy.extra_top_ratio
|
||||
ny1 += h * strategy.extra_bottom_ratio
|
||||
|
||||
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
|
||||
left_pad = min(x0 - nx0, strategy.max_pad_x)
|
||||
right_pad = min(nx1 - x1, strategy.max_pad_x)
|
||||
top_pad = min(y0 - ny0, strategy.max_pad_y)
|
||||
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
|
||||
|
||||
# Ensure pads are non-negative (in case of contraction)
|
||||
left_pad = max(0, left_pad)
|
||||
right_pad = max(0, right_pad)
|
||||
top_pad = max(0, top_pad)
|
||||
bottom_pad = max(0, bottom_pad)
|
||||
|
||||
nx0 = x0 - left_pad
|
||||
nx1 = x1 + right_pad
|
||||
ny0 = y0 - top_pad
|
||||
ny1 = y1 + bottom_pad
|
||||
|
||||
# Step 4: Clamp to image boundaries
|
||||
nx0 = max(0, int(nx0))
|
||||
ny0 = max(0, int(ny0))
|
||||
nx1 = min(int(image_width), int(nx1))
|
||||
ny1 = min(int(image_height), int(ny1))
|
||||
nx0 = max(0, int(x0 - pad))
|
||||
ny0 = max(0, int(y0 - pad))
|
||||
nx1 = min(int(image_width), int(x1 + pad))
|
||||
ny1 = min(int(image_height), int(y1 + pad))
|
||||
|
||||
return (nx0, ny0, nx1, ny1)
|
||||
|
||||
@@ -1,140 +1,12 @@
|
||||
"""
|
||||
Scale Strategy Configuration.
|
||||
|
||||
Defines field-specific bbox expansion strategies for YOLO training data.
|
||||
Each strategy controls how bboxes are expanded around field values to
|
||||
capture contextual information like labels.
|
||||
Defines uniform bbox expansion padding for YOLO training data.
|
||||
All fields use the same fixed-pixel padding -- no layout assumptions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScaleStrategy:
|
||||
"""Immutable scale strategy for bbox expansion.
|
||||
|
||||
Attributes:
|
||||
scale_x: Horizontal scale factor (1.0 = no scaling)
|
||||
scale_y: Vertical scale factor (1.0 = no scaling)
|
||||
extra_top_ratio: Additional expansion ratio towards top (for labels above)
|
||||
extra_bottom_ratio: Additional expansion ratio towards bottom
|
||||
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
|
||||
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
|
||||
max_pad_x: Maximum horizontal padding in pixels
|
||||
max_pad_y: Maximum vertical padding in pixels
|
||||
"""
|
||||
|
||||
scale_x: float = 1.15
|
||||
scale_y: float = 1.15
|
||||
extra_top_ratio: float = 0.0
|
||||
extra_bottom_ratio: float = 0.0
|
||||
extra_left_ratio: float = 0.0
|
||||
extra_right_ratio: float = 0.0
|
||||
max_pad_x: int = 50
|
||||
max_pad_y: int = 50
|
||||
|
||||
|
||||
# Default strategy for unknown fields (auto-label mode)
|
||||
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
|
||||
|
||||
# Manual label strategy - minimal padding to prevent edge clipping
|
||||
# No scaling, no directional compensation, just small uniform padding
|
||||
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.0,
|
||||
extra_bottom_ratio=0.0,
|
||||
extra_left_ratio=0.0,
|
||||
extra_right_ratio=0.0,
|
||||
max_pad_x=10, # Small padding to prevent edge loss
|
||||
max_pad_y=10,
|
||||
)
|
||||
|
||||
|
||||
# Field-specific strategies based on Swedish invoice field characteristics
|
||||
# Field labels typically appear above or to the left of values
|
||||
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
|
||||
# OCR number - label "OCR" or "Referens" typically above
|
||||
"ocr_number": ScaleStrategy(
|
||||
scale_x=1.15,
|
||||
scale_y=1.80,
|
||||
extra_top_ratio=0.60,
|
||||
max_pad_x=50,
|
||||
max_pad_y=140,
|
||||
),
|
||||
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
|
||||
"bankgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
|
||||
"plusgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Invoice date - label "Fakturadatum" typically above
|
||||
"invoice_date": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.55,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=110,
|
||||
),
|
||||
# Due date - label "Forfalldatum" typically above, sometimes left
|
||||
"invoice_due_date": ScaleStrategy(
|
||||
scale_x=1.30,
|
||||
scale_y=1.65,
|
||||
extra_top_ratio=0.45,
|
||||
extra_left_ratio=0.35,
|
||||
max_pad_x=100,
|
||||
max_pad_y=120,
|
||||
),
|
||||
# Amount - currency symbol "SEK" or "kr" may be to the right
|
||||
"amount": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.35,
|
||||
extra_right_ratio=0.30,
|
||||
max_pad_x=70,
|
||||
max_pad_y=80,
|
||||
),
|
||||
# Invoice number - label "Fakturanummer" typically above
|
||||
"invoice_number": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.50,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Supplier org number - label "Org.nr" typically above or left
|
||||
"supplier_org_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.40,
|
||||
extra_top_ratio=0.30,
|
||||
extra_left_ratio=0.20,
|
||||
max_pad_x=90,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Customer number - label "Kundnummer" typically above or left
|
||||
"customer_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.45,
|
||||
extra_top_ratio=0.35,
|
||||
extra_left_ratio=0.25,
|
||||
max_pad_x=90,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Payment line - machine-readable code, minimal expansion needed
|
||||
"payment_line": ScaleStrategy(
|
||||
scale_x=1.10,
|
||||
scale_y=1.20,
|
||||
max_pad_x=40,
|
||||
max_pad_y=40,
|
||||
),
|
||||
}
|
||||
# 15px at 150 DPI = ~2.5mm real-world padding around text.
|
||||
# Enough for OCR safety margin without capturing neighboring label text.
|
||||
UNIFORM_PAD: Final[int] = 15
|
||||
|
||||
@@ -17,7 +17,7 @@ class TrainingConfig:
|
||||
"""Training configuration."""
|
||||
|
||||
# Model settings
|
||||
model_path: str = "yolo11n.pt" # Base model or path to trained model
|
||||
model_path: str = "yolo26s.pt" # Base model or path to trained model
|
||||
data_yaml: str = "" # Path to data.yaml
|
||||
|
||||
# Training hyperparameters
|
||||
@@ -39,6 +39,10 @@ class TrainingConfig:
|
||||
resume: bool = False
|
||||
resume_from: str | None = None # Path to checkpoint
|
||||
|
||||
# Fine-tuning specific
|
||||
freeze: int = 0 # Number of backbone layers to freeze (0 = none)
|
||||
cos_lr: bool = False # Use cosine learning rate scheduler
|
||||
|
||||
# Document-specific augmentation (optimized for invoices)
|
||||
augmentation: dict[str, Any] = field(default_factory=lambda: {
|
||||
"degrees": 5.0,
|
||||
@@ -106,7 +110,7 @@ class YOLOTrainer:
|
||||
# Check model path
|
||||
model_path = Path(self.config.model_path)
|
||||
if not model_path.suffix == ".pt":
|
||||
# Could be a model name like "yolo11n.pt" which is downloaded
|
||||
# Could be a model name like "yolo26s.pt" which is downloaded
|
||||
if not model_path.name.startswith("yolo"):
|
||||
return False, f"Invalid model: {self.config.model_path}"
|
||||
elif not model_path.exists():
|
||||
@@ -147,6 +151,10 @@ class YOLOTrainer:
|
||||
self._log("INFO", f" Epochs: {self.config.epochs}")
|
||||
self._log("INFO", f" Batch size: {self.config.batch_size}")
|
||||
self._log("INFO", f" Image size: {self.config.image_size}")
|
||||
if self.config.freeze > 0:
|
||||
self._log("INFO", f" Freeze layers: {self.config.freeze}")
|
||||
if self.config.cos_lr:
|
||||
self._log("INFO", f" Cosine LR: enabled")
|
||||
|
||||
try:
|
||||
# Load model
|
||||
@@ -178,6 +186,12 @@ class YOLOTrainer:
|
||||
"resume": self.config.resume and self.config.resume_from is not None,
|
||||
}
|
||||
|
||||
# Add fine-tuning settings
|
||||
if self.config.freeze > 0:
|
||||
train_args["freeze"] = self.config.freeze
|
||||
if self.config.cos_lr:
|
||||
train_args["cos_lr"] = True
|
||||
|
||||
# Add augmentation settings
|
||||
train_args.update(self.config.augmentation)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
-e ../shared
|
||||
ultralytics>=8.1.0
|
||||
ultralytics>=8.4.0
|
||||
tqdm>=4.65.0
|
||||
torch>=2.0.0
|
||||
|
||||
@@ -34,7 +34,7 @@ def execute_training_task(db: TrainingTaskDB, task: dict) -> None:
|
||||
result = run_training(
|
||||
epochs=config.get("epochs", 100),
|
||||
batch=config.get("batch_size", 16),
|
||||
model=config.get("base_model", "yolo11n.pt"),
|
||||
model=config.get("base_model", "yolo26s.pt"),
|
||||
imgsz=config.get("imgsz", 1280),
|
||||
name=config.get("name", f"training_{task_id[:8]}"),
|
||||
)
|
||||
|
||||
@@ -28,8 +28,8 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
default='yolov8s.pt',
|
||||
help='Base model (default: yolov8s.pt)'
|
||||
default='yolo26s.pt',
|
||||
help='Base model (default: yolo26s.pt)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--epochs', '-e',
|
||||
|
||||
@@ -100,12 +100,11 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = best_match.bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Apply field-specific bbox expansion strategy
|
||||
# Apply uniform bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
@@ -173,12 +172,11 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = payment_line_bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Apply field-specific bbox expansion strategy for payment_line
|
||||
# Apply uniform bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type="payment_line",
|
||||
)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
|
||||
@@ -585,15 +585,11 @@ class DBYOLODataset:
|
||||
x1_px = x1_pdf * scale
|
||||
y1_px = y1_pdf * scale
|
||||
|
||||
# Get class name for field-specific expansion
|
||||
class_name = CLASS_NAMES[ann.class_id]
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
# Apply uniform bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0_px, y0_px, x1_px, y1_px),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
|
||||
Reference in New Issue
Block a user