restructure project
This commit is contained in:
0
packages/inference/inference/cli/__init__.py
Normal file
0
packages/inference/inference/cli/__init__.py
Normal file
141
packages/inference/inference/cli/infer.py
Normal file
141
packages/inference/inference/cli/infer.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Inference CLI
|
||||
|
||||
Runs inference on new PDFs to extract invoice data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Extract invoice data from PDFs using trained model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--model', '-m',
|
||||
required=True,
|
||||
help='Path to trained YOLO model (.pt file)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input', '-i',
|
||||
required=True,
|
||||
help='Input PDF file or directory'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', '-o',
|
||||
help='Output JSON file (default: stdout)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--confidence',
|
||||
type=float,
|
||||
default=0.5,
|
||||
help='Detection confidence threshold (default: 0.5)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dpi',
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f'DPI for PDF rendering (default: {DEFAULT_DPI}, must match training)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no-fallback',
|
||||
action='store_true',
|
||||
help='Disable fallback OCR'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--lang',
|
||||
default='en',
|
||||
help='OCR language (default: en)'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--gpu',
|
||||
action='store_true',
|
||||
help='Use GPU'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--verbose', '-v',
|
||||
action='store_true',
|
||||
help='Verbose output'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
input_path = Path(args.input)
|
||||
if input_path.is_file():
|
||||
pdf_files = [input_path]
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
|
||||
from inference.pipeline import InferencePipeline
|
||||
|
||||
# Initialize pipeline
|
||||
pipeline = InferencePipeline(
|
||||
model_path=model_path,
|
||||
confidence_threshold=args.confidence,
|
||||
ocr_lang=args.lang,
|
||||
use_gpu=args.gpu,
|
||||
dpi=args.dpi,
|
||||
enable_fallback=not args.no_fallback
|
||||
)
|
||||
|
||||
# Process files
|
||||
results = []
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
output = results[0]
|
||||
else:
|
||||
output = results
|
||||
|
||||
json_output = json.dumps(output, indent=2, ensure_ascii=False)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
else:
|
||||
print(json_output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
159
packages/inference/inference/cli/serve.py
Normal file
159
packages/inference/inference/cli/serve.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
Web Server CLI
|
||||
|
||||
Command-line interface for starting the web server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
|
||||
|
||||
def setup_logging(debug: bool = False) -> None:
|
||||
"""Configure logging."""
|
||||
level = logging.DEBUG if debug else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""Parse command-line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Start the Invoice Field Extraction web server",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="0.0.0.0",
|
||||
help="Host to bind to",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Port to listen on",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
type=Path,
|
||||
default=Path("runs/train/invoice_fields/weights/best.pt"),
|
||||
help="Path to YOLO model weights",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--confidence",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Detection confidence threshold",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dpi",
|
||||
type=int,
|
||||
default=DEFAULT_DPI,
|
||||
help=f"DPI for PDF rendering (default: {DEFAULT_DPI}, must match training DPI)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-gpu",
|
||||
action="store_true",
|
||||
help="Disable GPU acceleration",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--reload",
|
||||
action="store_true",
|
||||
help="Enable auto-reload for development",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of worker processes",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point."""
|
||||
args = parse_args()
|
||||
setup_logging(debug=args.debug)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Validate model path
|
||||
if not args.model.exists():
|
||||
logger.error(f"Model file not found: {args.model}")
|
||||
sys.exit(1)
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("Invoice Field Extraction Web Server")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"Model: {args.model}")
|
||||
logger.info(f"Confidence threshold: {args.confidence}")
|
||||
logger.info(f"GPU enabled: {not args.no_gpu}")
|
||||
logger.info(f"Server: http://{args.host}:{args.port}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Create config
|
||||
from inference.web.config import AppConfig, ModelConfig, ServerConfig, StorageConfig
|
||||
|
||||
config = AppConfig(
|
||||
model=ModelConfig(
|
||||
model_path=args.model,
|
||||
confidence_threshold=args.confidence,
|
||||
use_gpu=not args.no_gpu,
|
||||
dpi=args.dpi,
|
||||
),
|
||||
server=ServerConfig(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
debug=args.debug,
|
||||
reload=args.reload,
|
||||
workers=args.workers,
|
||||
),
|
||||
storage=StorageConfig(),
|
||||
)
|
||||
|
||||
# Create and run app
|
||||
import uvicorn
|
||||
from inference.web.app import create_app
|
||||
|
||||
app = create_app(config)
|
||||
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=config.server.host,
|
||||
port=config.server.port,
|
||||
reload=config.server.reload,
|
||||
workers=config.server.workers if not config.server.reload else 1,
|
||||
log_level="debug" if config.server.debug else "info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user