This commit is contained in:
Yaojia Wang
2026-01-30 00:44:21 +01:00
parent d2489a97d4
commit 33ada0350d
79 changed files with 9737 additions and 297 deletions

View File

@@ -0,0 +1,24 @@
"""
Document Image Augmentation Module.
Provides augmentation transformations for training data enhancement,
specifically designed for document images (invoices, forms, etc.).
Key features:
- Document-safe augmentations that preserve text readability
- Support for both offline preprocessing and runtime augmentation
- Bbox-aware geometric transforms
- Configurable augmentation pipeline
"""
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.dataset_augmenter import DatasetAugmenter
__all__ = [
"AugmentationConfig",
"AugmentationParams",
"AugmentationResult",
"BaseAugmentation",
"DatasetAugmenter",
]

View File

@@ -0,0 +1,108 @@
"""
Base classes for augmentation transforms.
Provides abstract base class and result dataclass for all augmentation
implementations.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
import numpy as np
@dataclass
class AugmentationResult:
"""
Result of applying an augmentation.
Attributes:
image: The augmented image as numpy array (H, W, C).
bboxes: Updated bounding boxes if geometric transform was applied.
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
applied: Whether the augmentation was actually applied.
metadata: Additional metadata about the augmentation.
"""
image: np.ndarray
bboxes: np.ndarray | None = None
transform_matrix: np.ndarray | None = None
applied: bool = True
metadata: dict[str, Any] | None = None
class BaseAugmentation(ABC):
"""
Abstract base class for all augmentations.
Subclasses must implement:
- _validate_params(): Validate augmentation parameters
- apply(): Apply the augmentation to an image
Class attributes:
name: Human-readable name of the augmentation.
affects_geometry: True if this augmentation modifies bbox coordinates.
"""
name: str = "base"
affects_geometry: bool = False
def __init__(self, params: dict[str, Any]) -> None:
"""
Initialize augmentation with parameters.
Args:
params: Dictionary of augmentation-specific parameters.
"""
self.params = params
self._validate_params()
@abstractmethod
def _validate_params(self) -> None:
"""
Validate augmentation parameters.
Raises:
ValueError: If parameters are invalid.
"""
pass
@abstractmethod
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
"""
Apply augmentation to image.
IMPORTANT: Implementations must NOT modify the input image or bboxes.
Always create copies before modifying.
Args:
image: Input image as numpy array (H, W, C) with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
Each row: [class_id, x_center, y_center, width, height].
Coordinates are normalized to 0-1 range.
rng: Random number generator for reproducibility.
If None, a new generator should be created.
Returns:
AugmentationResult with augmented image and optionally updated bboxes.
"""
pass
def get_preview_params(self) -> dict[str, Any]:
"""
Get parameters optimized for preview display.
Override this method to provide parameters that produce
clearly visible effects for preview/demo purposes.
Returns:
Dictionary of preview parameters.
"""
return dict(self.params)

View File

@@ -0,0 +1,274 @@
"""
Augmentation configuration module.
Provides dataclasses for configuring document image augmentations.
All default values are document-safe (conservative) to preserve text readability.
"""
from dataclasses import dataclass, field
from typing import Any
@dataclass
class AugmentationParams:
"""
Parameters for a single augmentation type.
Attributes:
enabled: Whether this augmentation is enabled.
probability: Probability of applying this augmentation (0.0 to 1.0).
params: Type-specific parameters dictionary.
"""
enabled: bool = False
probability: float = 0.5
params: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
return {
"enabled": self.enabled,
"probability": self.probability,
"params": dict(self.params),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
"""Create from dictionary."""
return cls(
enabled=data.get("enabled", False),
probability=data.get("probability", 0.5),
params=dict(data.get("params", {})),
)
def _default_perspective_warp() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
)
def _default_wrinkle() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
)
def _default_edge_damage() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
)
def _default_stain() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={
"num_stains": (1, 3),
"max_radius_ratio": 0.1,
"opacity": (0.1, 0.3),
},
)
def _default_lighting_variation() -> AugmentationParams:
return AugmentationParams(
enabled=True, # Safe default, commonly needed
probability=0.5,
params={
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
)
def _default_shadow() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
)
def _default_gaussian_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
)
def _default_motion_blur() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
)
def _default_gaussian_noise() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
)
def _default_salt_pepper() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"amount": (0.001, 0.005)}, # Very sparse
)
def _default_paper_texture() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.3,
params={"texture_type": "random", "intensity": (0.05, 0.15)},
)
def _default_scanner_artifacts() -> AugmentationParams:
return AugmentationParams(
enabled=False,
probability=0.2,
params={"line_probability": 0.3, "dust_probability": 0.4},
)
@dataclass
class AugmentationConfig:
"""
Complete augmentation configuration.
All augmentation types have document-safe defaults that preserve
text readability. Only lighting_variation is enabled by default.
Attributes:
perspective_warp: Geometric perspective transform (affects bboxes).
wrinkle: Paper wrinkle/crease simulation.
edge_damage: Damaged/torn edge effects.
stain: Coffee stain/smudge effects.
lighting_variation: Brightness and contrast variation.
shadow: Shadow overlay effects.
gaussian_blur: Gaussian blur for focus issues.
motion_blur: Motion blur simulation.
gaussian_noise: Gaussian noise for sensor noise.
salt_pepper: Salt and pepper noise.
paper_texture: Paper texture overlay.
scanner_artifacts: Scanner line and dust artifacts.
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
seed: Random seed for reproducibility.
"""
# Geometric transforms (affects bboxes)
perspective_warp: AugmentationParams = field(
default_factory=_default_perspective_warp
)
# Degradation effects
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
stain: AugmentationParams = field(default_factory=_default_stain)
# Lighting effects
lighting_variation: AugmentationParams = field(
default_factory=_default_lighting_variation
)
shadow: AugmentationParams = field(default_factory=_default_shadow)
# Blur effects
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
# Noise effects
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
# Texture effects
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
scanner_artifacts: AugmentationParams = field(
default_factory=_default_scanner_artifacts
)
# Global settings
preserve_bboxes: bool = True
seed: int | None = None
# List of all augmentation field names
_AUGMENTATION_FIELDS: tuple[str, ...] = (
"perspective_warp",
"wrinkle",
"edge_damage",
"stain",
"lighting_variation",
"shadow",
"gaussian_blur",
"motion_blur",
"gaussian_noise",
"salt_pepper",
"paper_texture",
"scanner_artifacts",
)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for serialization."""
result: dict[str, Any] = {
"preserve_bboxes": self.preserve_bboxes,
"seed": self.seed,
}
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
result[field_name] = params.to_dict()
return result
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
"""Create from dictionary."""
kwargs: dict[str, Any] = {
"preserve_bboxes": data.get("preserve_bboxes", True),
"seed": data.get("seed"),
}
for field_name in cls._AUGMENTATION_FIELDS:
if field_name in data:
field_data = data[field_name]
if isinstance(field_data, dict):
kwargs[field_name] = AugmentationParams.from_dict(field_data)
return cls(**kwargs)
def get_enabled_augmentations(self) -> list[str]:
"""Get list of enabled augmentation names."""
enabled = []
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if params.enabled:
enabled.append(field_name)
return enabled
def validate(self) -> None:
"""
Validate configuration.
Raises:
ValueError: If any configuration value is invalid.
"""
for field_name in self._AUGMENTATION_FIELDS:
params: AugmentationParams = getattr(self, field_name)
if not (0.0 <= params.probability <= 1.0):
raise ValueError(
f"{field_name}.probability must be between 0 and 1, "
f"got {params.probability}"
)

