111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
"""
|
|
GPU Worker Pool for OCR processing.
|
|
|
|
This pool handles GPU-bound tasks like PaddleOCR for scanned PDF processing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
from typing import Any, Callable, Optional
|
|
|
|
from training.processing.worker_pool import WorkerPool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global OCR instance for GPU workers (initialized once per process)
|
|
_ocr_instance: Optional[Any] = None
|
|
_gpu_initialized: bool = False
|
|
|
|
|
|
def _init_gpu_worker(gpu_id: int = 0) -> None:
|
|
"""
|
|
Initialize a GPU worker process with PaddleOCR.
|
|
|
|
Args:
|
|
gpu_id: GPU device ID to use.
|
|
"""
|
|
global _ocr_instance, _gpu_initialized
|
|
|
|
# Set GPU device before importing paddle
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
|
|
|
|
# Reduce logging noise
|
|
os.environ["GLOG_minloglevel"] = "2"
|
|
|
|
# Suppress PaddleX warnings
|
|
import warnings
|
|
warnings.filterwarnings("ignore", message=".*PDX has already been initialized.*")
|
|
warnings.filterwarnings("ignore", message=".*reinitialization.*")
|
|
|
|
try:
|
|
# Import PaddleOCR after setting environment
|
|
# PaddleOCR 3.x uses paddle.set_device() for GPU control, not use_gpu param
|
|
import paddle
|
|
paddle.set_device(f"gpu:{gpu_id}")
|
|
|
|
from paddleocr import PaddleOCR
|
|
|
|
# PaddleOCR 3.x init - minimal params, GPU controlled via paddle.set_device
|
|
_ocr_instance = PaddleOCR(lang="en")
|
|
_gpu_initialized = True
|
|
logger.info(f"GPU worker initialized on GPU {gpu_id} in process {os.getpid()}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize GPU worker: {e}")
|
|
raise
|
|
|
|
|
|
def get_ocr_instance() -> Any:
|
|
"""
|
|
Get the initialized OCR instance for the current worker.
|
|
|
|
Returns:
|
|
PaddleOCR instance.
|
|
|
|
Raises:
|
|
RuntimeError: If OCR is not initialized.
|
|
"""
|
|
global _ocr_instance
|
|
if _ocr_instance is None:
|
|
raise RuntimeError("OCR not initialized. This function must be called from a GPU worker.")
|
|
return _ocr_instance
|
|
|
|
|
|
class GPUWorkerPool(WorkerPool):
|
|
"""
|
|
Worker pool for GPU-bound OCR tasks.
|
|
|
|
Handles scanned PDF processing using PaddleOCR with GPU acceleration.
|
|
Typically limited to 1 worker to avoid GPU memory conflicts.
|
|
|
|
Example:
|
|
with GPUWorkerPool(max_workers=1, gpu_id=0) as pool:
|
|
future = pool.submit(process_scanned_pdf, pdf_path)
|
|
result = future.result()
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_workers: int = 1,
|
|
gpu_id: int = 0,
|
|
) -> None:
|
|
"""
|
|
Initialize GPU worker pool.
|
|
|
|
Args:
|
|
max_workers: Number of GPU worker processes.
|
|
Defaults to 1 to avoid GPU memory conflicts.
|
|
gpu_id: GPU device ID to use.
|
|
"""
|
|
super().__init__(max_workers=max_workers, use_gpu=True, gpu_id=gpu_id)
|
|
|
|
def get_initializer(self) -> Optional[Callable[..., None]]:
|
|
"""Return the GPU worker initializer."""
|
|
return _init_gpu_worker
|
|
|
|
def get_init_args(self) -> tuple:
|
|
"""Return args for GPU initializer."""
|
|
return (self.gpu_id,)
|