WIP
This commit is contained in:
24
packages/shared/shared/augmentation/__init__.py
Normal file
24
packages/shared/shared/augmentation/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Document Image Augmentation Module.
|
||||
|
||||
Provides augmentation transformations for training data enhancement,
|
||||
specifically designed for document images (invoices, forms, etc.).
|
||||
|
||||
Key features:
|
||||
- Document-safe augmentations that preserve text readability
|
||||
- Support for both offline preprocessing and runtime augmentation
|
||||
- Bbox-aware geometric transforms
|
||||
- Configurable augmentation pipeline
|
||||
"""
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.dataset_augmenter import DatasetAugmenter
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
"AugmentationParams",
|
||||
"AugmentationResult",
|
||||
"BaseAugmentation",
|
||||
"DatasetAugmenter",
|
||||
]
|
||||
108
packages/shared/shared/augmentation/base.py
Normal file
108
packages/shared/shared/augmentation/base.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Base classes for augmentation transforms.
|
||||
|
||||
Provides abstract base class and result dataclass for all augmentation
|
||||
implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationResult:
|
||||
"""
|
||||
Result of applying an augmentation.
|
||||
|
||||
Attributes:
|
||||
image: The augmented image as numpy array (H, W, C).
|
||||
bboxes: Updated bounding boxes if geometric transform was applied.
|
||||
Format: (N, 5) array with [class_id, x_center, y_center, width, height].
|
||||
transform_matrix: The transformation matrix if applicable (for bbox adjustment).
|
||||
applied: Whether the augmentation was actually applied.
|
||||
metadata: Additional metadata about the augmentation.
|
||||
"""
|
||||
|
||||
image: np.ndarray
|
||||
bboxes: np.ndarray | None = None
|
||||
transform_matrix: np.ndarray | None = None
|
||||
applied: bool = True
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class BaseAugmentation(ABC):
|
||||
"""
|
||||
Abstract base class for all augmentations.
|
||||
|
||||
Subclasses must implement:
|
||||
- _validate_params(): Validate augmentation parameters
|
||||
- apply(): Apply the augmentation to an image
|
||||
|
||||
Class attributes:
|
||||
name: Human-readable name of the augmentation.
|
||||
affects_geometry: True if this augmentation modifies bbox coordinates.
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
affects_geometry: bool = False
|
||||
|
||||
def __init__(self, params: dict[str, Any]) -> None:
|
||||
"""
|
||||
Initialize augmentation with parameters.
|
||||
|
||||
Args:
|
||||
params: Dictionary of augmentation-specific parameters.
|
||||
"""
|
||||
self.params = params
|
||||
self._validate_params()
|
||||
|
||||
@abstractmethod
|
||||
def _validate_params(self) -> None:
|
||||
"""
|
||||
Validate augmentation parameters.
|
||||
|
||||
Raises:
|
||||
ValueError: If parameters are invalid.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
"""
|
||||
Apply augmentation to image.
|
||||
|
||||
IMPORTANT: Implementations must NOT modify the input image or bboxes.
|
||||
Always create copies before modifying.
|
||||
|
||||
Args:
|
||||
image: Input image as numpy array (H, W, C) with dtype uint8.
|
||||
bboxes: Optional bounding boxes in YOLO format (N, 5) array.
|
||||
Each row: [class_id, x_center, y_center, width, height].
|
||||
Coordinates are normalized to 0-1 range.
|
||||
rng: Random number generator for reproducibility.
|
||||
If None, a new generator should be created.
|
||||
|
||||
Returns:
|
||||
AugmentationResult with augmented image and optionally updated bboxes.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get parameters optimized for preview display.
|
||||
|
||||
Override this method to provide parameters that produce
|
||||
clearly visible effects for preview/demo purposes.
|
||||
|
||||
Returns:
|
||||
Dictionary of preview parameters.
|
||||
"""
|
||||
return dict(self.params)
|
||||
274
packages/shared/shared/augmentation/config.py
Normal file
274
packages/shared/shared/augmentation/config.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Augmentation configuration module.
|
||||
|
||||
Provides dataclasses for configuring document image augmentations.
|
||||
All default values are document-safe (conservative) to preserve text readability.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationParams:
|
||||
"""
|
||||
Parameters for a single augmentation type.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether this augmentation is enabled.
|
||||
probability: Probability of applying this augmentation (0.0 to 1.0).
|
||||
params: Type-specific parameters dictionary.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
probability: float = 0.5
|
||||
params: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"enabled": self.enabled,
|
||||
"probability": self.probability,
|
||||
"params": dict(self.params),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AugmentationParams":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
enabled=data.get("enabled", False),
|
||||
probability=data.get("probability", 0.5),
|
||||
params=dict(data.get("params", {})),
|
||||
)
|
||||
|
||||
|
||||
def _default_perspective_warp() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"max_warp": 0.02}, # Very conservative - 2% max distortion
|
||||
)
|
||||
|
||||
|
||||
def _default_wrinkle() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"intensity": 0.3, "num_wrinkles": (2, 5)},
|
||||
)
|
||||
|
||||
|
||||
def _default_edge_damage() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"max_damage_ratio": 0.05}, # Max 5% of edge damaged
|
||||
)
|
||||
|
||||
|
||||
def _default_stain() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={
|
||||
"num_stains": (1, 3),
|
||||
"max_radius_ratio": 0.1,
|
||||
"opacity": (0.1, 0.3),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _default_lighting_variation() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=True, # Safe default, commonly needed
|
||||
probability=0.5,
|
||||
params={
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _default_shadow() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"num_shadows": (1, 2), "opacity": (0.2, 0.4)},
|
||||
)
|
||||
|
||||
|
||||
def _default_gaussian_blur() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||
)
|
||||
|
||||
|
||||
def _default_motion_blur() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"kernel_size": (5, 9), "angle_range": (-45, 45)},
|
||||
)
|
||||
|
||||
|
||||
def _default_gaussian_noise() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"mean": 0, "std": (5, 15)}, # Conservative noise levels
|
||||
)
|
||||
|
||||
|
||||
def _default_salt_pepper() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"amount": (0.001, 0.005)}, # Very sparse
|
||||
)
|
||||
|
||||
|
||||
def _default_paper_texture() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.3,
|
||||
params={"texture_type": "random", "intensity": (0.05, 0.15)},
|
||||
)
|
||||
|
||||
|
||||
def _default_scanner_artifacts() -> AugmentationParams:
|
||||
return AugmentationParams(
|
||||
enabled=False,
|
||||
probability=0.2,
|
||||
params={"line_probability": 0.3, "dust_probability": 0.4},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationConfig:
|
||||
"""
|
||||
Complete augmentation configuration.
|
||||
|
||||
All augmentation types have document-safe defaults that preserve
|
||||
text readability. Only lighting_variation is enabled by default.
|
||||
|
||||
Attributes:
|
||||
perspective_warp: Geometric perspective transform (affects bboxes).
|
||||
wrinkle: Paper wrinkle/crease simulation.
|
||||
edge_damage: Damaged/torn edge effects.
|
||||
stain: Coffee stain/smudge effects.
|
||||
lighting_variation: Brightness and contrast variation.
|
||||
shadow: Shadow overlay effects.
|
||||
gaussian_blur: Gaussian blur for focus issues.
|
||||
motion_blur: Motion blur simulation.
|
||||
gaussian_noise: Gaussian noise for sensor noise.
|
||||
salt_pepper: Salt and pepper noise.
|
||||
paper_texture: Paper texture overlay.
|
||||
scanner_artifacts: Scanner line and dust artifacts.
|
||||
preserve_bboxes: Whether to adjust bboxes for geometric transforms.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
|
||||
# Geometric transforms (affects bboxes)
|
||||
perspective_warp: AugmentationParams = field(
|
||||
default_factory=_default_perspective_warp
|
||||
)
|
||||
|
||||
# Degradation effects
|
||||
wrinkle: AugmentationParams = field(default_factory=_default_wrinkle)
|
||||
edge_damage: AugmentationParams = field(default_factory=_default_edge_damage)
|
||||
stain: AugmentationParams = field(default_factory=_default_stain)
|
||||
|
||||
# Lighting effects
|
||||
lighting_variation: AugmentationParams = field(
|
||||
default_factory=_default_lighting_variation
|
||||
)
|
||||
shadow: AugmentationParams = field(default_factory=_default_shadow)
|
||||
|
||||
# Blur effects
|
||||
gaussian_blur: AugmentationParams = field(default_factory=_default_gaussian_blur)
|
||||
motion_blur: AugmentationParams = field(default_factory=_default_motion_blur)
|
||||
|
||||
# Noise effects
|
||||
gaussian_noise: AugmentationParams = field(default_factory=_default_gaussian_noise)
|
||||
salt_pepper: AugmentationParams = field(default_factory=_default_salt_pepper)
|
||||
|
||||
# Texture effects
|
||||
paper_texture: AugmentationParams = field(default_factory=_default_paper_texture)
|
||||
scanner_artifacts: AugmentationParams = field(
|
||||
default_factory=_default_scanner_artifacts
|
||||
)
|
||||
|
||||
# Global settings
|
||||
preserve_bboxes: bool = True
|
||||
seed: int | None = None
|
||||
|
||||
# List of all augmentation field names
|
||||
_AUGMENTATION_FIELDS: tuple[str, ...] = (
|
||||
"perspective_warp",
|
||||
"wrinkle",
|
||||
"edge_damage",
|
||||
"stain",
|
||||
"lighting_variation",
|
||||
"shadow",
|
||||
"gaussian_blur",
|
||||
"motion_blur",
|
||||
"gaussian_noise",
|
||||
"salt_pepper",
|
||||
"paper_texture",
|
||||
"scanner_artifacts",
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
result: dict[str, Any] = {
|
||||
"preserve_bboxes": self.preserve_bboxes,
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
result[field_name] = params.to_dict()
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "AugmentationConfig":
|
||||
"""Create from dictionary."""
|
||||
kwargs: dict[str, Any] = {
|
||||
"preserve_bboxes": data.get("preserve_bboxes", True),
|
||||
"seed": data.get("seed"),
|
||||
}
|
||||
|
||||
for field_name in cls._AUGMENTATION_FIELDS:
|
||||
if field_name in data:
|
||||
field_data = data[field_name]
|
||||
if isinstance(field_data, dict):
|
||||
kwargs[field_name] = AugmentationParams.from_dict(field_data)
|
||||
|
||||
return cls(**kwargs)
|
||||
|
||||
def get_enabled_augmentations(self) -> list[str]:
|
||||
"""Get list of enabled augmentation names."""
|
||||
enabled = []
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
if params.enabled:
|
||||
enabled.append(field_name)
|
||||
return enabled
|
||||
|
||||
def validate(self) -> None:
|
||||
"""
|
||||
Validate configuration.
|
||||
|
||||
Raises:
|
||||
ValueError: If any configuration value is invalid.
|
||||
"""
|
||||
for field_name in self._AUGMENTATION_FIELDS:
|
||||
params: AugmentationParams = getattr(self, field_name)
|
||||
if not (0.0 <= params.probability <= 1.0):
|
||||
raise ValueError(
|
||||
f"{field_name}.probability must be between 0 and 1, "
|
||||
f"got {params.probability}"
|
||||
)
|
||||
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
206
packages/shared/shared/augmentation/dataset_augmenter.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""
|
||||
Dataset Augmenter Module.
|
||||
|
||||
Applies augmentation pipeline to YOLO datasets,
|
||||
creating new augmented images and label files.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.pipeline import AugmentationPipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetAugmenter:
|
||||
"""
|
||||
Augments YOLO datasets by creating new images and label files.
|
||||
|
||||
Reads images from dataset/images/train/ and labels from dataset/labels/train/,
|
||||
applies augmentation pipeline, and saves augmented versions with "_augN" suffix.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: dict[str, Any],
|
||||
seed: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize augmenter with configuration.
|
||||
|
||||
Args:
|
||||
config: Dictionary mapping augmentation names to their settings.
|
||||
Each augmentation should have 'enabled', 'probability', and 'params'.
|
||||
seed: Random seed for reproducibility.
|
||||
"""
|
||||
self._config_dict = config
|
||||
self._seed = seed
|
||||
self._config = self._build_config(config, seed)
|
||||
|
||||
def _build_config(
|
||||
self,
|
||||
config_dict: dict[str, Any],
|
||||
seed: int | None,
|
||||
) -> AugmentationConfig:
|
||||
"""Build AugmentationConfig from dictionary."""
|
||||
kwargs: dict[str, Any] = {"seed": seed, "preserve_bboxes": True}
|
||||
|
||||
for aug_name, aug_settings in config_dict.items():
|
||||
if aug_name in AugmentationConfig._AUGMENTATION_FIELDS:
|
||||
kwargs[aug_name] = AugmentationParams(
|
||||
enabled=aug_settings.get("enabled", False),
|
||||
probability=aug_settings.get("probability", 0.5),
|
||||
params=aug_settings.get("params", {}),
|
||||
)
|
||||
|
||||
return AugmentationConfig(**kwargs)
|
||||
|
||||
def augment_dataset(
|
||||
self,
|
||||
dataset_path: Path,
|
||||
multiplier: int = 1,
|
||||
split: str = "train",
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Augment a YOLO dataset.
|
||||
|
||||
Args:
|
||||
dataset_path: Path to dataset root (containing images/ and labels/).
|
||||
multiplier: Number of augmented copies per original image.
|
||||
split: Which split to augment (default: "train").
|
||||
|
||||
Returns:
|
||||
Summary dict with original_images, augmented_images, total_images.
|
||||
"""
|
||||
images_dir = dataset_path / "images" / split
|
||||
labels_dir = dataset_path / "labels" / split
|
||||
|
||||
if not images_dir.exists():
|
||||
raise ValueError(f"Images directory not found: {images_dir}")
|
||||
|
||||
# Find all images
|
||||
image_extensions = ("*.png", "*.jpg", "*.jpeg")
|
||||
image_files: list[Path] = []
|
||||
for ext in image_extensions:
|
||||
image_files.extend(images_dir.glob(ext))
|
||||
|
||||
original_count = len(image_files)
|
||||
augmented_count = 0
|
||||
|
||||
if multiplier <= 0:
|
||||
return {
|
||||
"original_images": original_count,
|
||||
"augmented_images": 0,
|
||||
"total_images": original_count,
|
||||
}
|
||||
|
||||
# Process each image
|
||||
for img_path in image_files:
|
||||
# Load image
|
||||
pil_image = Image.open(img_path).convert("RGB")
|
||||
image = np.array(pil_image)
|
||||
|
||||
# Load corresponding label
|
||||
label_path = labels_dir / f"{img_path.stem}.txt"
|
||||
bboxes = self._load_bboxes(label_path) if label_path.exists() else None
|
||||
|
||||
# Create multiple augmented versions
|
||||
for aug_idx in range(multiplier):
|
||||
# Create pipeline with adjusted seed for each augmentation
|
||||
aug_seed = None
|
||||
if self._seed is not None:
|
||||
aug_seed = self._seed + aug_idx + hash(img_path.stem) % 10000
|
||||
|
||||
pipeline = AugmentationPipeline(
|
||||
self._build_config(self._config_dict, aug_seed)
|
||||
)
|
||||
|
||||
# Apply augmentation
|
||||
result = pipeline.apply(image, bboxes)
|
||||
|
||||
# Save augmented image
|
||||
aug_name = f"{img_path.stem}_aug{aug_idx}{img_path.suffix}"
|
||||
aug_img_path = images_dir / aug_name
|
||||
aug_pil = Image.fromarray(result.image)
|
||||
aug_pil.save(aug_img_path)
|
||||
|
||||
# Save augmented label
|
||||
aug_label_path = labels_dir / f"{img_path.stem}_aug{aug_idx}.txt"
|
||||
self._save_bboxes(aug_label_path, result.bboxes)
|
||||
|
||||
augmented_count += 1
|
||||
|
||||
logger.info(
|
||||
"Dataset augmentation complete: %d original, %d augmented",
|
||||
original_count,
|
||||
augmented_count,
|
||||
)
|
||||
|
||||
return {
|
||||
"original_images": original_count,
|
||||
"augmented_images": augmented_count,
|
||||
"total_images": original_count + augmented_count,
|
||||
}
|
||||
|
||||
def _load_bboxes(self, label_path: Path) -> np.ndarray | None:
|
||||
"""
|
||||
Load bounding boxes from YOLO label file.
|
||||
|
||||
Args:
|
||||
label_path: Path to label file.
|
||||
|
||||
Returns:
|
||||
Array of shape (N, 5) with class_id, x_center, y_center, width, height.
|
||||
Returns None if file is empty or doesn't exist.
|
||||
"""
|
||||
if not label_path.exists():
|
||||
return None
|
||||
|
||||
content = label_path.read_text().strip()
|
||||
if not content:
|
||||
return None
|
||||
|
||||
bboxes = []
|
||||
for line in content.split("\n"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 5:
|
||||
class_id = int(parts[0])
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
bboxes.append([class_id, x_center, y_center, width, height])
|
||||
|
||||
if not bboxes:
|
||||
return None
|
||||
|
||||
return np.array(bboxes, dtype=np.float32)
|
||||
|
||||
def _save_bboxes(self, label_path: Path, bboxes: np.ndarray | None) -> None:
|
||||
"""
|
||||
Save bounding boxes to YOLO label file.
|
||||
|
||||
Args:
|
||||
label_path: Path to save label file.
|
||||
bboxes: Array of shape (N, 5) or None for empty labels.
|
||||
"""
|
||||
if bboxes is None or len(bboxes) == 0:
|
||||
label_path.write_text("")
|
||||
return
|
||||
|
||||
lines = []
|
||||
for bbox in bboxes:
|
||||
class_id = int(bbox[0])
|
||||
x_center = bbox[1]
|
||||
y_center = bbox[2]
|
||||
width = bbox[3]
|
||||
height = bbox[4]
|
||||
lines.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
|
||||
|
||||
label_path.write_text("\n".join(lines))
|
||||
184
packages/shared/shared/augmentation/pipeline.py
Normal file
184
packages/shared/shared/augmentation/pipeline.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
Augmentation pipeline module.
|
||||
|
||||
Orchestrates multiple augmentations with proper ordering and
|
||||
provides preview functionality.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
from shared.augmentation.transforms.blur import GaussianBlur, MotionBlur
|
||||
from shared.augmentation.transforms.degradation import EdgeDamage, Stain, Wrinkle
|
||||
from shared.augmentation.transforms.geometric import PerspectiveWarp
|
||||
from shared.augmentation.transforms.lighting import LightingVariation, Shadow
|
||||
from shared.augmentation.transforms.noise import GaussianNoise, SaltPepper
|
||||
from shared.augmentation.transforms.texture import PaperTexture, ScannerArtifacts
|
||||
|
||||
# Registry of augmentation classes
|
||||
AUGMENTATION_REGISTRY: dict[str, type[BaseAugmentation]] = {
|
||||
"perspective_warp": PerspectiveWarp,
|
||||
"wrinkle": Wrinkle,
|
||||
"edge_damage": EdgeDamage,
|
||||
"stain": Stain,
|
||||
"lighting_variation": LightingVariation,
|
||||
"shadow": Shadow,
|
||||
"gaussian_blur": GaussianBlur,
|
||||
"motion_blur": MotionBlur,
|
||||
"gaussian_noise": GaussianNoise,
|
||||
"salt_pepper": SaltPepper,
|
||||
"paper_texture": PaperTexture,
|
||||
"scanner_artifacts": ScannerArtifacts,
|
||||
}
|
||||
|
||||
|
||||
class AugmentationPipeline:
|
||||
"""
|
||||
Orchestrates multiple augmentations with proper ordering.
|
||||
|
||||
Augmentations are applied in the following order:
|
||||
1. Geometric (perspective_warp) - affects bboxes
|
||||
2. Degradation (wrinkle, edge_damage, stain) - visual artifacts
|
||||
3. Lighting (lighting_variation, shadow)
|
||||
4. Texture (paper_texture, scanner_artifacts)
|
||||
5. Blur (gaussian_blur, motion_blur)
|
||||
6. Noise (gaussian_noise, salt_pepper) - applied last
|
||||
"""
|
||||
|
||||
STAGE_ORDER = [
|
||||
"geometric",
|
||||
"degradation",
|
||||
"lighting",
|
||||
"texture",
|
||||
"blur",
|
||||
"noise",
|
||||
]
|
||||
|
||||
STAGE_MAPPING = {
|
||||
"perspective_warp": "geometric",
|
||||
"wrinkle": "degradation",
|
||||
"edge_damage": "degradation",
|
||||
"stain": "degradation",
|
||||
"lighting_variation": "lighting",
|
||||
"shadow": "lighting",
|
||||
"paper_texture": "texture",
|
||||
"scanner_artifacts": "texture",
|
||||
"gaussian_blur": "blur",
|
||||
"motion_blur": "blur",
|
||||
"gaussian_noise": "noise",
|
||||
"salt_pepper": "noise",
|
||||
}
|
||||
|
||||
def __init__(self, config: AugmentationConfig) -> None:
|
||||
"""
|
||||
Initialize pipeline with configuration.
|
||||
|
||||
Args:
|
||||
config: Augmentation configuration.
|
||||
"""
|
||||
self.config = config
|
||||
self._rng = np.random.default_rng(config.seed)
|
||||
self._augmentations = self._build_augmentations()
|
||||
|
||||
def _build_augmentations(
|
||||
self,
|
||||
) -> list[tuple[str, BaseAugmentation, float]]:
|
||||
"""Build ordered list of (name, augmentation, probability) tuples."""
|
||||
augmentations: list[tuple[str, BaseAugmentation, float]] = []
|
||||
|
||||
for aug_name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
params: AugmentationParams = getattr(self.config, aug_name)
|
||||
if params.enabled:
|
||||
aug = aug_class(params.params)
|
||||
augmentations.append((aug_name, aug, params.probability))
|
||||
|
||||
# Sort by stage order
|
||||
def sort_key(item: tuple[str, BaseAugmentation, float]) -> int:
|
||||
name, _, _ = item
|
||||
stage = self.STAGE_MAPPING[name]
|
||||
return self.STAGE_ORDER.index(stage)
|
||||
|
||||
return sorted(augmentations, key=sort_key)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
) -> AugmentationResult:
|
||||
"""
|
||||
Apply augmentation pipeline to image.
|
||||
|
||||
Args:
|
||||
image: Input image (H, W, C) as numpy array with dtype uint8.
|
||||
bboxes: Optional bounding boxes in YOLO format (N, 5).
|
||||
|
||||
Returns:
|
||||
AugmentationResult with augmented image and optionally adjusted bboxes.
|
||||
"""
|
||||
current_image = image.copy()
|
||||
current_bboxes = bboxes.copy() if bboxes is not None else None
|
||||
applied_augmentations: list[str] = []
|
||||
|
||||
for name, aug, probability in self._augmentations:
|
||||
if self._rng.random() < probability:
|
||||
result = aug.apply(current_image, current_bboxes, self._rng)
|
||||
current_image = result.image
|
||||
if result.bboxes is not None and self.config.preserve_bboxes:
|
||||
current_bboxes = result.bboxes
|
||||
applied_augmentations.append(name)
|
||||
|
||||
return AugmentationResult(
|
||||
image=current_image,
|
||||
bboxes=current_bboxes,
|
||||
metadata={"applied_augmentations": applied_augmentations},
|
||||
)
|
||||
|
||||
def preview(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
augmentation_name: str,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Preview a single augmentation deterministically.
|
||||
|
||||
Args:
|
||||
image: Input image.
|
||||
augmentation_name: Name of augmentation to preview.
|
||||
|
||||
Returns:
|
||||
Augmented image.
|
||||
|
||||
Raises:
|
||||
ValueError: If augmentation_name is not recognized.
|
||||
"""
|
||||
if augmentation_name not in AUGMENTATION_REGISTRY:
|
||||
raise ValueError(f"Unknown augmentation: {augmentation_name}")
|
||||
|
||||
params: AugmentationParams = getattr(self.config, augmentation_name)
|
||||
aug = AUGMENTATION_REGISTRY[augmentation_name](params.params)
|
||||
|
||||
# Use deterministic RNG for preview
|
||||
preview_rng = np.random.default_rng(42)
|
||||
result = aug.apply(image.copy(), rng=preview_rng)
|
||||
return result.image
|
||||
|
||||
|
||||
def get_available_augmentations() -> list[dict[str, Any]]:
|
||||
"""
|
||||
Get list of available augmentations with metadata.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with augmentation info.
|
||||
"""
|
||||
augmentations = []
|
||||
for name, aug_class in AUGMENTATION_REGISTRY.items():
|
||||
augmentations.append({
|
||||
"name": name,
|
||||
"description": aug_class.__doc__ or "",
|
||||
"affects_geometry": aug_class.affects_geometry,
|
||||
"stage": AugmentationPipeline.STAGE_MAPPING[name],
|
||||
})
|
||||
return augmentations
|
||||
212
packages/shared/shared/augmentation/presets.py
Normal file
212
packages/shared/shared/augmentation/presets.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Predefined augmentation presets for common document scenarios.
|
||||
|
||||
Presets provide ready-to-use configurations optimized for different
|
||||
use cases, from conservative (preserves text readability) to aggressive
|
||||
(simulates poor document quality).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from shared.augmentation.config import AugmentationConfig, AugmentationParams
|
||||
|
||||
|
||||
PRESETS: dict[str, dict[str, Any]] = {
|
||||
"conservative": {
|
||||
"description": "Safe augmentations that preserve text readability",
|
||||
"config": {
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (3, 10)},
|
||||
},
|
||||
},
|
||||
},
|
||||
"moderate": {
|
||||
"description": "Balanced augmentations for typical document degradation",
|
||||
"config": {
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {
|
||||
"brightness_range": (-0.15, 0.15),
|
||||
"contrast_range": (0.85, 1.15),
|
||||
},
|
||||
},
|
||||
"shadow": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"num_shadows": (1, 2), "opacity": (0.2, 0.35)},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (5, 12)},
|
||||
},
|
||||
"gaussian_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.0)},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"intensity": (0.05, 0.12)},
|
||||
},
|
||||
},
|
||||
},
|
||||
"aggressive": {
|
||||
"description": "Heavy augmentations simulating poor scan quality",
|
||||
"config": {
|
||||
"perspective_warp": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"max_warp": 0.02},
|
||||
},
|
||||
"wrinkle": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": 0.3, "num_wrinkles": (2, 4)},
|
||||
},
|
||||
"stain": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {
|
||||
"num_stains": (1, 2),
|
||||
"max_radius_ratio": 0.08,
|
||||
"opacity": (0.1, 0.25),
|
||||
},
|
||||
},
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.6,
|
||||
"params": {
|
||||
"brightness_range": (-0.2, 0.2),
|
||||
"contrast_range": (0.8, 1.2),
|
||||
},
|
||||
},
|
||||
"shadow": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"num_shadows": (1, 2), "opacity": (0.25, 0.4)},
|
||||
},
|
||||
"gaussian_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"kernel_size": (3, 5), "sigma": (0.5, 1.5)},
|
||||
},
|
||||
"motion_blur": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"kernel_size": (5, 7), "angle_range": (-30, 30)},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"std": (8, 18)},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": (0.08, 0.15)},
|
||||
},
|
||||
"scanner_artifacts": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||
},
|
||||
"edge_damage": {
|
||||
"enabled": True,
|
||||
"probability": 0.2,
|
||||
"params": {"max_damage_ratio": 0.04},
|
||||
},
|
||||
},
|
||||
},
|
||||
"scanned_document": {
|
||||
"description": "Simulates typical scanned document artifacts",
|
||||
"config": {
|
||||
"scanner_artifacts": {
|
||||
"enabled": True,
|
||||
"probability": 0.5,
|
||||
"params": {"line_probability": 0.4, "dust_probability": 0.5},
|
||||
},
|
||||
"paper_texture": {
|
||||
"enabled": True,
|
||||
"probability": 0.4,
|
||||
"params": {"intensity": (0.05, 0.12)},
|
||||
},
|
||||
"lighting_variation": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {
|
||||
"brightness_range": (-0.1, 0.1),
|
||||
"contrast_range": (0.9, 1.1),
|
||||
},
|
||||
},
|
||||
"gaussian_noise": {
|
||||
"enabled": True,
|
||||
"probability": 0.3,
|
||||
"params": {"std": (5, 12)},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_preset_config(preset_name: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get the configuration dictionary for a preset.
|
||||
|
||||
Args:
|
||||
preset_name: Name of the preset.
|
||||
|
||||
Returns:
|
||||
Configuration dictionary.
|
||||
|
||||
Raises:
|
||||
ValueError: If preset is not found.
|
||||
"""
|
||||
if preset_name not in PRESETS:
|
||||
raise ValueError(
|
||||
f"Unknown preset: {preset_name}. "
|
||||
f"Available presets: {list(PRESETS.keys())}"
|
||||
)
|
||||
return PRESETS[preset_name]["config"]
|
||||
|
||||
|
||||
def create_config_from_preset(preset_name: str) -> AugmentationConfig:
|
||||
"""
|
||||
Create an AugmentationConfig from a preset.
|
||||
|
||||
Args:
|
||||
preset_name: Name of the preset.
|
||||
|
||||
Returns:
|
||||
AugmentationConfig instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If preset is not found.
|
||||
"""
|
||||
config_dict = get_preset_config(preset_name)
|
||||
return AugmentationConfig.from_dict(config_dict)
|
||||
|
||||
|
||||
def list_presets() -> list[dict[str, str]]:
|
||||
"""
|
||||
List all available presets.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with name and description.
|
||||
"""
|
||||
return [
|
||||
{"name": name, "description": preset["description"]}
|
||||
for name, preset in PRESETS.items()
|
||||
]
|
||||
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
13
packages/shared/shared/augmentation/transforms/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Augmentation transform implementations.
|
||||
|
||||
Each module contains related augmentation classes:
|
||||
- geometric.py: Perspective warp and other geometric transforms
|
||||
- degradation.py: Wrinkle, edge damage, stain effects
|
||||
- lighting.py: Lighting variation and shadow effects
|
||||
- blur.py: Gaussian and motion blur
|
||||
- noise.py: Gaussian and salt-pepper noise
|
||||
- texture.py: Paper texture and scanner artifacts
|
||||
"""
|
||||
|
||||
# Will be populated as transforms are implemented
|
||||
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
144
packages/shared/shared/augmentation/transforms/blur.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Blur augmentation transforms.
|
||||
|
||||
Provides blur effects for document image augmentation:
|
||||
- GaussianBlur: Simulates out-of-focus capture
|
||||
- MotionBlur: Simulates camera/document movement during capture
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class GaussianBlur(BaseAugmentation):
|
||||
"""
|
||||
Applies Gaussian blur to the image.
|
||||
|
||||
Simulates out-of-focus capture or low-quality optics.
|
||||
Conservative defaults to preserve text readability.
|
||||
|
||||
Parameters:
|
||||
kernel_size: Blur kernel size, int or (min, max) tuple (default: (3, 5)).
|
||||
sigma: Blur sigma, float or (min, max) tuple (default: (0.5, 1.5)).
|
||||
"""
|
||||
|
||||
name = "gaussian_blur"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||
if isinstance(kernel_size, int):
|
||||
if kernel_size < 1 or kernel_size % 2 == 0:
|
||||
raise ValueError("kernel_size must be a positive odd integer")
|
||||
elif isinstance(kernel_size, tuple):
|
||||
if kernel_size[0] < 1 or kernel_size[1] < kernel_size[0]:
|
||||
raise ValueError("kernel_size tuple must be (min, max) with min >= 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
kernel_size = self.params.get("kernel_size", (3, 5))
|
||||
sigma = self.params.get("sigma", (0.5, 1.5))
|
||||
|
||||
if isinstance(kernel_size, tuple):
|
||||
# Choose random odd kernel size
|
||||
min_k, max_k = kernel_size
|
||||
possible_sizes = [k for k in range(min_k, max_k + 1) if k % 2 == 1]
|
||||
if not possible_sizes:
|
||||
possible_sizes = [min_k if min_k % 2 == 1 else min_k + 1]
|
||||
kernel_size = rng.choice(possible_sizes)
|
||||
|
||||
if isinstance(sigma, tuple):
|
||||
sigma = rng.uniform(sigma[0], sigma[1])
|
||||
|
||||
# Ensure kernel size is odd
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size += 1
|
||||
|
||||
# Apply Gaussian blur
|
||||
blurred = cv2.GaussianBlur(image, (kernel_size, kernel_size), sigma)
|
||||
|
||||
return AugmentationResult(
|
||||
image=blurred,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"kernel_size": kernel_size, "sigma": sigma},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"kernel_size": 5, "sigma": 1.5}
|
||||
|
||||
|
||||
class MotionBlur(BaseAugmentation):
|
||||
"""
|
||||
Applies motion blur to the image.
|
||||
|
||||
Simulates camera shake or document movement during capture.
|
||||
|
||||
Parameters:
|
||||
kernel_size: Blur kernel size, int or (min, max) tuple (default: (5, 9)).
|
||||
angle_range: Motion angle range in degrees (default: (-45, 45)).
|
||||
"""
|
||||
|
||||
name = "motion_blur"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||
if isinstance(kernel_size, int):
|
||||
if kernel_size < 3:
|
||||
raise ValueError("kernel_size must be at least 3")
|
||||
elif isinstance(kernel_size, tuple):
|
||||
if kernel_size[0] < 3:
|
||||
raise ValueError("kernel_size min must be at least 3")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
kernel_size = self.params.get("kernel_size", (5, 9))
|
||||
angle_range = self.params.get("angle_range", (-45, 45))
|
||||
|
||||
if isinstance(kernel_size, tuple):
|
||||
kernel_size = rng.integers(kernel_size[0], kernel_size[1] + 1)
|
||||
|
||||
angle = rng.uniform(angle_range[0], angle_range[1])
|
||||
|
||||
# Create motion blur kernel
|
||||
kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
|
||||
|
||||
# Draw a line in the center of the kernel
|
||||
center = kernel_size // 2
|
||||
angle_rad = np.deg2rad(angle)
|
||||
|
||||
for i in range(kernel_size):
|
||||
offset = i - center
|
||||
x = int(center + offset * np.cos(angle_rad))
|
||||
y = int(center + offset * np.sin(angle_rad))
|
||||
if 0 <= x < kernel_size and 0 <= y < kernel_size:
|
||||
kernel[y, x] = 1.0
|
||||
|
||||
# Normalize kernel
|
||||
kernel = kernel / kernel.sum() if kernel.sum() > 0 else kernel
|
||||
|
||||
# Apply motion blur
|
||||
blurred = cv2.filter2D(image, -1, kernel)
|
||||
|
||||
return AugmentationResult(
|
||||
image=blurred,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"kernel_size": kernel_size, "angle": angle},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"kernel_size": 7, "angle_range": (-30, 30)}
|
||||
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
259
packages/shared/shared/augmentation/transforms/degradation.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Degradation augmentation transforms.
|
||||
|
||||
Provides degradation effects for document image augmentation:
|
||||
- Wrinkle: Paper wrinkle/crease simulation
|
||||
- EdgeDamage: Damaged/torn edge effects
|
||||
- Stain: Coffee stain/smudge effects
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class Wrinkle(BaseAugmentation):
|
||||
"""
|
||||
Simulates paper wrinkles/creases using displacement mapping.
|
||||
|
||||
Document-friendly: Uses subtle displacement to preserve text readability.
|
||||
|
||||
Parameters:
|
||||
intensity: Wrinkle intensity (0-1) (default: 0.3).
|
||||
num_wrinkles: Number of wrinkles, int or (min, max) tuple (default: (2, 5)).
|
||||
"""
|
||||
|
||||
name = "wrinkle"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
intensity = self.params.get("intensity", 0.3)
|
||||
if not (0 < intensity <= 1):
|
||||
raise ValueError("intensity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
intensity = self.params.get("intensity", 0.3)
|
||||
num_wrinkles = self.params.get("num_wrinkles", (2, 5))
|
||||
|
||||
if isinstance(num_wrinkles, tuple):
|
||||
num_wrinkles = rng.integers(num_wrinkles[0], num_wrinkles[1] + 1)
|
||||
|
||||
# Create displacement maps
|
||||
displacement_x = np.zeros((h, w), dtype=np.float32)
|
||||
displacement_y = np.zeros((h, w), dtype=np.float32)
|
||||
|
||||
for _ in range(num_wrinkles):
|
||||
# Random wrinkle parameters
|
||||
angle = rng.uniform(0, np.pi)
|
||||
x0 = rng.uniform(0, w)
|
||||
y0 = rng.uniform(0, h)
|
||||
length = rng.uniform(0.3, 0.8) * min(h, w)
|
||||
width = rng.uniform(0.02, 0.05) * min(h, w)
|
||||
|
||||
# Create coordinate grids
|
||||
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
|
||||
|
||||
# Distance from wrinkle line
|
||||
dx = (xx - x0) * np.cos(angle) + (yy - y0) * np.sin(angle)
|
||||
dy = -(xx - x0) * np.sin(angle) + (yy - y0) * np.cos(angle)
|
||||
|
||||
# Gaussian falloff perpendicular to wrinkle
|
||||
mask = np.exp(-dy**2 / (2 * width**2))
|
||||
mask *= (np.abs(dx) < length / 2).astype(np.float32)
|
||||
|
||||
# Displacement perpendicular to wrinkle
|
||||
disp_amount = intensity * rng.uniform(2, 8)
|
||||
displacement_x += mask * disp_amount * np.sin(angle)
|
||||
displacement_y += mask * disp_amount * np.cos(angle)
|
||||
|
||||
# Create remap coordinates
|
||||
map_x = (np.arange(w)[np.newaxis, :] + displacement_x).astype(np.float32)
|
||||
map_y = (np.arange(h)[:, np.newaxis] + displacement_y).astype(np.float32)
|
||||
|
||||
# Apply displacement
|
||||
augmented = cv2.remap(
|
||||
image, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT
|
||||
)
|
||||
|
||||
# Add subtle shading along wrinkles
|
||||
max_disp = np.max(np.abs(displacement_y)) + 1e-6
|
||||
shading = 1 - 0.1 * intensity * np.abs(displacement_y) / max_disp
|
||||
shading = shading[:, :, np.newaxis]
|
||||
augmented = (augmented.astype(np.float32) * shading).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=augmented,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_wrinkles": num_wrinkles, "intensity": intensity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"intensity": 0.5, "num_wrinkles": 3}
|
||||
|
||||
|
||||
class EdgeDamage(BaseAugmentation):
|
||||
"""
|
||||
Adds damaged/torn edge effects to the image.
|
||||
|
||||
Simulates worn or torn document edges.
|
||||
|
||||
Parameters:
|
||||
max_damage_ratio: Maximum proportion of edge to damage (default: 0.05).
|
||||
edges: Which edges to potentially damage (default: all).
|
||||
"""
|
||||
|
||||
name = "edge_damage"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||
if not (0 < max_damage_ratio <= 0.2):
|
||||
raise ValueError("max_damage_ratio must be between 0 and 0.2")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
max_damage_ratio = self.params.get("max_damage_ratio", 0.05)
|
||||
edges = self.params.get("edges", ["top", "bottom", "left", "right"])
|
||||
|
||||
output = image.copy()
|
||||
|
||||
# Select random edge to damage
|
||||
edge = rng.choice(edges)
|
||||
damage_size = int(max_damage_ratio * min(h, w))
|
||||
|
||||
if edge == "top":
|
||||
# Create irregular top edge
|
||||
for x in range(w):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
# Random color (white or darker)
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[:depth, x] = color
|
||||
|
||||
elif edge == "bottom":
|
||||
for x in range(w):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[h - depth:, x] = color
|
||||
|
||||
elif edge == "left":
|
||||
for y in range(h):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[y, :depth] = color
|
||||
|
||||
else: # right
|
||||
for y in range(h):
|
||||
depth = rng.integers(0, damage_size + 1)
|
||||
if depth > 0:
|
||||
color = rng.integers(200, 255) if rng.random() > 0.5 else rng.integers(100, 150)
|
||||
output[y, w - depth:] = color
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"edge": edge, "damage_size": damage_size},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"max_damage_ratio": 0.08}
|
||||
|
||||
|
||||
class Stain(BaseAugmentation):
|
||||
"""
|
||||
Adds coffee stain/smudge effects to the image.
|
||||
|
||||
Simulates accidental stains on documents.
|
||||
|
||||
Parameters:
|
||||
num_stains: Number of stains, int or (min, max) tuple (default: (1, 3)).
|
||||
max_radius_ratio: Maximum stain radius as ratio of image size (default: 0.1).
|
||||
opacity: Stain opacity, float or (min, max) tuple (default: (0.1, 0.3)).
|
||||
"""
|
||||
|
||||
name = "stain"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||
if isinstance(opacity, (int, float)):
|
||||
if not (0 < opacity <= 1):
|
||||
raise ValueError("opacity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
num_stains = self.params.get("num_stains", (1, 3))
|
||||
max_radius_ratio = self.params.get("max_radius_ratio", 0.1)
|
||||
opacity = self.params.get("opacity", (0.1, 0.3))
|
||||
|
||||
if isinstance(num_stains, tuple):
|
||||
num_stains = rng.integers(num_stains[0], num_stains[1] + 1)
|
||||
if isinstance(opacity, tuple):
|
||||
opacity = rng.uniform(opacity[0], opacity[1])
|
||||
|
||||
output = image.astype(np.float32)
|
||||
max_radius = int(max_radius_ratio * min(h, w))
|
||||
|
||||
for _ in range(num_stains):
|
||||
# Random stain position and size
|
||||
cx = rng.integers(max_radius, w - max_radius)
|
||||
cy = rng.integers(max_radius, h - max_radius)
|
||||
radius = rng.integers(max_radius // 3, max_radius)
|
||||
|
||||
# Create stain mask with irregular edges
|
||||
yy, xx = np.ogrid[:h, :w]
|
||||
dist = np.sqrt((xx - cx) ** 2 + (yy - cy) ** 2)
|
||||
|
||||
# Add noise to make edges irregular
|
||||
noise = rng.uniform(0.8, 1.2, (h, w))
|
||||
mask = (dist < radius * noise).astype(np.float32)
|
||||
|
||||
# Blur for soft edges
|
||||
mask = cv2.GaussianBlur(mask, (21, 21), 0)
|
||||
|
||||
# Random stain color (brownish/yellowish)
|
||||
stain_color = np.array([
|
||||
rng.integers(180, 220), # R
|
||||
rng.integers(160, 200), # G
|
||||
rng.integers(120, 160), # B
|
||||
], dtype=np.float32)
|
||||
|
||||
# Apply stain
|
||||
mask_3d = mask[:, :, np.newaxis]
|
||||
output = output * (1 - mask_3d * opacity) + stain_color * mask_3d * opacity
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_stains": num_stains, "opacity": opacity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"num_stains": 2, "max_radius_ratio": 0.1, "opacity": 0.25}
|
||||
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
145
packages/shared/shared/augmentation/transforms/geometric.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""
|
||||
Geometric augmentation transforms.
|
||||
|
||||
Provides geometric transforms for document image augmentation:
|
||||
- PerspectiveWarp: Subtle perspective distortion
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class PerspectiveWarp(BaseAugmentation):
|
||||
"""
|
||||
Applies subtle perspective transformation to the image.
|
||||
|
||||
Simulates viewing document at slight angle. Very conservative
|
||||
by default to preserve text readability.
|
||||
|
||||
IMPORTANT: This transform affects bounding box coordinates.
|
||||
|
||||
Parameters:
|
||||
max_warp: Maximum warp as proportion of image size (default: 0.02).
|
||||
"""
|
||||
|
||||
name = "perspective_warp"
|
||||
affects_geometry = True
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
max_warp = self.params.get("max_warp", 0.02)
|
||||
if not (0 < max_warp <= 0.1):
|
||||
raise ValueError("max_warp must be between 0 and 0.1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
max_warp = self.params.get("max_warp", 0.02)
|
||||
|
||||
# Original corners
|
||||
src_pts = np.float32([
|
||||
[0, 0],
|
||||
[w, 0],
|
||||
[w, h],
|
||||
[0, h],
|
||||
])
|
||||
|
||||
# Add random perturbations to corners
|
||||
max_offset = max_warp * min(h, w)
|
||||
dst_pts = src_pts.copy()
|
||||
for i in range(4):
|
||||
dst_pts[i, 0] += rng.uniform(-max_offset, max_offset)
|
||||
dst_pts[i, 1] += rng.uniform(-max_offset, max_offset)
|
||||
|
||||
# Compute perspective transform matrix
|
||||
transform_matrix = cv2.getPerspectiveTransform(src_pts, dst_pts)
|
||||
|
||||
# Apply perspective transform
|
||||
warped = cv2.warpPerspective(
|
||||
image, transform_matrix, (w, h),
|
||||
borderMode=cv2.BORDER_REPLICATE
|
||||
)
|
||||
|
||||
# Transform bounding boxes if present
|
||||
transformed_bboxes = None
|
||||
if bboxes is not None:
|
||||
transformed_bboxes = self._transform_bboxes(
|
||||
bboxes, transform_matrix, w, h
|
||||
)
|
||||
|
||||
return AugmentationResult(
|
||||
image=warped,
|
||||
bboxes=transformed_bboxes,
|
||||
transform_matrix=transform_matrix,
|
||||
metadata={"max_warp": max_warp},
|
||||
)
|
||||
|
||||
def _transform_bboxes(
|
||||
self,
|
||||
bboxes: np.ndarray,
|
||||
transform_matrix: np.ndarray,
|
||||
w: int,
|
||||
h: int,
|
||||
) -> np.ndarray:
|
||||
"""Transform bounding boxes using perspective matrix."""
|
||||
if len(bboxes) == 0:
|
||||
return bboxes.copy()
|
||||
|
||||
transformed = []
|
||||
for bbox in bboxes:
|
||||
class_id, x_center, y_center, width, height = bbox
|
||||
|
||||
# Convert normalized coords to pixel coords
|
||||
x_center_px = x_center * w
|
||||
y_center_px = y_center * h
|
||||
width_px = width * w
|
||||
height_px = height * h
|
||||
|
||||
# Get corner points
|
||||
x1 = x_center_px - width_px / 2
|
||||
y1 = y_center_px - height_px / 2
|
||||
x2 = x_center_px + width_px / 2
|
||||
y2 = y_center_px + height_px / 2
|
||||
|
||||
# Transform all 4 corners
|
||||
corners = np.float32([
|
||||
[x1, y1],
|
||||
[x2, y1],
|
||||
[x2, y2],
|
||||
[x1, y2],
|
||||
]).reshape(-1, 1, 2)
|
||||
|
||||
transformed_corners = cv2.perspectiveTransform(corners, transform_matrix)
|
||||
transformed_corners = transformed_corners.reshape(-1, 2)
|
||||
|
||||
# Get bounding box of transformed corners
|
||||
new_x1 = np.min(transformed_corners[:, 0])
|
||||
new_y1 = np.min(transformed_corners[:, 1])
|
||||
new_x2 = np.max(transformed_corners[:, 0])
|
||||
new_y2 = np.max(transformed_corners[:, 1])
|
||||
|
||||
# Convert back to normalized center format
|
||||
new_width = (new_x2 - new_x1) / w
|
||||
new_height = (new_y2 - new_y1) / h
|
||||
new_x_center = ((new_x1 + new_x2) / 2) / w
|
||||
new_y_center = ((new_y1 + new_y2) / 2) / h
|
||||
|
||||
# Clamp to valid range
|
||||
new_x_center = np.clip(new_x_center, 0, 1)
|
||||
new_y_center = np.clip(new_y_center, 0, 1)
|
||||
new_width = np.clip(new_width, 0, 1)
|
||||
new_height = np.clip(new_height, 0, 1)
|
||||
|
||||
transformed.append([class_id, new_x_center, new_y_center, new_width, new_height])
|
||||
|
||||
return np.array(transformed, dtype=np.float32)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"max_warp": 0.03}
|
||||
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
167
packages/shared/shared/augmentation/transforms/lighting.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Lighting augmentation transforms.
|
||||
|
||||
Provides lighting effects for document image augmentation:
|
||||
- LightingVariation: Adjusts brightness and contrast
|
||||
- Shadow: Adds shadow overlay effects
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class LightingVariation(BaseAugmentation):
|
||||
"""
|
||||
Adjusts image brightness and contrast.
|
||||
|
||||
Simulates different lighting conditions during document capture.
|
||||
Safe for documents with conservative default parameters.
|
||||
|
||||
Parameters:
|
||||
brightness_range: (min, max) brightness adjustment (default: (-0.1, 0.1)).
|
||||
contrast_range: (min, max) contrast multiplier (default: (0.9, 1.1)).
|
||||
"""
|
||||
|
||||
name = "lighting_variation"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
brightness = self.params.get("brightness_range", (-0.1, 0.1))
|
||||
contrast = self.params.get("contrast_range", (0.9, 1.1))
|
||||
|
||||
if not isinstance(brightness, tuple) or len(brightness) != 2:
|
||||
raise ValueError("brightness_range must be a (min, max) tuple")
|
||||
if not isinstance(contrast, tuple) or len(contrast) != 2:
|
||||
raise ValueError("contrast_range must be a (min, max) tuple")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
brightness_range = self.params.get("brightness_range", (-0.1, 0.1))
|
||||
contrast_range = self.params.get("contrast_range", (0.9, 1.1))
|
||||
|
||||
# Random brightness and contrast
|
||||
brightness = rng.uniform(brightness_range[0], brightness_range[1])
|
||||
contrast = rng.uniform(contrast_range[0], contrast_range[1])
|
||||
|
||||
# Apply adjustments
|
||||
adjusted = image.astype(np.float32)
|
||||
|
||||
# Contrast adjustment (multiply around mean)
|
||||
mean = adjusted.mean()
|
||||
adjusted = (adjusted - mean) * contrast + mean
|
||||
|
||||
# Brightness adjustment (add offset)
|
||||
adjusted = adjusted + brightness * 255
|
||||
|
||||
# Clip and convert back
|
||||
adjusted = np.clip(adjusted, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=adjusted,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"brightness": brightness, "contrast": contrast},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"brightness_range": (-0.15, 0.15), "contrast_range": (0.85, 1.15)}
|
||||
|
||||
|
||||
class Shadow(BaseAugmentation):
|
||||
"""
|
||||
Adds shadow overlay effects to the image.
|
||||
|
||||
Simulates shadows from objects or hands during document capture.
|
||||
|
||||
Parameters:
|
||||
num_shadows: Number of shadow regions, int or (min, max) tuple (default: (1, 2)).
|
||||
opacity: Shadow darkness, float or (min, max) tuple (default: (0.2, 0.4)).
|
||||
"""
|
||||
|
||||
name = "shadow"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||
if isinstance(opacity, (int, float)):
|
||||
if not (0 <= opacity <= 1):
|
||||
raise ValueError("opacity must be between 0 and 1")
|
||||
elif isinstance(opacity, tuple):
|
||||
if not (0 <= opacity[0] <= opacity[1] <= 1):
|
||||
raise ValueError("opacity tuple must be in range [0, 1]")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
num_shadows = self.params.get("num_shadows", (1, 2))
|
||||
opacity = self.params.get("opacity", (0.2, 0.4))
|
||||
|
||||
if isinstance(num_shadows, tuple):
|
||||
num_shadows = rng.integers(num_shadows[0], num_shadows[1] + 1)
|
||||
if isinstance(opacity, tuple):
|
||||
opacity = rng.uniform(opacity[0], opacity[1])
|
||||
|
||||
h, w = image.shape[:2]
|
||||
output = image.astype(np.float32)
|
||||
|
||||
for _ in range(num_shadows):
|
||||
# Generate random shadow polygon
|
||||
num_vertices = rng.integers(3, 6)
|
||||
vertices = []
|
||||
|
||||
# Start from a random edge
|
||||
edge = rng.integers(0, 4)
|
||||
if edge == 0: # Top
|
||||
start = (rng.integers(0, w), 0)
|
||||
elif edge == 1: # Right
|
||||
start = (w, rng.integers(0, h))
|
||||
elif edge == 2: # Bottom
|
||||
start = (rng.integers(0, w), h)
|
||||
else: # Left
|
||||
start = (0, rng.integers(0, h))
|
||||
|
||||
vertices.append(start)
|
||||
|
||||
# Add random vertices
|
||||
for _ in range(num_vertices - 1):
|
||||
x = rng.integers(0, w)
|
||||
y = rng.integers(0, h)
|
||||
vertices.append((x, y))
|
||||
|
||||
# Create shadow mask
|
||||
mask = np.zeros((h, w), dtype=np.float32)
|
||||
pts = np.array(vertices, dtype=np.int32).reshape((-1, 1, 2))
|
||||
cv2.fillPoly(mask, [pts], 1.0)
|
||||
|
||||
# Blur the mask for soft edges
|
||||
blur_size = max(31, min(h, w) // 10)
|
||||
if blur_size % 2 == 0:
|
||||
blur_size += 1
|
||||
mask = cv2.GaussianBlur(mask, (blur_size, blur_size), 0)
|
||||
|
||||
# Apply shadow
|
||||
shadow_factor = 1 - opacity * mask[:, :, np.newaxis]
|
||||
output = output * shadow_factor
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"num_shadows": num_shadows, "opacity": opacity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"num_shadows": 1, "opacity": 0.3}
|
||||
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
142
packages/shared/shared/augmentation/transforms/noise.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Noise augmentation transforms.
|
||||
|
||||
Provides noise effects for document image augmentation:
|
||||
- GaussianNoise: Adds Gaussian noise to simulate sensor noise
|
||||
- SaltPepper: Adds salt and pepper noise for impulse noise effects
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class GaussianNoise(BaseAugmentation):
|
||||
"""
|
||||
Adds Gaussian noise to the image.
|
||||
|
||||
Simulates sensor noise from cameras or scanners.
|
||||
Document-safe with conservative default parameters.
|
||||
|
||||
Parameters:
|
||||
mean: Mean of the Gaussian noise (default: 0).
|
||||
std: Standard deviation, can be int or (min, max) tuple (default: (5, 15)).
|
||||
"""
|
||||
|
||||
name = "gaussian_noise"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
std = self.params.get("std", (5, 15))
|
||||
if isinstance(std, (int, float)):
|
||||
if std < 0:
|
||||
raise ValueError("std must be non-negative")
|
||||
elif isinstance(std, tuple):
|
||||
if len(std) != 2 or std[0] < 0 or std[1] < std[0]:
|
||||
raise ValueError("std tuple must be (min, max) with min <= max >= 0")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
mean = self.params.get("mean", 0)
|
||||
std = self.params.get("std", (5, 15))
|
||||
|
||||
if isinstance(std, tuple):
|
||||
std = rng.uniform(std[0], std[1])
|
||||
|
||||
# Generate noise
|
||||
noise = rng.normal(mean, std, image.shape).astype(np.float32)
|
||||
|
||||
# Apply noise
|
||||
noisy = image.astype(np.float32) + noise
|
||||
noisy = np.clip(noisy, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=noisy,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"applied_std": std},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
return {"mean": 0, "std": 15}
|
||||
|
||||
|
||||
class SaltPepper(BaseAugmentation):
|
||||
"""
|
||||
Adds salt and pepper (impulse) noise to the image.
|
||||
|
||||
Simulates defects from damaged sensors or transmission errors.
|
||||
Very sparse by default to preserve document readability.
|
||||
|
||||
Parameters:
|
||||
amount: Proportion of pixels to affect, can be float or (min, max) tuple.
|
||||
Default: (0.001, 0.005) for very sparse noise.
|
||||
salt_vs_pepper: Ratio of salt to pepper (default: 0.5 for equal amounts).
|
||||
"""
|
||||
|
||||
name = "salt_pepper"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
amount = self.params.get("amount", (0.001, 0.005))
|
||||
if isinstance(amount, (int, float)):
|
||||
if not (0 <= amount <= 1):
|
||||
raise ValueError("amount must be between 0 and 1")
|
||||
elif isinstance(amount, tuple):
|
||||
if len(amount) != 2 or not (0 <= amount[0] <= amount[1] <= 1):
|
||||
raise ValueError("amount tuple must be (min, max) in range [0, 1]")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
amount = self.params.get("amount", (0.001, 0.005))
|
||||
salt_vs_pepper = self.params.get("salt_vs_pepper", 0.5)
|
||||
|
||||
if isinstance(amount, tuple):
|
||||
amount = rng.uniform(amount[0], amount[1])
|
||||
|
||||
# Copy image
|
||||
output = image.copy()
|
||||
h, w = image.shape[:2]
|
||||
total_pixels = h * w
|
||||
|
||||
# Calculate number of salt and pepper pixels
|
||||
num_salt = int(total_pixels * amount * salt_vs_pepper)
|
||||
num_pepper = int(total_pixels * amount * (1 - salt_vs_pepper))
|
||||
|
||||
# Add salt (white pixels)
|
||||
if num_salt > 0:
|
||||
salt_coords = (
|
||||
rng.integers(0, h, num_salt),
|
||||
rng.integers(0, w, num_salt),
|
||||
)
|
||||
output[salt_coords] = 255
|
||||
|
||||
# Add pepper (black pixels)
|
||||
if num_pepper > 0:
|
||||
pepper_coords = (
|
||||
rng.integers(0, h, num_pepper),
|
||||
rng.integers(0, w, num_pepper),
|
||||
)
|
||||
output[pepper_coords] = 0
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"applied_amount": amount},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict[str, Any]:
|
||||
return {"amount": 0.01, "salt_vs_pepper": 0.5}
|
||||
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
159
packages/shared/shared/augmentation/transforms/texture.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Texture augmentation transforms.
|
||||
|
||||
Provides texture effects for document image augmentation:
|
||||
- PaperTexture: Adds paper grain/texture
|
||||
- ScannerArtifacts: Adds scanner line and dust artifacts
|
||||
"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from shared.augmentation.base import AugmentationResult, BaseAugmentation
|
||||
|
||||
|
||||
class PaperTexture(BaseAugmentation):
|
||||
"""
|
||||
Adds paper texture/grain to the image.
|
||||
|
||||
Simulates different paper types and ages.
|
||||
|
||||
Parameters:
|
||||
texture_type: Type of texture ("random", "fine", "coarse") (default: "random").
|
||||
intensity: Texture intensity, float or (min, max) tuple (default: (0.05, 0.15)).
|
||||
"""
|
||||
|
||||
name = "paper_texture"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||
if isinstance(intensity, (int, float)):
|
||||
if not (0 < intensity <= 1):
|
||||
raise ValueError("intensity must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
texture_type = self.params.get("texture_type", "random")
|
||||
intensity = self.params.get("intensity", (0.05, 0.15))
|
||||
|
||||
if texture_type == "random":
|
||||
texture_type = rng.choice(["fine", "coarse"])
|
||||
|
||||
if isinstance(intensity, tuple):
|
||||
intensity = rng.uniform(intensity[0], intensity[1])
|
||||
|
||||
# Generate base noise
|
||||
if texture_type == "fine":
|
||||
# Fine grain texture
|
||||
noise = rng.uniform(-1, 1, (h, w)).astype(np.float32)
|
||||
noise = cv2.GaussianBlur(noise, (3, 3), 0)
|
||||
else:
|
||||
# Coarse texture
|
||||
# Generate at lower resolution and upscale
|
||||
small_h, small_w = h // 4, w // 4
|
||||
noise = rng.uniform(-1, 1, (small_h, small_w)).astype(np.float32)
|
||||
noise = cv2.resize(noise, (w, h), interpolation=cv2.INTER_LINEAR)
|
||||
noise = cv2.GaussianBlur(noise, (5, 5), 0)
|
||||
|
||||
# Apply texture
|
||||
output = image.astype(np.float32)
|
||||
noise_3d = noise[:, :, np.newaxis] * intensity * 255
|
||||
output = output + noise_3d
|
||||
|
||||
output = np.clip(output, 0, 255).astype(np.uint8)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={"texture_type": texture_type, "intensity": intensity},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"texture_type": "coarse", "intensity": 0.15}
|
||||
|
||||
|
||||
class ScannerArtifacts(BaseAugmentation):
|
||||
"""
|
||||
Adds scanner artifacts to the image.
|
||||
|
||||
Simulates scanner imperfections like lines and dust spots.
|
||||
|
||||
Parameters:
|
||||
line_probability: Probability of adding scan lines (default: 0.3).
|
||||
dust_probability: Probability of adding dust spots (default: 0.4).
|
||||
"""
|
||||
|
||||
name = "scanner_artifacts"
|
||||
affects_geometry = False
|
||||
|
||||
def _validate_params(self) -> None:
|
||||
line_prob = self.params.get("line_probability", 0.3)
|
||||
dust_prob = self.params.get("dust_probability", 0.4)
|
||||
if not (0 <= line_prob <= 1):
|
||||
raise ValueError("line_probability must be between 0 and 1")
|
||||
if not (0 <= dust_prob <= 1):
|
||||
raise ValueError("dust_probability must be between 0 and 1")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
bboxes: np.ndarray | None = None,
|
||||
rng: np.random.Generator | None = None,
|
||||
) -> AugmentationResult:
|
||||
rng = rng or np.random.default_rng()
|
||||
|
||||
h, w = image.shape[:2]
|
||||
line_probability = self.params.get("line_probability", 0.3)
|
||||
dust_probability = self.params.get("dust_probability", 0.4)
|
||||
|
||||
output = image.copy()
|
||||
|
||||
# Add scan lines
|
||||
if rng.random() < line_probability:
|
||||
num_lines = rng.integers(1, 4)
|
||||
for _ in range(num_lines):
|
||||
y = rng.integers(0, h)
|
||||
thickness = rng.integers(1, 3)
|
||||
# Light or dark line
|
||||
color = rng.integers(200, 240) if rng.random() > 0.5 else rng.integers(50, 100)
|
||||
|
||||
# Make line partially transparent
|
||||
alpha = rng.uniform(0.3, 0.6)
|
||||
for dy in range(thickness):
|
||||
if y + dy < h:
|
||||
output[y + dy, :] = (
|
||||
output[y + dy, :].astype(np.float32) * (1 - alpha) +
|
||||
color * alpha
|
||||
).astype(np.uint8)
|
||||
|
||||
# Add dust spots
|
||||
if rng.random() < dust_probability:
|
||||
num_dust = rng.integers(5, 20)
|
||||
for _ in range(num_dust):
|
||||
x = rng.integers(0, w)
|
||||
y = rng.integers(0, h)
|
||||
radius = rng.integers(1, 3)
|
||||
|
||||
# Dark dust spot
|
||||
color = rng.integers(50, 120)
|
||||
cv2.circle(output, (x, y), radius, int(color), -1)
|
||||
|
||||
return AugmentationResult(
|
||||
image=output,
|
||||
bboxes=bboxes.copy() if bboxes is not None else None,
|
||||
metadata={
|
||||
"line_probability": line_probability,
|
||||
"dust_probability": dust_probability,
|
||||
},
|
||||
)
|
||||
|
||||
def get_preview_params(self) -> dict:
|
||||
return {"line_probability": 0.8, "dust_probability": 0.8}
|
||||
5
packages/shared/shared/training/__init__.py
Normal file
5
packages/shared/shared/training/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Shared training utilities."""
|
||||
|
||||
from .yolo_trainer import YOLOTrainer, TrainingConfig, TrainingResult
|
||||
|
||||
__all__ = ["YOLOTrainer", "TrainingConfig", "TrainingResult"]
|
||||
239
packages/shared/shared/training/yolo_trainer.py
Normal file
239
packages/shared/shared/training/yolo_trainer.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Shared YOLO Training Module
|
||||
|
||||
Unified training logic for both CLI and Web API.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
"""Training configuration."""
|
||||
|
||||
# Model settings
|
||||
model_path: str = "yolo11n.pt" # Base model or path to trained model
|
||||
data_yaml: str = "" # Path to data.yaml
|
||||
|
||||
# Training hyperparameters
|
||||
epochs: int = 100
|
||||
batch_size: int = 16
|
||||
image_size: int = 640
|
||||
learning_rate: float = 0.01
|
||||
device: str = "0"
|
||||
|
||||
# Output settings
|
||||
project: str = "runs/train"
|
||||
name: str = "invoice_fields"
|
||||
|
||||
# Performance settings
|
||||
workers: int = 4
|
||||
cache: bool = False
|
||||
|
||||
# Resume settings
|
||||
resume: bool = False
|
||||
resume_from: str | None = None # Path to checkpoint
|
||||
|
||||
# Document-specific augmentation (optimized for invoices)
|
||||
augmentation: dict[str, Any] = field(default_factory=lambda: {
|
||||
"degrees": 5.0,
|
||||
"translate": 0.05,
|
||||
"scale": 0.2,
|
||||
"shear": 0.0,
|
||||
"perspective": 0.0,
|
||||
"flipud": 0.0,
|
||||
"fliplr": 0.0,
|
||||
"mosaic": 0.0,
|
||||
"mixup": 0.0,
|
||||
"hsv_h": 0.0,
|
||||
"hsv_s": 0.1,
|
||||
"hsv_v": 0.2,
|
||||
})
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingResult:
|
||||
"""Training result."""
|
||||
|
||||
success: bool
|
||||
model_path: str | None = None
|
||||
metrics: dict[str, float] = field(default_factory=dict)
|
||||
error: str | None = None
|
||||
save_dir: str | None = None
|
||||
|
||||
|
||||
class YOLOTrainer:
|
||||
"""Unified YOLO trainer for CLI and Web API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: TrainingConfig,
|
||||
log_callback: Callable[[str, str], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize trainer.
|
||||
|
||||
Args:
|
||||
config: Training configuration
|
||||
log_callback: Optional callback for logging (level, message)
|
||||
"""
|
||||
self.config = config
|
||||
self._log_callback = log_callback
|
||||
|
||||
def _log(self, level: str, message: str) -> None:
|
||||
"""Log a message."""
|
||||
if self._log_callback:
|
||||
self._log_callback(level, message)
|
||||
if level == "INFO":
|
||||
logger.info(message)
|
||||
elif level == "ERROR":
|
||||
logger.error(message)
|
||||
elif level == "WARNING":
|
||||
logger.warning(message)
|
||||
|
||||
def validate_config(self) -> tuple[bool, str | None]:
|
||||
"""
|
||||
Validate training configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
# Check model path
|
||||
model_path = Path(self.config.model_path)
|
||||
if not model_path.suffix == ".pt":
|
||||
# Could be a model name like "yolo11n.pt" which is downloaded
|
||||
if not model_path.name.startswith("yolo"):
|
||||
return False, f"Invalid model: {self.config.model_path}"
|
||||
elif not model_path.exists():
|
||||
return False, f"Model file not found: {self.config.model_path}"
|
||||
|
||||
# Check data.yaml
|
||||
if not self.config.data_yaml:
|
||||
return False, "data_yaml is required"
|
||||
data_yaml = Path(self.config.data_yaml)
|
||||
if not data_yaml.exists():
|
||||
return False, f"data.yaml not found: {self.config.data_yaml}"
|
||||
|
||||
return True, None
|
||||
|
||||
def train(self) -> TrainingResult:
|
||||
"""
|
||||
Run YOLO training.
|
||||
|
||||
Returns:
|
||||
TrainingResult with model path and metrics
|
||||
"""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
except ImportError:
|
||||
return TrainingResult(
|
||||
success=False,
|
||||
error="Ultralytics (YOLO) not installed. Install with: pip install ultralytics",
|
||||
)
|
||||
|
||||
# Validate config
|
||||
is_valid, error = self.validate_config()
|
||||
if not is_valid:
|
||||
return TrainingResult(success=False, error=error)
|
||||
|
||||
self._log("INFO", f"Starting YOLO training")
|
||||
self._log("INFO", f" Model: {self.config.model_path}")
|
||||
self._log("INFO", f" Data: {self.config.data_yaml}")
|
||||
self._log("INFO", f" Epochs: {self.config.epochs}")
|
||||
self._log("INFO", f" Batch size: {self.config.batch_size}")
|
||||
self._log("INFO", f" Image size: {self.config.image_size}")
|
||||
|
||||
try:
|
||||
# Load model
|
||||
if self.config.resume and self.config.resume_from:
|
||||
resume_path = Path(self.config.resume_from)
|
||||
if resume_path.exists():
|
||||
self._log("INFO", f"Resuming from: {resume_path}")
|
||||
model = YOLO(str(resume_path))
|
||||
else:
|
||||
model = YOLO(self.config.model_path)
|
||||
else:
|
||||
model = YOLO(self.config.model_path)
|
||||
|
||||
# Build training arguments
|
||||
train_args = {
|
||||
"data": str(Path(self.config.data_yaml).absolute()),
|
||||
"epochs": self.config.epochs,
|
||||
"batch": self.config.batch_size,
|
||||
"imgsz": self.config.image_size,
|
||||
"lr0": self.config.learning_rate,
|
||||
"device": self.config.device,
|
||||
"project": self.config.project,
|
||||
"name": self.config.name,
|
||||
"exist_ok": True,
|
||||
"pretrained": True,
|
||||
"verbose": True,
|
||||
"workers": self.config.workers,
|
||||
"cache": self.config.cache,
|
||||
"resume": self.config.resume and self.config.resume_from is not None,
|
||||
}
|
||||
|
||||
# Add augmentation settings
|
||||
train_args.update(self.config.augmentation)
|
||||
|
||||
# Train
|
||||
results = model.train(**train_args)
|
||||
|
||||
# Get best model path
|
||||
best_model = Path(results.save_dir) / "weights" / "best.pt"
|
||||
|
||||
# Extract metrics
|
||||
metrics = {}
|
||||
if hasattr(results, "results_dict"):
|
||||
metrics = {
|
||||
"mAP50": results.results_dict.get("metrics/mAP50(B)", 0),
|
||||
"mAP50-95": results.results_dict.get("metrics/mAP50-95(B)", 0),
|
||||
"precision": results.results_dict.get("metrics/precision(B)", 0),
|
||||
"recall": results.results_dict.get("metrics/recall(B)", 0),
|
||||
}
|
||||
|
||||
self._log("INFO", f"Training completed successfully")
|
||||
self._log("INFO", f" Best model: {best_model}")
|
||||
self._log("INFO", f" mAP@0.5: {metrics.get('mAP50', 'N/A')}")
|
||||
|
||||
return TrainingResult(
|
||||
success=True,
|
||||
model_path=str(best_model) if best_model.exists() else None,
|
||||
metrics=metrics,
|
||||
save_dir=str(results.save_dir),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._log("ERROR", f"Training failed: {e}")
|
||||
return TrainingResult(success=False, error=str(e))
|
||||
|
||||
def validate(self, split: str = "val") -> dict[str, float]:
|
||||
"""
|
||||
Run validation on trained model.
|
||||
|
||||
Args:
|
||||
split: Dataset split to validate on ("val" or "test")
|
||||
|
||||
Returns:
|
||||
Validation metrics
|
||||
"""
|
||||
try:
|
||||
from ultralytics import YOLO
|
||||
|
||||
model = YOLO(self.config.model_path)
|
||||
metrics = model.val(data=self.config.data_yaml, split=split)
|
||||
|
||||
return {
|
||||
"mAP50": metrics.box.map50,
|
||||
"mAP50-95": metrics.box.map,
|
||||
"precision": metrics.box.mp,
|
||||
"recall": metrics.box.mr,
|
||||
}
|
||||
except Exception as e:
|
||||
self._log("ERROR", f"Validation failed: {e}")
|
||||
return {}
|
||||
Reference in New Issue
Block a user