View File

@@ -0,0 +1,206 @@
"""
Dataset Augmenter Module.
Applies augmentation pipeline to YOLO datasets,
creating new augmented images and label files.
"""
import logging
from pathlib import Path
from typing import Any
import numpy as np
from PIL import Image
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.pipeline import AugmentationPipeline
logger = logging.getLogger(__name__)
class DatasetAugmenter:
"""
Augments YOLO datasets by creating new images and label files.
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
"""
def __init__(
self,
config: dict[str, Any],
seed: int | None = None,
) -> None:
"""
Initialize augmenter with configuration.
Args:
config: Dictionary mapping augmentation names to their settings.
Each augmentation should have 'enabled', 'probability', and 'params'.
seed: Random seed for reproducibility.
"""
self._config_dict = config
self._seed = seed
self._config = self._build_config(config, seed)
def _build_config(
self,
config_dict: dict[str, Any],
seed: int | None,
) -> AugmentationConfig:
"""Build AugmentationConfig from dictionary."""
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
for aug_name, aug_settings in config_dict.items():
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
kwargs[aug_name] = AugmentationParams(
enabled=aug_settings.get("enabled", False),
probability=aug_settings.get("probability", 0.5),
params=aug_settings.get("params", {}),
)
return AugmentationConfig(**kwargs)
def augment_dataset(
self,
dataset_path: Path,
multiplier: int = 1,
split: str = "train",
) -> dict[str, int]:
"""
Augment a YOLO dataset.
Args:
dataset_path: Path to dataset root (containing images/ and labels/).
multiplier: Number of augmented copies per original image.
split: Which split to augment (default: "train").
Returns:
Summary dict with original_images, augmented_images, total_images.
"""
images_dir = dataset_path / "images" / split
labels_dir = dataset_path / "labels" / split
if not images_dir.exists():
raise ValueError(f"Images directory not found: {images_dir}")
# Find all images
image_extensions = ("*.png", "*.jpg", "*.jpeg")
image_files: list[Path] = []
for ext in image_extensions:
image_files.extend(images_dir.glob(ext))
original_count = len(image_files)
augmented_count = 0
if multiplier <= 0:
return {
"original_images": original_count,
"augmented_images": 0,
"total_images": original_count,
}
# Process each image
for img_path in image_files:
# Load image
pil_image = Image.open(img_path).convert("RGB")
image = np.array(pil_image)
# Load corresponding label
label_path = labels_dir / f"{img_path.stem}.txt"
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
# Create multiple augmented versions
for aug_idx in range(multiplier):
# Create pipeline with adjusted seed for each augmentation
aug_seed = None
if self._seed is not None:
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
pipeline = AugmentationPipeline(
self._build_config(self._config_dict, aug_seed)
)
# Apply augmentation
result = pipeline.apply(image, bboxes)
# Save augmented image
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
aug_img_path = images_dir / aug_name
aug_pil = Image.fromarray(result.image)
aug_pil.save(aug_img_path)
# Save augmented label
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
self._save_bboxes(aug_label_path, result.bboxes)
augmented_count += 1
logger.info(
"Dataset augmentation complete: %d original, %d augmented",
original_count,
augmented_count,
)
return {
"original_images": original_count,
"augmented_images": augmented_count,
"total_images": original_count + augmented_count,
}
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
"""
Load bounding boxes from YOLO label file.
Args:
label_path: Path to label file.
Returns:
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
Returns None if file is empty or doesn't exist.
"""
if not label_path.exists():
return None
content = label_path.read_text().strip()
if not content:
return None
bboxes = []
for line in content.split("\n"):
parts = line.strip().split()
if len(parts) == 5:
class_id = int(parts[0])
x_center = float(parts[1])
y_center = float(parts[2])
width = float(parts[3])
height = float(parts[4])
bboxes.append([class_id, x_center, y_center, width, height])
if not bboxes:
return None
return np.array(bboxes, dtype=np.float32)
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
"""
Save bounding boxes to YOLO label file.
Args:
label_path: Path to save label file.
bboxes: Array of shape (N, 5) or None for empty labels.
"""
if bboxes is None or len(bboxes) == 0:
label_path.write_text("")
return
lines = []
for bbox in bboxes:
class_id = int(bbox[0])
x_center = bbox[1]
y_center = bbox[2]
width = bbox[3]
height = bbox[4]
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
label_path.write_text("\n".join(lines))

View File

@@ -0,0 +1,184 @@
"""
Augmentation pipeline module.
Orchestrates multiple augmentations with proper ordering and
provides preview functionality.
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
from shared.augmentation.config import AugmentationConfig, AugmentationParams
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
from shared.augmentation.transforms.geometric import PerspectiveWarp
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
# Registry of augmentation classes
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
"perspective_warp": PerspectiveWarp,
"wrinkle": Wrinkle,
"edge_damage": EdgeDamage,
"stain": Stain,
"lighting_variation": LightingVariation,
"shadow": Shadow,
"gaussian_blur": GaussianBlur,
"motion_blur": MotionBlur,
"gaussian_noise": GaussianNoise,
"salt_pepper": SaltPepper,
"paper_texture": PaperTexture,
"scanner_artifacts": ScannerArtifacts,
}
class AugmentationPipeline:
"""
Orchestrates multiple augmentations with proper ordering.
Augmentations are applied in the following order:
1. Geometric (perspective_warp) - affects bboxes
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
3. Lighting (lighting_variation, shadow)
4. Texture (paper_texture, scanner_artifacts)
5. Blur (gaussian_blur, motion_blur)
6. Noise (gaussian_noise, salt_pepper) - applied last
"""
STAGE_ORDER = [
"geometric",
"degradation",
"lighting",
"texture",
"blur",
"noise",
]
STAGE_MAPPING = {
"perspective_warp": "geometric",
"wrinkle": "degradation",
"edge_damage": "degradation",
"stain": "degradation",
"lighting_variation": "lighting",
"shadow": "lighting",
"paper_texture": "texture",
"scanner_artifacts": "texture",
"gaussian_blur": "blur",
"motion_blur": "blur",
"gaussian_noise": "noise",
"salt_pepper": "noise",
}
def __init__(self, config: AugmentationConfig) -> None:
"""
Initialize pipeline with configuration.
Args:
config: Augmentation configuration.
"""
self.config = config
self._rng = np.random.default_rng(config.seed)
self._augmentations = self._build_augmentations()
def _build_augmentations(
self,
) -> list[tuple[str, BaseAugmentation, float]]:
"""Build ordered list of (name, augmentation, probability) tuples."""
augmentations: list[tuple[str, BaseAugmentation, float]] = []
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
params: AugmentationParams = getattr(self.config, aug_name)
if params.enabled:
aug = aug_class(params.params)
augmentations.append((aug_name, aug, params.probability))
# Sort by stage order
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
name, _, _ = item
stage = self.STAGE_MAPPING[name]
return self.STAGE_ORDER.index(stage)
return sorted(augmentations, key=sort_key)
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
) -> AugmentationResult:
"""
Apply augmentation pipeline to image.
Args:
image: Input image (H, W, C) as numpy array with dtype uint8.
bboxes: Optional bounding boxes in YOLO format (N, 5).
Returns:
AugmentationResult with augmented image and optionally adjusted bboxes.
"""
current_image = image.copy()
current_bboxes = bboxes.copy() if bboxes is not None else None
applied_augmentations: list[str] = []
for name, aug, probability in self._augmentations:
if self._rng.random() < probability:
result = aug.apply(current_image, current_bboxes, self._rng)
current_image = result.image
if result.bboxes is not None and self.config.preserve_bboxes:
current_bboxes = result.bboxes
applied_augmentations.append(name)
return AugmentationResult(
image=current_image,
bboxes=current_bboxes,
metadata={"applied_augmentations": applied_augmentations},
)
def preview(
self,
image: np.ndarray,
augmentation_name: str,
) -> np.ndarray:
"""
Preview a single augmentation deterministically.
Args:
image: Input image.
augmentation_name: Name of augmentation to preview.
Returns:
Augmented image.
Raises:
ValueError: If augmentation_name is not recognized.
"""
if augmentation_name not in AUGMENTATION_REGISTRY:
raise ValueError(f"Unknown augmentation: {augmentation_name}")
params: AugmentationParams = getattr(self.config, augmentation_name)
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
# Use deterministic RNG for preview
preview_rng = np.random.default_rng(42)
result = aug.apply(image.copy(), rng=preview_rng)
return result.image
def get_available_augmentations() -> list[dict[str, Any]]:
"""
Get list of available augmentations with metadata.
Returns:
List of dictionaries with augmentation info.
"""
augmentations = []
for name, aug_class in AUGMENTATION_REGISTRY.items():
augmentations.append({
"name": name,
"description": aug_class.__doc__ or "",
"affects_geometry": aug_class.affects_geometry,
"stage": AugmentationPipeline.STAGE_MAPPING[name],
})
return augmentations

View File

@@ -0,0 +1,212 @@
"""
Predefined augmentation presets for common document scenarios.
Presets provide ready-to-use configurations optimized for different
use cases, from conservative (preserves text readability) to aggressive
(simulates poor document quality).
"""
from typing import Any
from shared.augmentation.config import AugmentationConfig, AugmentationParams
PRESETS: dict[str, dict[str, Any]] = {
"conservative": {
"description": "Safe augmentations that preserve text readability",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (3, 10)},
},
},
},
"moderate": {
"description": "Balanced augmentations for typical document degradation",
"config": {
"lighting_variation": {
"enabled": True,
"probability": 0.5,
"params": {
"brightness_range": (-0.15, 0.15),
"contrast_range": (0.85, 1.15),
},
},
"shadow": {
"enabled": True,
"probability": 0.3,
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
},
"paper_texture": {
"enabled": True,
"probability": 0.3,
"params": {"intensity": (0.05, 0.12)},
},
},
},
"aggressive": {
"description": "Heavy augmentations simulating poor scan quality",
"config": {
"perspective_warp": {
"enabled": True,
"probability": 0.3,
"params": {"max_warp": 0.02},
},
"wrinkle": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
},
"stain": {
"enabled": True,
"probability": 0.3,
"params": {
"num_stains": (1, 2),
"max_radius_ratio": 0.08,
"opacity": (0.1, 0.25),
},
},
"lighting_variation": {
"enabled": True,
"probability": 0.6,
"params": {
"brightness_range": (-0.2, 0.2),
"contrast_range": (0.8, 1.2),
},
},
"shadow": {
"enabled": True,
"probability": 0.4,
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
},
"gaussian_blur": {
"enabled": True,
"probability": 0.3,
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
},
"motion_blur": {
"enabled": True,
"probability": 0.2,
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.4,
"params": {"std": (8, 18)},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.08, 0.15)},
},
"scanner_artifacts": {
"enabled": True,
"probability": 0.3,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"edge_damage": {
"enabled": True,
"probability": 0.2,
"params": {"max_damage_ratio": 0.04},
},
},
},
"scanned_document": {
"description": "Simulates typical scanned document artifacts",
"config": {
"scanner_artifacts": {
"enabled": True,
"probability": 0.5,
"params": {"line_probability": 0.4, "dust_probability": 0.5},
},
"paper_texture": {
"enabled": True,
"probability": 0.4,
"params": {"intensity": (0.05, 0.12)},
},
"lighting_variation": {
"enabled": True,
"probability": 0.3,
"params": {
"brightness_range": (-0.1, 0.1),
"contrast_range": (0.9, 1.1),
},
},
"gaussian_noise": {
"enabled": True,
"probability": 0.3,
"params": {"std": (5, 12)},
},
},
},
}
def get_preset_config(preset_name: str) -> dict[str, Any]:
"""
Get the configuration dictionary for a preset.
Args:
preset_name: Name of the preset.
Returns:
Configuration dictionary.
Raises:
ValueError: If preset is not found.
"""
if preset_name not in PRESETS:
raise ValueError(
f"Unknown preset: {preset_name}. "
f"Available presets: {list(PRESETS.keys())}"
)
return PRESETS[preset_name]["config"]
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
"""
Create an AugmentationConfig from a preset.
Args:
preset_name: Name of the preset.
Returns:
AugmentationConfig instance.
Raises:
ValueError: If preset is not found.
"""
config_dict = get_preset_config(preset_name)
return AugmentationConfig.from_dict(config_dict)
def list_presets() -> list[dict[str, str]]:
"""
List all available presets.
Returns:
List of dictionaries with name and description.
"""
return [
{"name": name, "description": preset["description"]}
for name, preset in PRESETS.items()
]

View File

@@ -0,0 +1,13 @@
"""
Augmentation transform implementations.
Each module contains related augmentation classes:
- geometric.py: Perspective warp and other geometric transforms
- degradation.py: Wrinkle, edge damage, stain effects
- lighting.py: Lighting variation and shadow effects
- blur.py: Gaussian and motion blur
- noise.py: Gaussian and salt-pepper noise
- texture.py: Paper texture and scanner artifacts
"""
# Will be populated as transforms are implemented

View File

@@ -0,0 +1,144 @@
"""
Blur augmentation transforms.
Provides blur effects for document image augmentation:
- GaussianBlur: Simulates out-of-focus capture
- MotionBlur: Simulates camera/document movement during capture
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianBlur(BaseAugmentation):
"""
Applies Gaussian blur to the image.
Simulates out-of-focus capture or low-quality optics.
Conservative defaults to preserve text readability.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
"""
name = "gaussian_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (3, 5))
if isinstance(kernel_size, int):
if kernel_size < 1 or kernel_size % 2 == 0:
raise ValueError("kernel_size must be a positive odd integer")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (3, 5))
sigma = self.params.get("sigma", (0.5, 1.5))
if isinstance(kernel_size, tuple):
# Choose random odd kernel size
min_k, max_k = kernel_size
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
if not possible_sizes:
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
kernel_size = rng.choice(possible_sizes)
if isinstance(sigma, tuple):
sigma = rng.uniform(sigma[0], sigma[1])
# Ensure kernel size is odd
if kernel_size % 2 == 0:
kernel_size += 1
# Apply Gaussian blur
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "sigma": sigma},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 5, "sigma": 1.5}
class MotionBlur(BaseAugmentation):
"""
Applies motion blur to the image.
Simulates camera shake or document movement during capture.
Parameters:
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
angle_range: Motion angle range in degrees (default: (-45, 45)).
"""
name = "motion_blur"
affects_geometry = False
def _validate_params(self) -> None:
kernel_size = self.params.get("kernel_size", (5, 9))
if isinstance(kernel_size, int):
if kernel_size < 3:
raise ValueError("kernel_size must be at least 3")
elif isinstance(kernel_size, tuple):
if kernel_size[0] < 3:
raise ValueError("kernel_size min must be at least 3")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
kernel_size = self.params.get("kernel_size", (5, 9))
angle_range = self.params.get("angle_range", (-45, 45))
if isinstance(kernel_size, tuple):
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
angle = rng.uniform(angle_range[0], angle_range[1])
# Create motion blur kernel
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
# Draw a line in the center of the kernel
center = kernel_size // 2
angle_rad = np.deg2rad(angle)
for i in range(kernel_size):
offset = i - center
x = int(center + offset * np.cos(angle_rad))
y = int(center + offset * np.sin(angle_rad))
if 0 <= x < kernel_size and 0 <= y < kernel_size:
kernel[y, x] = 1.0
# Normalize kernel
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
# Apply motion blur
blurred = cv2.filter2D(image, -1, kernel)
return AugmentationResult(
image=blurred,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"kernel_size": kernel_size, "angle": angle},
)
def get_preview_params(self) -> dict:
return {"kernel_size": 7, "angle_range": (-30, 30)}

View File

@@ -0,0 +1,259 @@
"""
Degradation augmentation transforms.
Provides degradation effects for document image augmentation:
- Wrinkle: Paper wrinkle/crease simulation
- EdgeDamage: Damaged/torn edge effects
- Stain: Coffee stain/smudge effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class Wrinkle(BaseAugmentation):
"""
Simulates paper wrinkles/creases using displacement mapping.
Document-friendly: Uses subtle displacement to preserve text readability.
Parameters:
intensity: Wrinkle intensity (0-1) (default: 0.3).
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
"""
name = "wrinkle"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", 0.3)
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
intensity = self.params.get("intensity", 0.3)
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
if isinstance(num_wrinkles, tuple):
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
# Create displacement maps
displacement_x = np.zeros((h, w), dtype=np.float32)
displacement_y = np.zeros((h, w), dtype=np.float32)
for _ in range(num_wrinkles):
# Random wrinkle parameters
angle = rng.uniform(0, np.pi)
x0 = rng.uniform(0, w)
y0 = rng.uniform(0, h)
length = rng.uniform(0.3, 0.8) * min(h, w)
width = rng.uniform(0.02, 0.05) * min(h, w)
# Create coordinate grids
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
# Distance from wrinkle line
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
# Gaussian falloff perpendicular to wrinkle
mask = np.exp(-dy**2 / (2 * width**2))
mask *= (np.abs(dx) < length / 2).astype(np.float32)
# Displacement perpendicular to wrinkle
disp_amount = intensity * rng.uniform(2, 8)
displacement_x += mask * disp_amount * np.sin(angle)
displacement_y += mask * disp_amount * np.cos(angle)
# Create remap coordinates
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
# Apply displacement
augmented = cv2.remap(
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
)
# Add subtle shading along wrinkles
max_disp = np.max(np.abs(displacement_y)) + 1e-6
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
shading = shading[:, :, np.newaxis]
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
return AugmentationResult(
image=augmented,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"intensity": 0.5, "num_wrinkles": 3}
class EdgeDamage(BaseAugmentation):
"""
Adds damaged/torn edge effects to the image.
Simulates worn or torn document edges.
Parameters:
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
edges: Which edges to potentially damage (default: all).
"""
name = "edge_damage"
affects_geometry = False
def _validate_params(self) -> None:
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
if not (0 < max_damage_ratio <= 0.2):
raise ValueError("max_damage_ratio must be between 0 and 0.2")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
output = image.copy()
# Select random edge to damage
edge = rng.choice(edges)
damage_size = int(max_damage_ratio * min(h, w))
if edge == "top":
# Create irregular top edge
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
# Random color (white or darker)
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[:depth, x] = color
elif edge == "bottom":
for x in range(w):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[h - depth:, x] = color
elif edge == "left":
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, :depth] = color
else: # right
for y in range(h):
depth = rng.integers(0, damage_size + 1)
if depth > 0:
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
output[y, w - depth:] = color
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"edge": edge, "damage_size": damage_size},
)
def get_preview_params(self) -> dict:
return {"max_damage_ratio": 0.08}
class Stain(BaseAugmentation):
"""
Adds coffee stain/smudge effects to the image.
Simulates accidental stains on documents.
Parameters:
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
"""
name = "stain"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(opacity, (int, float)):
if not (0 < opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
num_stains = self.params.get("num_stains", (1, 3))
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
opacity = self.params.get("opacity", (0.1, 0.3))
if isinstance(num_stains, tuple):
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
output = image.astype(np.float32)
max_radius = int(max_radius_ratio * min(h, w))
for _ in range(num_stains):
# Random stain position and size
cx = rng.integers(max_radius, w - max_radius)
cy = rng.integers(max_radius, h - max_radius)
radius = rng.integers(max_radius // 3, max_radius)
# Create stain mask with irregular edges
yy, xx = np.ogrid[:h, :w]
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
# Add noise to make edges irregular
noise = rng.uniform(0.8, 1.2, (h, w))
mask = (dist < radius * noise).astype(np.float32)
# Blur for soft edges
mask = cv2.GaussianBlur(mask, (21, 21), 0)
# Random stain color (brownish/yellowish)
stain_color = np.array([
rng.integers(180, 220), # R
rng.integers(160, 200), # G
rng.integers(120, 160), # B
], dtype=np.float32)
# Apply stain
mask_3d = mask[:, :, np.newaxis]
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_stains": num_stains, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}

View File

@@ -0,0 +1,145 @@
"""
Geometric augmentation transforms.
Provides geometric transforms for document image augmentation:
- PerspectiveWarp: Subtle perspective distortion
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PerspectiveWarp(BaseAugmentation):
"""
Applies subtle perspective transformation to the image.
Simulates viewing document at slight angle. Very conservative
by default to preserve text readability.
IMPORTANT: This transform affects bounding box coordinates.
Parameters:
max_warp: Maximum warp as proportion of image size (default: 0.02).
"""
name = "perspective_warp"
affects_geometry = True
def _validate_params(self) -> None:
max_warp = self.params.get("max_warp", 0.02)
if not (0 < max_warp <= 0.1):
raise ValueError("max_warp must be between 0 and 0.1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
max_warp = self.params.get("max_warp", 0.02)
# Original corners
src_pts = np.float32([
[0, 0],
[w, 0],
[w, h],
[0, h],
])
# Add random perturbations to corners
max_offset = max_warp * min(h, w)
dst_pts = src_pts.copy()
for i in range(4):
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
# Compute perspective transform matrix
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
# Apply perspective transform
warped = cv2.warpPerspective(
image, transform_matrix, (w, h),
borderMode=cv2.BORDER_REPLICATE
)
# Transform bounding boxes if present
transformed_bboxes = None
if bboxes is not None:
transformed_bboxes = self._transform_bboxes(
bboxes, transform_matrix, w, h
)
return AugmentationResult(
image=warped,
bboxes=transformed_bboxes,
transform_matrix=transform_matrix,
metadata={"max_warp": max_warp},
)
def _transform_bboxes(
self,
bboxes: np.ndarray,
transform_matrix: np.ndarray,
w: int,
h: int,
) -> np.ndarray:
"""Transform bounding boxes using perspective matrix."""
if len(bboxes) == 0:
return bboxes.copy()
transformed = []
for bbox in bboxes:
class_id, x_center, y_center, width, height = bbox
# Convert normalized coords to pixel coords
x_center_px = x_center * w
y_center_px = y_center * h
width_px = width * w
height_px = height * h
# Get corner points
x1 = x_center_px - width_px / 2
y1 = y_center_px - height_px / 2
x2 = x_center_px + width_px / 2
y2 = y_center_px + height_px / 2
# Transform all 4 corners
corners = np.float32([
[x1, y1],
[x2, y1],
[x2, y2],
[x1, y2],
]).reshape(-1, 1, 2)
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
transformed_corners = transformed_corners.reshape(-1, 2)
# Get bounding box of transformed corners
new_x1 = np.min(transformed_corners[:, 0])
new_y1 = np.min(transformed_corners[:, 1])
new_x2 = np.max(transformed_corners[:, 0])
new_y2 = np.max(transformed_corners[:, 1])
# Convert back to normalized center format
new_width = (new_x2 - new_x1) / w
new_height = (new_y2 - new_y1) / h
new_x_center = ((new_x1 + new_x2) / 2) / w
new_y_center = ((new_y1 + new_y2) / 2) / h
# Clamp to valid range
new_x_center = np.clip(new_x_center, 0, 1)
new_y_center = np.clip(new_y_center, 0, 1)
new_width = np.clip(new_width, 0, 1)
new_height = np.clip(new_height, 0, 1)
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
return np.array(transformed, dtype=np.float32)
def get_preview_params(self) -> dict:
return {"max_warp": 0.03}

View File

@@ -0,0 +1,167 @@
"""
Lighting augmentation transforms.
Provides lighting effects for document image augmentation:
- LightingVariation: Adjusts brightness and contrast
- Shadow: Adds shadow overlay effects
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class LightingVariation(BaseAugmentation):
"""
Adjusts image brightness and contrast.
Simulates different lighting conditions during document capture.
Safe for documents with conservative default parameters.
Parameters:
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
"""
name = "lighting_variation"
affects_geometry = False
def _validate_params(self) -> None:
brightness = self.params.get("brightness_range", (-0.1, 0.1))
contrast = self.params.get("contrast_range", (0.9, 1.1))
if not isinstance(brightness, tuple) or len(brightness) != 2:
raise ValueError("brightness_range must be a (min, max) tuple")
if not isinstance(contrast, tuple) or len(contrast) != 2:
raise ValueError("contrast_range must be a (min, max) tuple")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
# Random brightness and contrast
brightness = rng.uniform(brightness_range[0], brightness_range[1])
contrast = rng.uniform(contrast_range[0], contrast_range[1])
# Apply adjustments
adjusted = image.astype(np.float32)
# Contrast adjustment (multiply around mean)
mean = adjusted.mean()
adjusted = (adjusted - mean) * contrast + mean
# Brightness adjustment (add offset)
adjusted = adjusted + brightness * 255
# Clip and convert back
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
return AugmentationResult(
image=adjusted,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"brightness": brightness, "contrast": contrast},
)
def get_preview_params(self) -> dict:
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
class Shadow(BaseAugmentation):
"""
Adds shadow overlay effects to the image.
Simulates shadows from objects or hands during document capture.
Parameters:
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
"""
name = "shadow"
affects_geometry = False
def _validate_params(self) -> None:
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(opacity, (int, float)):
if not (0 <= opacity <= 1):
raise ValueError("opacity must be between 0 and 1")
elif isinstance(opacity, tuple):
if not (0 <= opacity[0] <= opacity[1] <= 1):
raise ValueError("opacity tuple must be in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
num_shadows = self.params.get("num_shadows", (1, 2))
opacity = self.params.get("opacity", (0.2, 0.4))
if isinstance(num_shadows, tuple):
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
if isinstance(opacity, tuple):
opacity = rng.uniform(opacity[0], opacity[1])
h, w = image.shape[:2]
output = image.astype(np.float32)
for _ in range(num_shadows):
# Generate random shadow polygon
num_vertices = rng.integers(3, 6)
vertices = []
# Start from a random edge
edge = rng.integers(0, 4)
if edge == 0: # Top
start = (rng.integers(0, w), 0)
elif edge == 1: # Right
start = (w, rng.integers(0, h))
elif edge == 2: # Bottom
start = (rng.integers(0, w), h)
else: # Left
start = (0, rng.integers(0, h))
vertices.append(start)
# Add random vertices
for _ in range(num_vertices - 1):
x = rng.integers(0, w)
y = rng.integers(0, h)
vertices.append((x, y))
# Create shadow mask
mask = np.zeros((h, w), dtype=np.float32)
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
cv2.fillPoly(mask, [pts], 1.0)
# Blur the mask for soft edges
blur_size = max(31, min(h, w) // 10)
if blur_size % 2 == 0:
blur_size += 1
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
# Apply shadow
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
output = output * shadow_factor
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"num_shadows": num_shadows, "opacity": opacity},
)
def get_preview_params(self) -> dict:
return {"num_shadows": 1, "opacity": 0.3}

View File

@@ -0,0 +1,142 @@
"""
Noise augmentation transforms.
Provides noise effects for document image augmentation:
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
- SaltPepper: Adds salt and pepper noise for impulse noise effects
"""
from typing import Any
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class GaussianNoise(BaseAugmentation):
"""
Adds Gaussian noise to the image.
Simulates sensor noise from cameras or scanners.
Document-safe with conservative default parameters.
Parameters:
mean: Mean of the Gaussian noise (default: 0).
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
"""
name = "gaussian_noise"
affects_geometry = False
def _validate_params(self) -> None:
std = self.params.get("std", (5, 15))
if isinstance(std, (int, float)):
if std < 0:
raise ValueError("std must be non-negative")
elif isinstance(std, tuple):
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
mean = self.params.get("mean", 0)
std = self.params.get("std", (5, 15))
if isinstance(std, tuple):
std = rng.uniform(std[0], std[1])
# Generate noise
noise = rng.normal(mean, std, image.shape).astype(np.float32)
# Apply noise
noisy = image.astype(np.float32) + noise
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
return AugmentationResult(
image=noisy,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_std": std},
)
def get_preview_params(self) -> dict[str, Any]:
return {"mean": 0, "std": 15}
class SaltPepper(BaseAugmentation):
"""
Adds salt and pepper (impulse) noise to the image.
Simulates defects from damaged sensors or transmission errors.
Very sparse by default to preserve document readability.
Parameters:
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
Default: (0.001, 0.005) for very sparse noise.
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
"""
name = "salt_pepper"
affects_geometry = False
def _validate_params(self) -> None:
amount = self.params.get("amount", (0.001, 0.005))
if isinstance(amount, (int, float)):
if not (0 <= amount <= 1):
raise ValueError("amount must be between 0 and 1")
elif isinstance(amount, tuple):
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
amount = self.params.get("amount", (0.001, 0.005))
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
if isinstance(amount, tuple):
amount = rng.uniform(amount[0], amount[1])
# Copy image
output = image.copy()
h, w = image.shape[:2]
total_pixels = h * w
# Calculate number of salt and pepper pixels
num_salt = int(total_pixels * amount * salt_vs_pepper)
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
# Add salt (white pixels)
if num_salt > 0:
salt_coords = (
rng.integers(0, h, num_salt),
rng.integers(0, w, num_salt),
)
output[salt_coords] = 255
# Add pepper (black pixels)
if num_pepper > 0:
pepper_coords = (
rng.integers(0, h, num_pepper),
rng.integers(0, w, num_pepper),
)
output[pepper_coords] = 0
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"applied_amount": amount},
)
def get_preview_params(self) -> dict[str, Any]:
return {"amount": 0.01, "salt_vs_pepper": 0.5}

View File

@@ -0,0 +1,159 @@
"""
Texture augmentation transforms.
Provides texture effects for document image augmentation:
- PaperTexture: Adds paper grain/texture
- ScannerArtifacts: Adds scanner line and dust artifacts
"""
import cv2
import numpy as np
from shared.augmentation.base import AugmentationResult, BaseAugmentation
class PaperTexture(BaseAugmentation):
"""
Adds paper texture/grain to the image.
Simulates different paper types and ages.
Parameters:
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
"""
name = "paper_texture"
affects_geometry = False
def _validate_params(self) -> None:
intensity = self.params.get("intensity", (0.05, 0.15))
if isinstance(intensity, (int, float)):
if not (0 < intensity <= 1):
raise ValueError("intensity must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
texture_type = self.params.get("texture_type", "random")
intensity = self.params.get("intensity", (0.05, 0.15))
if texture_type == "random":
texture_type = rng.choice(["fine", "coarse"])
if isinstance(intensity, tuple):
intensity = rng.uniform(intensity[0], intensity[1])
# Generate base noise
if texture_type == "fine":
# Fine grain texture
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
noise = cv2.GaussianBlur(noise, (3, 3), 0)
else:
# Coarse texture
# Generate at lower resolution and upscale
small_h, small_w = h // 4, w // 4
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
noise = cv2.GaussianBlur(noise, (5, 5), 0)
# Apply texture
output = image.astype(np.float32)
noise_3d = noise[:, :, np.newaxis] * intensity * 255
output = output + noise_3d
output = np.clip(output, 0, 255).astype(np.uint8)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={"texture_type": texture_type, "intensity": intensity},
)
def get_preview_params(self) -> dict:
return {"texture_type": "coarse", "intensity": 0.15}
class ScannerArtifacts(BaseAugmentation):
"""
Adds scanner artifacts to the image.
Simulates scanner imperfections like lines and dust spots.
Parameters:
line_probability: Probability of adding scan lines (default: 0.3).
dust_probability: Probability of adding dust spots (default: 0.4).
"""
name = "scanner_artifacts"
affects_geometry = False
def _validate_params(self) -> None:
line_prob = self.params.get("line_probability", 0.3)
dust_prob = self.params.get("dust_probability", 0.4)
if not (0 <= line_prob <= 1):
raise ValueError("line_probability must be between 0 and 1")
if not (0 <= dust_prob <= 1):
raise ValueError("dust_probability must be between 0 and 1")
def apply(
self,
image: np.ndarray,
bboxes: np.ndarray | None = None,
rng: np.random.Generator | None = None,
) -> AugmentationResult:
rng = rng or np.random.default_rng()
h, w = image.shape[:2]
line_probability = self.params.get("line_probability", 0.3)
dust_probability = self.params.get("dust_probability", 0.4)
output = image.copy()
# Add scan lines
if rng.random() < line_probability:
num_lines = rng.integers(1, 4)
for _ in range(num_lines):
y = rng.integers(0, h)
thickness = rng.integers(1, 3)
# Light or dark line
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
# Make line partially transparent
alpha = rng.uniform(0.3, 0.6)
for dy in range(thickness):
if y + dy < h:
output[y + dy, :] = (
output[y + dy, :].astype(np.float32) * (1 - alpha) +
color * alpha
).astype(np.uint8)
# Add dust spots
if rng.random() < dust_probability:
num_dust = rng.integers(5, 20)
for _ in range(num_dust):
x = rng.integers(0, w)
y = rng.integers(0, h)
radius = rng.integers(1, 3)
# Dark dust spot
color = rng.integers(50, 120)
cv2.circle(output, (x, y), radius, int(color), -1)
return AugmentationResult(
image=output,
bboxes=bboxes.copy() if bboxes is not None else None,
metadata={
"line_probability": line_probability,
"dust_probability": dust_probability,
},
)
def get_preview_params(self) -> dict:
return {"line_probability": 0.8, "dust_probability": 0.8}

View File

@@ -0,0 +1,5 @@
"""Shared training utilities."""
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]

View File

@@ -0,0 +1,239 @@
"""
Shared YOLO Training Module
Unified training logic for both CLI and Web API.
"""
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
logger = logging.getLogger(__name__)
@dataclass
class TrainingConfig:
"""Training configuration."""
# Model settings
model_path: str = "yolo11n.pt" # Base model or path to trained model
data_yaml: str = "" # Path to data.yaml
# Training hyperparameters
epochs: int = 100
batch_size: int = 16
image_size: int = 640
learning_rate: float = 0.01
device: str = "0"
# Output settings
project: str = "runs/train"
name: str = "invoice_fields"
# Performance settings
workers: int = 4
cache: bool = False
# Resume settings
resume: bool = False
resume_from: str | None = None # Path to checkpoint
# Document-specific augmentation (optimized for invoices)
augmentation: dict[str, Any] = field(default_factory=lambda: {
"degrees": 5.0,
"translate": 0.05,
"scale": 0.2,
"shear": 0.0,
"perspective": 0.0,
"flipud": 0.0,
"fliplr": 0.0,
"mosaic": 0.0,
"mixup": 0.0,
"hsv_h": 0.0,
"hsv_s": 0.1,
"hsv_v": 0.2,
})
@dataclass
class TrainingResult:
"""Training result."""
success: bool
model_path: str | None = None
metrics: dict[str, float] = field(default_factory=dict)
error: str | None = None
save_dir: str | None = None
class YOLOTrainer:
"""Unified YOLO trainer for CLI and Web API."""
def __init__(
self,
config: TrainingConfig,
log_callback: Callable[[str, str], None] | None = None,
):
"""
Initialize trainer.
Args:
config: Training configuration
log_callback: Optional callback for logging (level, message)
"""
self.config = config
self._log_callback = log_callback
def _log(self, level: str, message: str) -> None:
"""Log a message."""
if self._log_callback:
self._log_callback(level, message)
if level == "INFO":
logger.info(message)
elif level == "ERROR":
logger.error(message)
elif level == "WARNING":
logger.warning(message)
def validate_config(self) -> tuple[bool, str | None]:
"""
Validate training configuration.
Returns:
Tuple of (is_valid, error_message)
"""
# Check model path
model_path = Path(self.config.model_path)
if not model_path.suffix == ".pt":
# Could be a model name like "yolo11n.pt" which is downloaded
if not model_path.name.startswith("yolo"):
return False, f"Invalid model: {self.config.model_path}"
elif not model_path.exists():
return False, f"Model file not found: {self.config.model_path}"
# Check data.yaml
if not self.config.data_yaml:
return False, "data_yaml is required"
data_yaml = Path(self.config.data_yaml)
if not data_yaml.exists():
return False, f"data.yaml not found: {self.config.data_yaml}"
return True, None
def train(self) -> TrainingResult:
"""
Run YOLO training.
Returns:
TrainingResult with model path and metrics
"""
try:
from ultralytics import YOLO
except ImportError:
return TrainingResult(
success=False,
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
)
# Validate config
is_valid, error = self.validate_config()
if not is_valid:
return TrainingResult(success=False, error=error)
self._log("INFO", f"Starting YOLO training")
self._log("INFO", f" Model: {self.config.model_path}")
self._log("INFO", f" Data: {self.config.data_yaml}")
self._log("INFO", f" Epochs: {self.config.epochs}")
self._log("INFO", f" Batch size: {self.config.batch_size}")
self._log("INFO", f" Image size: {self.config.image_size}")
try:
# Load model
if self.config.resume and self.config.resume_from:
resume_path = Path(self.config.resume_from)
if resume_path.exists():
self._log("INFO", f"Resuming from: {resume_path}")
model = YOLO(str(resume_path))
else:
model = YOLO(self.config.model_path)
else:
model = YOLO(self.config.model_path)
# Build training arguments
train_args = {
"data": str(Path(self.config.data_yaml).absolute()),
"epochs": self.config.epochs,
"batch": self.config.batch_size,
"imgsz": self.config.image_size,
"lr0": self.config.learning_rate,
"device": self.config.device,
"project": self.config.project,
"name": self.config.name,
"exist_ok": True,
"pretrained": True,
"verbose": True,
"workers": self.config.workers,
"cache": self.config.cache,
"resume": self.config.resume and self.config.resume_from is not None,
}
# Add augmentation settings
train_args.update(self.config.augmentation)
# Train
results = model.train(**train_args)
# Get best model path
best_model = Path(results.save_dir) / "weights" / "best.pt"
# Extract metrics
metrics = {}
if hasattr(results, "results_dict"):
metrics = {
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
"precision": results.results_dict.get("metrics/precision(B)", 0),
"recall": results.results_dict.get("metrics/recall(B)", 0),
}
self._log("INFO", f"Training completed successfully")
self._log("INFO", f" Best model: {best_model}")
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
return TrainingResult(
success=True,
model_path=str(best_model) if best_model.exists() else None,
metrics=metrics,
save_dir=str(results.save_dir),
)
except Exception as e:
self._log("ERROR", f"Training failed: {e}")
return TrainingResult(success=False, error=str(e))
def validate(self, split: str = "val") -> dict[str, float]:
"""
Run validation on trained model.
Args:
split: Dataset split to validate on ("val" or "test")
Returns:
Validation metrics
"""
try:
from ultralytics import YOLO
model = YOLO(self.config.model_path)
metrics = model.val(data=self.config.data_yaml, split=split)
return {
"mAP50": metrics.box.map50,
"mAP50-95": metrics.box.map,
"precision": metrics.box.mp,
"recall": metrics.box.mr,
}
except Exception as e:
self._log("ERROR", f"Validation failed: {e}")
return {}