Compare commits
7 Commits
feature/pa
...
f1a7bfe6b7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1a7bfe6b7 | ||
|
|
0990239e9c | ||
|
|
8723ef4653 | ||
|
|
c2c8f2dd04 | ||
|
|
4c7fc3015c | ||
|
|
183d3503ef | ||
|
|
729d96f59e |
@@ -29,57 +29,38 @@ wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoic
|
||||
- No print() in production - use logging
|
||||
- Run tests: `pytest --cov=src`
|
||||
|
||||
## File Structure
|
||||
## Critical Rules
|
||||
|
||||
```
|
||||
src/
|
||||
├── cli/ # autolabel, train, infer, serve
|
||||
├── pdf/ # extractor, renderer, detector
|
||||
├── ocr/ # PaddleOCR wrapper, machine_code_parser
|
||||
├── inference/ # pipeline, yolo_detector, field_extractor
|
||||
├── normalize/ # Per-field normalizers
|
||||
├── matcher/ # Exact, substring, fuzzy strategies
|
||||
├── processing/ # CPU/GPU pool architecture
|
||||
├── web/ # FastAPI app, routes, services, schemas
|
||||
├── utils/ # validators, text_cleaner, fuzzy_matcher
|
||||
└── data/ # Database operations
|
||||
tests/ # Mirror of src structure
|
||||
runs/train/ # Training outputs
|
||||
```
|
||||
### Code Organization
|
||||
|
||||
## Supported Fields
|
||||
- Many small files over few large files
|
||||
- High cohesion, low coupling
|
||||
- 200-400 lines typical, 800 max per file
|
||||
- Organize by feature/domain, not by type
|
||||
|
||||
| ID | Field | Description |
|
||||
|----|-------|-------------|
|
||||
| 0 | invoice_number | Invoice number |
|
||||
| 1 | invoice_date | Invoice date |
|
||||
| 2 | invoice_due_date | Due date |
|
||||
| 3 | ocr_number | OCR reference (Swedish payment) |
|
||||
| 4 | bankgiro | Bankgiro account |
|
||||
| 5 | plusgiro | Plusgiro account |
|
||||
| 6 | amount | Amount |
|
||||
| 7 | supplier_organisation_number | Supplier org number |
|
||||
| 8 | payment_line | Payment line (machine-readable) |
|
||||
| 9 | customer_number | Customer number |
|
||||
### Code Style
|
||||
|
||||
## Key Patterns
|
||||
- No emojis in code, comments, or documentation
|
||||
- Immutability always - never mutate objects or arrays
|
||||
- No console.log in production code
|
||||
- Proper error handling with try/catch
|
||||
- Input validation with Zod or similar
|
||||
|
||||
### Inference Result
|
||||
### Testing
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
document_id: str
|
||||
document_type: str # "invoice" or "letter"
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
cross_validation: CrossValidationResult | None
|
||||
processing_time_ms: float
|
||||
```
|
||||
- TDD: Write tests first
|
||||
- 80% minimum coverage
|
||||
- Unit tests for utilities
|
||||
- Integration tests for APIs
|
||||
- E2E tests for critical flows
|
||||
|
||||
### API Schemas
|
||||
### Security
|
||||
|
||||
See `src/web/schemas.py` for request/response models.
|
||||
- No hardcoded secrets
|
||||
- Environment variables for sensitive data
|
||||
- Validate all user inputs
|
||||
- Parameterized queries only
|
||||
- CSRF protection enabled
|
||||
|
||||
## Environment Variables
|
||||
|
||||
@@ -97,47 +78,16 @@ CONFIDENCE_THRESHOLD=0.5
|
||||
SERVER_HOST=0.0.0.0
|
||||
SERVER_PORT=8000
|
||||
```
|
||||
## Available Commands
|
||||
|
||||
## CLI Commands
|
||||
- `/tdd` - Test-driven development workflow
|
||||
- `/plan` - Create implementation plan
|
||||
- `/code-review` - Review code quality
|
||||
- `/build-fix` - Fix build errors
|
||||
|
||||
```bash
|
||||
# Auto-labeling
|
||||
python -m src.cli.autolabel --dual-pool --cpu-workers 3 --gpu-workers 1
|
||||
## Git Workflow
|
||||
|
||||
# Training
|
||||
python -m src.cli.train --model yolo11n.pt --epochs 100 --batch 16 --name invoice_fields
|
||||
|
||||
# Inference
|
||||
python -m src.cli.infer --model runs/train/invoice_fields/weights/best.pt --input invoice.pdf --gpu
|
||||
|
||||
# Web Server
|
||||
python run_server.py --port 8000
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| GET | `/` | Web UI |
|
||||
| GET | `/api/v1/health` | Health check |
|
||||
| POST | `/api/v1/infer` | Process invoice |
|
||||
| GET | `/api/v1/results/{filename}` | Get visualization |
|
||||
|
||||
## Current Status
|
||||
|
||||
- **Tests**: 688 passing
|
||||
- **Coverage**: 37%
|
||||
- **Model**: 93.5% mAP@0.5
|
||||
- **Documents Labeled**: 9,738
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Start server
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py"
|
||||
|
||||
# Run tests
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest"
|
||||
|
||||
# Access UI: http://localhost:8000
|
||||
```
|
||||
- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `test:`
|
||||
- Never commit to main directly
|
||||
- PRs require review
|
||||
- All tests must pass before merge
|
||||
|
||||
@@ -107,7 +107,11 @@
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, metrics_mAP, document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {row[0][:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && DB_HOST=192.168.68.31 DB_PORT=5432 DB_NAME=docmaster DB_USER=docmaster DB_PASSWORD=0412220 python -c \"\"\nimport os\nimport psycopg2\n\nhost = os.getenv\\(''DB_HOST''\\)\nport = os.getenv\\(''DB_PORT''\\)\ndbname = os.getenv\\(''DB_NAME''\\)\nuser = os.getenv\\(''DB_USER''\\)\npassword = os.getenv\\(''DB_PASSWORD''\\)\n\nconn = psycopg2.connect\\(host=host, port=port, dbname=dbname, user=user, password=password\\)\ncursor = conn.cursor\\(\\)\n\n# Get all model versions - use double quotes for case-sensitive column names\ncursor.execute\\(''''''\n SELECT version_id, version, name, status, is_active, \\\\\"\"metrics_mAP\\\\\"\", document_count, model_path, created_at\n FROM model_versions\n ORDER BY created_at DESC\n''''''\\)\nprint\\(''Existing model versions:''\\)\nfor row in cursor.fetchall\\(\\):\n print\\(f'' ID: {str\\(row[0]\\)[:8]}...''\\)\n print\\(f'' Version: {row[1]}''\\)\n print\\(f'' Name: {row[2]}''\\)\n print\\(f'' Status: {row[3]}''\\)\n print\\(f'' Active: {row[4]}''\\)\n print\\(f'' mAP: {row[5]}''\\)\n print\\(f'' Docs: {row[6]}''\\)\n print\\(f'' Path: {row[7]}''\\)\n print\\(f'' Created: {row[8]}''\\)\n print\\(\\)\n\ncursor.close\\(\\)\nconn.close\\(\\)\n\"\"\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/shared/fields/test_field_config.py -v 2>&1 | head -100\")",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")"
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m pytest tests/web/core/test_task_interface.py -v 2>&1 | head -60\")",
|
||||
"Skill(tdd)",
|
||||
"Skill(tdd:*)",
|
||||
"Bash(wsl bash -c \"source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python -m training.cli.train --model runs/train/invoice_fields/weights/best.pt --device 0 --epochs 100\")",
|
||||
"Bash(git commit -m \"$\\(cat <<''EOF''\nfeat: add field-specific bbox expansion strategies for YOLO training\n\nImplement center-point based bbox scaling with directional compensation\nto capture field labels that typically appear above or to the left of\nfield values. This improves YOLO training data quality by including\ncontextual information around field values.\n\nKey changes:\n- Add shared.bbox module with ScaleStrategy dataclass and expand_bbox function\n- Define field-specific strategies \\(ocr_number, bankgiro, invoice_date, etc.\\)\n- Support manual_mode for minimal padding \\(no scaling\\)\n- Integrate expand_bbox into AnnotationGenerator\n- Add FIELD_TO_CLASS mapping for field_name to class_name lookup\n- Comprehensive tests with 100% coverage \\(45 tests\\)\n\nCo-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>\nEOF\n\\)\")"
|
||||
],
|
||||
"deny": [],
|
||||
"ask": [],
|
||||
|
||||
314
.opencode/skills/backend-patterns/SKILL.md
Normal file
314
.opencode/skills/backend-patterns/SKILL.md
Normal file
@@ -0,0 +1,314 @@
|
||||
# Backend Development Patterns
|
||||
|
||||
Backend architecture patterns for Python/FastAPI/PostgreSQL applications.
|
||||
|
||||
## API Design
|
||||
|
||||
### RESTful Structure
|
||||
|
||||
```
|
||||
GET /api/v1/documents # List
|
||||
GET /api/v1/documents/{id} # Get
|
||||
POST /api/v1/documents # Create
|
||||
PUT /api/v1/documents/{id} # Replace
|
||||
PATCH /api/v1/documents/{id} # Update
|
||||
DELETE /api/v1/documents/{id} # Delete
|
||||
|
||||
GET /api/v1/documents?status=processed&sort=created_at&limit=20&offset=0
|
||||
```
|
||||
|
||||
### FastAPI Route Pattern
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
@router.post("/infer", response_model=ApiResponse[InferenceResult])
|
||||
async def infer_document(
|
||||
file: UploadFile = File(...),
|
||||
confidence_threshold: float = Query(0.5, ge=0, le=1),
|
||||
service: InferenceService = Depends(get_inference_service)
|
||||
) -> ApiResponse[InferenceResult]:
|
||||
result = await service.process(file, confidence_threshold)
|
||||
return ApiResponse(success=True, data=result)
|
||||
```
|
||||
|
||||
### Consistent Response Schema
|
||||
|
||||
```python
|
||||
from typing import Generic, TypeVar
|
||||
T = TypeVar('T')
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: T | None = None
|
||||
error: str | None = None
|
||||
meta: dict | None = None
|
||||
```
|
||||
|
||||
## Core Patterns
|
||||
|
||||
### Repository Pattern
|
||||
|
||||
```python
|
||||
from typing import Protocol
|
||||
|
||||
class DocumentRepository(Protocol):
|
||||
def find_all(self, filters: dict | None = None) -> list[Document]: ...
|
||||
def find_by_id(self, id: str) -> Document | None: ...
|
||||
def create(self, data: dict) -> Document: ...
|
||||
def update(self, id: str, data: dict) -> Document: ...
|
||||
def delete(self, id: str) -> None: ...
|
||||
```
|
||||
|
||||
### Service Layer
|
||||
|
||||
```python
|
||||
class InferenceService:
|
||||
def __init__(self, model_path: str, use_gpu: bool = True):
|
||||
self.pipeline = InferencePipeline(model_path=model_path, use_gpu=use_gpu)
|
||||
|
||||
async def process(self, file: UploadFile, confidence_threshold: float) -> InferenceResult:
|
||||
temp_path = self._save_temp_file(file)
|
||||
try:
|
||||
return self.pipeline.process_pdf(temp_path)
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
```
|
||||
|
||||
### Dependency Injection
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
db_host: str = "localhost"
|
||||
db_password: str
|
||||
model_path: str = "runs/train/invoice_fields/weights/best.pt"
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
|
||||
def get_inference_service(settings: Settings = Depends(get_settings)) -> InferenceService:
|
||||
return InferenceService(model_path=settings.model_path)
|
||||
```
|
||||
|
||||
## Database Patterns
|
||||
|
||||
### Connection Pooling
|
||||
|
||||
```python
|
||||
from psycopg2 import pool
|
||||
from contextlib import contextmanager
|
||||
|
||||
db_pool = pool.ThreadedConnectionPool(minconn=2, maxconn=10, **db_config)
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection():
|
||||
conn = db_pool.getconn()
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
db_pool.putconn(conn)
|
||||
```
|
||||
|
||||
### Query Optimization
|
||||
|
||||
```python
|
||||
# GOOD: Select only needed columns
|
||||
cur.execute("""
|
||||
SELECT id, status, fields->>'InvoiceNumber' as invoice_number
|
||||
FROM documents WHERE status = %s
|
||||
ORDER BY created_at DESC LIMIT %s
|
||||
""", ('processed', 10))
|
||||
|
||||
# BAD: SELECT * FROM documents
|
||||
```
|
||||
|
||||
### N+1 Prevention
|
||||
|
||||
```python
|
||||
# BAD: N+1 queries
|
||||
for doc in documents:
|
||||
doc.labels = get_labels(doc.id) # N queries
|
||||
|
||||
# GOOD: Batch fetch with JOIN
|
||||
cur.execute("""
|
||||
SELECT d.id, d.status, array_agg(l.label) as labels
|
||||
FROM documents d
|
||||
LEFT JOIN document_labels l ON d.id = l.document_id
|
||||
GROUP BY d.id, d.status
|
||||
""")
|
||||
```
|
||||
|
||||
### Transaction Pattern
|
||||
|
||||
```python
|
||||
def create_document_with_labels(doc_data: dict, labels: list[dict]) -> str:
|
||||
with get_db_connection() as conn:
|
||||
try:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("INSERT INTO documents ... RETURNING id", ...)
|
||||
doc_id = cur.fetchone()[0]
|
||||
for label in labels:
|
||||
cur.execute("INSERT INTO document_labels ...", ...)
|
||||
conn.commit()
|
||||
return doc_id
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
```
|
||||
|
||||
## Caching
|
||||
|
||||
```python
|
||||
from cachetools import TTLCache
|
||||
|
||||
_cache = TTLCache(maxsize=1000, ttl=300)
|
||||
|
||||
def get_document_cached(doc_id: str) -> Document | None:
|
||||
if doc_id in _cache:
|
||||
return _cache[doc_id]
|
||||
doc = repo.find_by_id(doc_id)
|
||||
if doc:
|
||||
_cache[doc_id] = doc
|
||||
return doc
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
### Exception Hierarchy
|
||||
|
||||
```python
|
||||
class AppError(Exception):
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
|
||||
class NotFoundError(AppError):
|
||||
def __init__(self, resource: str, id: str):
|
||||
super().__init__(f"{resource} not found: {id}", 404)
|
||||
|
||||
class ValidationError(AppError):
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message, 400)
|
||||
```
|
||||
|
||||
### FastAPI Exception Handler
|
||||
|
||||
```python
|
||||
@app.exception_handler(AppError)
|
||||
async def app_error_handler(request: Request, exc: AppError):
|
||||
return JSONResponse(status_code=exc.status_code, content={"success": False, "error": exc.message})
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def generic_error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unexpected error: {exc}", exc_info=True)
|
||||
return JSONResponse(status_code=500, content={"success": False, "error": "Internal server error"})
|
||||
```
|
||||
|
||||
### Retry with Backoff
|
||||
|
||||
```python
|
||||
async def retry_with_backoff(fn, max_retries: int = 3, base_delay: float = 1.0):
|
||||
last_error = None
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await fn() if asyncio.iscoroutinefunction(fn) else fn()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
await asyncio.sleep(base_delay * (2 ** attempt))
|
||||
raise last_error
|
||||
```
|
||||
|
||||
## Rate Limiting
|
||||
|
||||
```python
|
||||
from time import time
|
||||
from collections import defaultdict
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self):
|
||||
self.requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def check_limit(self, identifier: str, max_requests: int, window_sec: int) -> bool:
|
||||
now = time()
|
||||
self.requests[identifier] = [t for t in self.requests[identifier] if now - t < window_sec]
|
||||
if len(self.requests[identifier]) >= max_requests:
|
||||
return False
|
||||
self.requests[identifier].append(now)
|
||||
return True
|
||||
|
||||
limiter = RateLimiter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
ip = request.client.host
|
||||
if not limiter.check_limit(ip, max_requests=100, window_sec=60):
|
||||
return JSONResponse(status_code=429, content={"error": "Rate limit exceeded"})
|
||||
return await call_next(request)
|
||||
```
|
||||
|
||||
## Logging & Middleware
|
||||
|
||||
### Request Logging
|
||||
|
||||
```python
|
||||
@app.middleware("http")
|
||||
async def log_requests(request: Request, call_next):
|
||||
request_id = str(uuid.uuid4())[:8]
|
||||
start_time = time.time()
|
||||
logger.info(f"[{request_id}] {request.method} {request.url.path}")
|
||||
response = await call_next(request)
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
logger.info(f"[{request_id}] Completed {response.status_code} in {duration_ms:.2f}ms")
|
||||
return response
|
||||
```
|
||||
|
||||
### Structured Logging
|
||||
|
||||
```python
|
||||
class JSONFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return json.dumps({
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"level": record.levelname,
|
||||
"message": record.getMessage(),
|
||||
"module": record.module,
|
||||
})
|
||||
```
|
||||
|
||||
## Background Tasks
|
||||
|
||||
```python
|
||||
from fastapi import BackgroundTasks
|
||||
|
||||
def send_notification(document_id: str, status: str):
|
||||
logger.info(f"Notification: {document_id} -> {status}")
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(file: UploadFile, background_tasks: BackgroundTasks):
|
||||
result = await process_document(file)
|
||||
background_tasks.add_task(send_notification, result.document_id, "completed")
|
||||
return result
|
||||
```
|
||||
|
||||
## Key Principles
|
||||
|
||||
- Repository pattern: Abstract data access
|
||||
- Service layer: Business logic separated from routes
|
||||
- Dependency injection via `Depends()`
|
||||
- Connection pooling for database
|
||||
- Parameterized queries only (no f-strings in SQL)
|
||||
- Batch fetch to prevent N+1
|
||||
- Consistent `ApiResponse[T]` format
|
||||
- Exception hierarchy with proper status codes
|
||||
- Rate limit by IP
|
||||
- Structured logging with request ID
|
||||
665
.opencode/skills/coding-standards/SKILL.md
Normal file
665
.opencode/skills/coding-standards/SKILL.md
Normal file
@@ -0,0 +1,665 @@
|
||||
---
|
||||
name: coding-standards
|
||||
description: Universal coding standards, best practices, and patterns for Python, FastAPI, and data processing development.
|
||||
---
|
||||
|
||||
# Coding Standards & Best Practices
|
||||
|
||||
Python coding standards for the Invoice Master project.
|
||||
|
||||
## Code Quality Principles
|
||||
|
||||
### 1. Readability First
|
||||
- Code is read more than written
|
||||
- Clear variable and function names
|
||||
- Self-documenting code preferred over comments
|
||||
- Consistent formatting (follow PEP 8)
|
||||
|
||||
### 2. KISS (Keep It Simple, Stupid)
|
||||
- Simplest solution that works
|
||||
- Avoid over-engineering
|
||||
- No premature optimization
|
||||
- Easy to understand > clever code
|
||||
|
||||
### 3. DRY (Don't Repeat Yourself)
|
||||
- Extract common logic into functions
|
||||
- Create reusable utilities
|
||||
- Share modules across the codebase
|
||||
- Avoid copy-paste programming
|
||||
|
||||
### 4. YAGNI (You Aren't Gonna Need It)
|
||||
- Don't build features before they're needed
|
||||
- Avoid speculative generality
|
||||
- Add complexity only when required
|
||||
- Start simple, refactor when needed
|
||||
|
||||
## Python Standards
|
||||
|
||||
### Variable Naming
|
||||
|
||||
```python
|
||||
# GOOD: Descriptive names
|
||||
invoice_number = "INV-2024-001"
|
||||
is_valid_document = True
|
||||
total_confidence_score = 0.95
|
||||
|
||||
# BAD: Unclear names
|
||||
inv = "INV-2024-001"
|
||||
flag = True
|
||||
x = 0.95
|
||||
```
|
||||
|
||||
### Function Naming
|
||||
|
||||
```python
|
||||
# GOOD: Verb-noun pattern with type hints
|
||||
def extract_invoice_fields(pdf_path: Path) -> dict[str, str]:
|
||||
"""Extract fields from invoice PDF."""
|
||||
...
|
||||
|
||||
def calculate_confidence(predictions: list[float]) -> float:
|
||||
"""Calculate average confidence score."""
|
||||
...
|
||||
|
||||
def is_valid_bankgiro(value: str) -> bool:
|
||||
"""Check if value is valid Bankgiro number."""
|
||||
...
|
||||
|
||||
# BAD: Unclear or noun-only
|
||||
def invoice(path):
|
||||
...
|
||||
|
||||
def confidence(p):
|
||||
...
|
||||
|
||||
def bankgiro(v):
|
||||
...
|
||||
```
|
||||
|
||||
### Type Hints (REQUIRED)
|
||||
|
||||
```python
|
||||
# GOOD: Full type annotations
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class InferenceResult:
|
||||
document_id: str
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
processing_time_ms: float
|
||||
|
||||
def process_document(
|
||||
pdf_path: Path,
|
||||
confidence_threshold: float = 0.5
|
||||
) -> InferenceResult:
|
||||
"""Process PDF and return extracted fields."""
|
||||
...
|
||||
|
||||
# BAD: No type hints
|
||||
def process_document(pdf_path, confidence_threshold=0.5):
|
||||
...
|
||||
```
|
||||
|
||||
### Immutability Pattern (CRITICAL)
|
||||
|
||||
```python
|
||||
# GOOD: Create new objects, don't mutate
|
||||
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
|
||||
return {**fields, **updates}
|
||||
|
||||
def add_item(items: list[str], new_item: str) -> list[str]:
|
||||
return [*items, new_item]
|
||||
|
||||
# BAD: Direct mutation
|
||||
def update_fields(fields: dict[str, str], updates: dict[str, str]) -> dict[str, str]:
|
||||
fields.update(updates) # MUTATION!
|
||||
return fields
|
||||
|
||||
def add_item(items: list[str], new_item: str) -> list[str]:
|
||||
items.append(new_item) # MUTATION!
|
||||
return items
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GOOD: Comprehensive error handling with logging
|
||||
def load_model(model_path: Path) -> Model:
|
||||
"""Load YOLO model from path."""
|
||||
try:
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"Model not found: {model_path}")
|
||||
|
||||
model = YOLO(str(model_path))
|
||||
logger.info(f"Model loaded: {model_path}")
|
||||
return model
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}")
|
||||
raise RuntimeError(f"Model loading failed: {model_path}") from e
|
||||
|
||||
# BAD: No error handling
|
||||
def load_model(model_path):
|
||||
return YOLO(str(model_path))
|
||||
|
||||
# BAD: Bare except
|
||||
def load_model(model_path):
|
||||
try:
|
||||
return YOLO(str(model_path))
|
||||
except: # Never use bare except!
|
||||
return None
|
||||
```
|
||||
|
||||
### Async Best Practices
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
|
||||
# GOOD: Parallel execution when possible
|
||||
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
|
||||
tasks = [process_document(path) for path in pdf_paths]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle exceptions
|
||||
valid_results = []
|
||||
for path, result in zip(pdf_paths, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(f"Failed to process {path}: {result}")
|
||||
else:
|
||||
valid_results.append(result)
|
||||
return valid_results
|
||||
|
||||
# BAD: Sequential when unnecessary
|
||||
async def process_batch(pdf_paths: list[Path]) -> list[InferenceResult]:
|
||||
results = []
|
||||
for path in pdf_paths:
|
||||
result = await process_document(path)
|
||||
results.append(result)
|
||||
return results
|
||||
```
|
||||
|
||||
### Context Managers
|
||||
|
||||
```python
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
# GOOD: Proper resource management
|
||||
@contextmanager
|
||||
def temp_pdf_copy(pdf_path: Path):
|
||||
"""Create temporary copy of PDF for processing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp:
|
||||
tmp.write(pdf_path.read_bytes())
|
||||
tmp_path = Path(tmp.name)
|
||||
try:
|
||||
yield tmp_path
|
||||
finally:
|
||||
tmp_path.unlink(missing_ok=True)
|
||||
|
||||
# Usage
|
||||
with temp_pdf_copy(original_pdf) as tmp_pdf:
|
||||
result = process_pdf(tmp_pdf)
|
||||
```
|
||||
|
||||
## FastAPI Best Practices
|
||||
|
||||
### Route Structure
|
||||
|
||||
```python
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["inference"])
|
||||
|
||||
class InferenceResponse(BaseModel):
|
||||
success: bool
|
||||
document_id: str
|
||||
fields: dict[str, str]
|
||||
confidence: dict[str, float]
|
||||
processing_time_ms: float
|
||||
|
||||
@router.post("/infer", response_model=InferenceResponse)
|
||||
async def infer_document(
|
||||
file: UploadFile = File(...),
|
||||
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0)
|
||||
) -> InferenceResponse:
|
||||
"""Process invoice PDF and extract fields."""
|
||||
if not file.filename.endswith(".pdf"):
|
||||
raise HTTPException(status_code=400, detail="Only PDF files accepted")
|
||||
|
||||
result = await inference_service.process(file, confidence_threshold)
|
||||
return InferenceResponse(
|
||||
success=True,
|
||||
document_id=result.document_id,
|
||||
fields=result.fields,
|
||||
confidence=result.confidence,
|
||||
processing_time_ms=result.processing_time_ms
|
||||
)
|
||||
```
|
||||
|
||||
### Input Validation with Pydantic
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import date
|
||||
import re
|
||||
|
||||
class InvoiceData(BaseModel):
|
||||
invoice_number: str = Field(..., min_length=1, max_length=50)
|
||||
invoice_date: date
|
||||
amount: float = Field(..., gt=0)
|
||||
bankgiro: str | None = None
|
||||
ocr_number: str | None = None
|
||||
|
||||
@field_validator("bankgiro")
|
||||
@classmethod
|
||||
def validate_bankgiro(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
# Bankgiro: 7-8 digits
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (7 <= len(cleaned) <= 8):
|
||||
raise ValueError("Bankgiro must be 7-8 digits")
|
||||
return cleaned
|
||||
|
||||
@field_validator("ocr_number")
|
||||
@classmethod
|
||||
def validate_ocr(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
# OCR: 2-25 digits
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (2 <= len(cleaned) <= 25):
|
||||
raise ValueError("OCR must be 2-25 digits")
|
||||
return cleaned
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: T | None = None
|
||||
error: str | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
# Success response
|
||||
return ApiResponse(
|
||||
success=True,
|
||||
data=result,
|
||||
meta={"processing_time_ms": elapsed_ms}
|
||||
)
|
||||
|
||||
# Error response
|
||||
return ApiResponse(
|
||||
success=False,
|
||||
error="Invalid PDF format"
|
||||
)
|
||||
```
|
||||
|
||||
## File Organization
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
src/
|
||||
├── cli/ # Command-line interfaces
|
||||
│ ├── autolabel.py
|
||||
│ ├── train.py
|
||||
│ └── infer.py
|
||||
├── pdf/ # PDF processing
|
||||
│ ├── extractor.py
|
||||
│ └── renderer.py
|
||||
├── ocr/ # OCR processing
|
||||
│ ├── paddle_ocr.py
|
||||
│ └── machine_code_parser.py
|
||||
├── inference/ # Inference pipeline
|
||||
│ ├── pipeline.py
|
||||
│ ├── yolo_detector.py
|
||||
│ └── field_extractor.py
|
||||
├── normalize/ # Field normalization
|
||||
│ ├── base.py
|
||||
│ ├── date_normalizer.py
|
||||
│ └── amount_normalizer.py
|
||||
├── web/ # FastAPI application
|
||||
│ ├── app.py
|
||||
│ ├── routes.py
|
||||
│ ├── services.py
|
||||
│ └── schemas.py
|
||||
└── utils/ # Shared utilities
|
||||
├── validators.py
|
||||
├── text_cleaner.py
|
||||
└── logging.py
|
||||
tests/ # Mirror of src structure
|
||||
├── test_pdf/
|
||||
├── test_ocr/
|
||||
└── test_inference/
|
||||
```
|
||||
|
||||
### File Naming
|
||||
|
||||
```
|
||||
src/ocr/paddle_ocr.py # snake_case for modules
|
||||
src/inference/yolo_detector.py # snake_case for modules
|
||||
tests/test_paddle_ocr.py # test_ prefix for tests
|
||||
config.py # snake_case for config
|
||||
```
|
||||
|
||||
### Module Size Guidelines
|
||||
|
||||
- **Maximum**: 800 lines per file
|
||||
- **Typical**: 200-400 lines per file
|
||||
- **Functions**: Max 50 lines each
|
||||
- Extract utilities when modules grow too large
|
||||
|
||||
## Comments & Documentation
|
||||
|
||||
### When to Comment
|
||||
|
||||
```python
|
||||
# GOOD: Explain WHY, not WHAT
|
||||
# Swedish Bankgiro uses Luhn algorithm with weight [1,2,1,2...]
|
||||
def validate_bankgiro_checksum(bankgiro: str) -> bool:
|
||||
...
|
||||
|
||||
# Payment line format: 7 groups separated by #, checksum at end
|
||||
def parse_payment_line(line: str) -> PaymentLineData:
|
||||
...
|
||||
|
||||
# BAD: Stating the obvious
|
||||
# Increment counter by 1
|
||||
count += 1
|
||||
|
||||
# Set name to user's name
|
||||
name = user.name
|
||||
```
|
||||
|
||||
### Docstrings for Public APIs
|
||||
|
||||
```python
|
||||
def extract_invoice_fields(
|
||||
pdf_path: Path,
|
||||
confidence_threshold: float = 0.5,
|
||||
use_gpu: bool = True
|
||||
) -> InferenceResult:
|
||||
"""Extract structured fields from Swedish invoice PDF.
|
||||
|
||||
Uses YOLOv11 for field detection and PaddleOCR for text extraction.
|
||||
Applies field-specific normalization and validation.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the invoice PDF file.
|
||||
confidence_threshold: Minimum confidence for field detection (0.0-1.0).
|
||||
use_gpu: Whether to use GPU acceleration.
|
||||
|
||||
Returns:
|
||||
InferenceResult containing extracted fields and confidence scores.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If PDF file doesn't exist.
|
||||
ProcessingError: If OCR or detection fails.
|
||||
|
||||
Example:
|
||||
>>> result = extract_invoice_fields(Path("invoice.pdf"))
|
||||
>>> print(result.fields["invoice_number"])
|
||||
"INV-2024-001"
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
## Performance Best Practices
|
||||
|
||||
### Caching
|
||||
|
||||
```python
|
||||
from functools import lru_cache
|
||||
from cachetools import TTLCache
|
||||
|
||||
# Static data: LRU cache
|
||||
@lru_cache(maxsize=100)
|
||||
def get_field_config(field_name: str) -> FieldConfig:
|
||||
"""Load field configuration (cached)."""
|
||||
return load_config(field_name)
|
||||
|
||||
# Dynamic data: TTL cache
|
||||
_document_cache = TTLCache(maxsize=1000, ttl=300) # 5 minutes
|
||||
|
||||
def get_document_cached(doc_id: str) -> Document | None:
|
||||
if doc_id in _document_cache:
|
||||
return _document_cache[doc_id]
|
||||
|
||||
doc = repo.find_by_id(doc_id)
|
||||
if doc:
|
||||
_document_cache[doc_id] = doc
|
||||
return doc
|
||||
```
|
||||
|
||||
### Database Queries
|
||||
|
||||
```python
|
||||
# GOOD: Select only needed columns
|
||||
cur.execute("""
|
||||
SELECT id, status, fields->>'invoice_number'
|
||||
FROM documents
|
||||
WHERE status = %s
|
||||
LIMIT %s
|
||||
""", ('processed', 10))
|
||||
|
||||
# BAD: Select everything
|
||||
cur.execute("SELECT * FROM documents")
|
||||
|
||||
# GOOD: Batch operations
|
||||
cur.executemany(
|
||||
"INSERT INTO labels (doc_id, field, value) VALUES (%s, %s, %s)",
|
||||
[(doc_id, f, v) for f, v in fields.items()]
|
||||
)
|
||||
|
||||
# BAD: Individual inserts in loop
|
||||
for field, value in fields.items():
|
||||
cur.execute("INSERT INTO labels ...", (doc_id, field, value))
|
||||
```
|
||||
|
||||
### Lazy Loading
|
||||
|
||||
```python
|
||||
class InferencePipeline:
|
||||
def __init__(self, model_path: Path):
|
||||
self.model_path = model_path
|
||||
self._model: YOLO | None = None
|
||||
self._ocr: PaddleOCR | None = None
|
||||
|
||||
@property
|
||||
def model(self) -> YOLO:
|
||||
"""Lazy load YOLO model."""
|
||||
if self._model is None:
|
||||
self._model = YOLO(str(self.model_path))
|
||||
return self._model
|
||||
|
||||
@property
|
||||
def ocr(self) -> PaddleOCR:
|
||||
"""Lazy load PaddleOCR."""
|
||||
if self._ocr is None:
|
||||
self._ocr = PaddleOCR(use_angle_cls=True, lang="latin")
|
||||
return self._ocr
|
||||
```
|
||||
|
||||
## Testing Standards
|
||||
|
||||
### Test Structure (AAA Pattern)
|
||||
|
||||
```python
|
||||
def test_extract_bankgiro_valid():
|
||||
# Arrange
|
||||
text = "Bankgiro: 123-4567"
|
||||
|
||||
# Act
|
||||
result = extract_bankgiro(text)
|
||||
|
||||
# Assert
|
||||
assert result == "1234567"
|
||||
|
||||
def test_extract_bankgiro_invalid_returns_none():
|
||||
# Arrange
|
||||
text = "No bankgiro here"
|
||||
|
||||
# Act
|
||||
result = extract_bankgiro(text)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
```
|
||||
|
||||
### Test Naming
|
||||
|
||||
```python
|
||||
# GOOD: Descriptive test names
|
||||
def test_parse_payment_line_extracts_all_fields(): ...
|
||||
def test_parse_payment_line_handles_missing_checksum(): ...
|
||||
def test_validate_ocr_returns_false_for_invalid_checksum(): ...
|
||||
|
||||
# BAD: Vague test names
|
||||
def test_parse(): ...
|
||||
def test_works(): ...
|
||||
def test_payment_line(): ...
|
||||
```
|
||||
|
||||
### Fixtures
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invoice_pdf(tmp_path: Path) -> Path:
|
||||
"""Create sample invoice PDF for testing."""
|
||||
pdf_path = tmp_path / "invoice.pdf"
|
||||
# Create test PDF...
|
||||
return pdf_path
|
||||
|
||||
@pytest.fixture
|
||||
def inference_pipeline(sample_model_path: Path) -> InferencePipeline:
|
||||
"""Create inference pipeline with test model."""
|
||||
return InferencePipeline(sample_model_path)
|
||||
|
||||
def test_process_invoice(inference_pipeline, sample_invoice_pdf):
|
||||
result = inference_pipeline.process(sample_invoice_pdf)
|
||||
assert result.fields.get("invoice_number") is not None
|
||||
```
|
||||
|
||||
## Code Smell Detection
|
||||
|
||||
### 1. Long Functions
|
||||
|
||||
```python
|
||||
# BAD: Function > 50 lines
|
||||
def process_document():
|
||||
# 100 lines of code...
|
||||
|
||||
# GOOD: Split into smaller functions
|
||||
def process_document(pdf_path: Path) -> InferenceResult:
|
||||
image = render_pdf(pdf_path)
|
||||
detections = detect_fields(image)
|
||||
ocr_results = extract_text(image, detections)
|
||||
fields = normalize_fields(ocr_results)
|
||||
return build_result(fields)
|
||||
```
|
||||
|
||||
### 2. Deep Nesting
|
||||
|
||||
```python
|
||||
# BAD: 5+ levels of nesting
|
||||
if document:
|
||||
if document.is_valid:
|
||||
if document.has_fields:
|
||||
if field in document.fields:
|
||||
if document.fields[field]:
|
||||
# Do something
|
||||
|
||||
# GOOD: Early returns
|
||||
if not document:
|
||||
return None
|
||||
if not document.is_valid:
|
||||
return None
|
||||
if not document.has_fields:
|
||||
return None
|
||||
if field not in document.fields:
|
||||
return None
|
||||
if not document.fields[field]:
|
||||
return None
|
||||
|
||||
# Do something
|
||||
```
|
||||
|
||||
### 3. Magic Numbers
|
||||
|
||||
```python
|
||||
# BAD: Unexplained numbers
|
||||
if confidence > 0.5:
|
||||
...
|
||||
time.sleep(3)
|
||||
|
||||
# GOOD: Named constants
|
||||
CONFIDENCE_THRESHOLD = 0.5
|
||||
RETRY_DELAY_SECONDS = 3
|
||||
|
||||
if confidence > CONFIDENCE_THRESHOLD:
|
||||
...
|
||||
time.sleep(RETRY_DELAY_SECONDS)
|
||||
```
|
||||
|
||||
### 4. Mutable Default Arguments
|
||||
|
||||
```python
|
||||
# BAD: Mutable default argument
|
||||
def process_fields(fields: list = []): # DANGEROUS!
|
||||
fields.append("new_field")
|
||||
return fields
|
||||
|
||||
# GOOD: Use None as default
|
||||
def process_fields(fields: list | None = None) -> list:
|
||||
if fields is None:
|
||||
fields = []
|
||||
return [*fields, "new_field"]
|
||||
```
|
||||
|
||||
## Logging Standards
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Module-level logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# GOOD: Appropriate log levels
|
||||
logger.debug("Processing document: %s", doc_id)
|
||||
logger.info("Document processed successfully: %s", doc_id)
|
||||
logger.warning("Low confidence score: %.2f", confidence)
|
||||
logger.error("Failed to process document: %s", error)
|
||||
|
||||
# GOOD: Structured logging with extra data
|
||||
logger.info(
|
||||
"Inference complete",
|
||||
extra={
|
||||
"document_id": doc_id,
|
||||
"field_count": len(fields),
|
||||
"processing_time_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
|
||||
# BAD: Using print()
|
||||
print(f"Processing {doc_id}") # Never in production!
|
||||
```
|
||||
|
||||
**Remember**: Code quality is not negotiable. Clear, maintainable Python code with proper type hints enables confident development and refactoring.
|
||||
80
.opencode/skills/continuous-learning/SKILL.md
Normal file
80
.opencode/skills/continuous-learning/SKILL.md
Normal file
@@ -0,0 +1,80 @@
|
||||
---
|
||||
name: continuous-learning
|
||||
description: Automatically extract reusable patterns from Claude Code sessions and save them as learned skills for future use.
|
||||
---
|
||||
|
||||
# Continuous Learning Skill
|
||||
|
||||
Automatically evaluates Claude Code sessions on end to extract reusable patterns that can be saved as learned skills.
|
||||
|
||||
## How It Works
|
||||
|
||||
This skill runs as a **Stop hook** at the end of each session:
|
||||
|
||||
1. **Session Evaluation**: Checks if session has enough messages (default: 10+)
|
||||
2. **Pattern Detection**: Identifies extractable patterns from the session
|
||||
3. **Skill Extraction**: Saves useful patterns to `~/.claude/skills/learned/`
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit `config.json` to customize:
|
||||
|
||||
```json
|
||||
{
|
||||
"min_session_length": 10,
|
||||
"extraction_threshold": "medium",
|
||||
"auto_approve": false,
|
||||
"learned_skills_path": "~/.claude/skills/learned/",
|
||||
"patterns_to_detect": [
|
||||
"error_resolution",
|
||||
"user_corrections",
|
||||
"workarounds",
|
||||
"debugging_techniques",
|
||||
"project_specific"
|
||||
],
|
||||
"ignore_patterns": [
|
||||
"simple_typos",
|
||||
"one_time_fixes",
|
||||
"external_api_issues"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Pattern Types
|
||||
|
||||
| Pattern | Description |
|
||||
|---------|-------------|
|
||||
| `error_resolution` | How specific errors were resolved |
|
||||
| `user_corrections` | Patterns from user corrections |
|
||||
| `workarounds` | Solutions to framework/library quirks |
|
||||
| `debugging_techniques` | Effective debugging approaches |
|
||||
| `project_specific` | Project-specific conventions |
|
||||
|
||||
## Hook Setup
|
||||
|
||||
Add to your `~/.claude/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"Stop": [{
|
||||
"matcher": "*",
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
|
||||
}]
|
||||
}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Why Stop Hook?
|
||||
|
||||
- **Lightweight**: Runs once at session end
|
||||
- **Non-blocking**: Doesn't add latency to every message
|
||||
- **Complete context**: Has access to full session transcript
|
||||
|
||||
## Related
|
||||
|
||||
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Section on continuous learning
|
||||
- `/learn` command - Manual pattern extraction mid-session
|
||||
18
.opencode/skills/continuous-learning/config.json
Normal file
18
.opencode/skills/continuous-learning/config.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"min_session_length": 10,
|
||||
"extraction_threshold": "medium",
|
||||
"auto_approve": false,
|
||||
"learned_skills_path": "~/.claude/skills/learned/",
|
||||
"patterns_to_detect": [
|
||||
"error_resolution",
|
||||
"user_corrections",
|
||||
"workarounds",
|
||||
"debugging_techniques",
|
||||
"project_specific"
|
||||
],
|
||||
"ignore_patterns": [
|
||||
"simple_typos",
|
||||
"one_time_fixes",
|
||||
"external_api_issues"
|
||||
]
|
||||
}
|
||||
60
.opencode/skills/continuous-learning/evaluate-session.sh
Normal file
60
.opencode/skills/continuous-learning/evaluate-session.sh
Normal file
@@ -0,0 +1,60 @@
|
||||
#!/bin/bash
|
||||
# Continuous Learning - Session Evaluator
|
||||
# Runs on Stop hook to extract reusable patterns from Claude Code sessions
|
||||
#
|
||||
# Why Stop hook instead of UserPromptSubmit:
|
||||
# - Stop runs once at session end (lightweight)
|
||||
# - UserPromptSubmit runs every message (heavy, adds latency)
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "Stop": [{
|
||||
# "matcher": "*",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/skills/continuous-learning/evaluate-session.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# Patterns to detect: error_resolution, debugging_techniques, workarounds, project_specific
|
||||
# Patterns to ignore: simple_typos, one_time_fixes, external_api_issues
|
||||
# Extracted skills saved to: ~/.claude/skills/learned/
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
CONFIG_FILE="$SCRIPT_DIR/config.json"
|
||||
LEARNED_SKILLS_PATH="${HOME}/.claude/skills/learned"
|
||||
MIN_SESSION_LENGTH=10
|
||||
|
||||
# Load config if exists
|
||||
if [ -f "$CONFIG_FILE" ]; then
|
||||
MIN_SESSION_LENGTH=$(jq -r '.min_session_length // 10' "$CONFIG_FILE")
|
||||
LEARNED_SKILLS_PATH=$(jq -r '.learned_skills_path // "~/.claude/skills/learned/"' "$CONFIG_FILE" | sed "s|~|$HOME|")
|
||||
fi
|
||||
|
||||
# Ensure learned skills directory exists
|
||||
mkdir -p "$LEARNED_SKILLS_PATH"
|
||||
|
||||
# Get transcript path from environment (set by Claude Code)
|
||||
transcript_path="${CLAUDE_TRANSCRIPT_PATH:-}"
|
||||
|
||||
if [ -z "$transcript_path" ] || [ ! -f "$transcript_path" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Count messages in session
|
||||
message_count=$(grep -c '"type":"user"' "$transcript_path" 2>/dev/null || echo "0")
|
||||
|
||||
# Skip short sessions
|
||||
if [ "$message_count" -lt "$MIN_SESSION_LENGTH" ]; then
|
||||
echo "[ContinuousLearning] Session too short ($message_count messages), skipping" >&2
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Signal to Claude that session should be evaluated for extractable patterns
|
||||
echo "[ContinuousLearning] Session has $message_count messages - evaluate for extractable patterns" >&2
|
||||
echo "[ContinuousLearning] Save learned skills to: $LEARNED_SKILLS_PATH" >&2
|
||||
221
.opencode/skills/eval-harness/SKILL.md
Normal file
221
.opencode/skills/eval-harness/SKILL.md
Normal file
@@ -0,0 +1,221 @@
|
||||
# Eval Harness Skill
|
||||
|
||||
A formal evaluation framework for Claude Code sessions, implementing eval-driven development (EDD) principles.
|
||||
|
||||
## Philosophy
|
||||
|
||||
Eval-Driven Development treats evals as the "unit tests of AI development":
|
||||
- Define expected behavior BEFORE implementation
|
||||
- Run evals continuously during development
|
||||
- Track regressions with each change
|
||||
- Use pass@k metrics for reliability measurement
|
||||
|
||||
## Eval Types
|
||||
|
||||
### Capability Evals
|
||||
Test if Claude can do something it couldn't before:
|
||||
```markdown
|
||||
[CAPABILITY EVAL: feature-name]
|
||||
Task: Description of what Claude should accomplish
|
||||
Success Criteria:
|
||||
- [ ] Criterion 1
|
||||
- [ ] Criterion 2
|
||||
- [ ] Criterion 3
|
||||
Expected Output: Description of expected result
|
||||
```
|
||||
|
||||
### Regression Evals
|
||||
Ensure changes don't break existing functionality:
|
||||
```markdown
|
||||
[REGRESSION EVAL: feature-name]
|
||||
Baseline: SHA or checkpoint name
|
||||
Tests:
|
||||
- existing-test-1: PASS/FAIL
|
||||
- existing-test-2: PASS/FAIL
|
||||
- existing-test-3: PASS/FAIL
|
||||
Result: X/Y passed (previously Y/Y)
|
||||
```
|
||||
|
||||
## Grader Types
|
||||
|
||||
### 1. Code-Based Grader
|
||||
Deterministic checks using code:
|
||||
```bash
|
||||
# Check if file contains expected pattern
|
||||
grep -q "export function handleAuth" src/auth.ts && echo "PASS" || echo "FAIL"
|
||||
|
||||
# Check if tests pass
|
||||
npm test -- --testPathPattern="auth" && echo "PASS" || echo "FAIL"
|
||||
|
||||
# Check if build succeeds
|
||||
npm run build && echo "PASS" || echo "FAIL"
|
||||
```
|
||||
|
||||
### 2. Model-Based Grader
|
||||
Use Claude to evaluate open-ended outputs:
|
||||
```markdown
|
||||
[MODEL GRADER PROMPT]
|
||||
Evaluate the following code change:
|
||||
1. Does it solve the stated problem?
|
||||
2. Is it well-structured?
|
||||
3. Are edge cases handled?
|
||||
4. Is error handling appropriate?
|
||||
|
||||
Score: 1-5 (1=poor, 5=excellent)
|
||||
Reasoning: [explanation]
|
||||
```
|
||||
|
||||
### 3. Human Grader
|
||||
Flag for manual review:
|
||||
```markdown
|
||||
[HUMAN REVIEW REQUIRED]
|
||||
Change: Description of what changed
|
||||
Reason: Why human review is needed
|
||||
Risk Level: LOW/MEDIUM/HIGH
|
||||
```
|
||||
|
||||
## Metrics
|
||||
|
||||
### pass@k
|
||||
"At least one success in k attempts"
|
||||
- pass@1: First attempt success rate
|
||||
- pass@3: Success within 3 attempts
|
||||
- Typical target: pass@3 > 90%
|
||||
|
||||
### pass^k
|
||||
"All k trials succeed"
|
||||
- Higher bar for reliability
|
||||
- pass^3: 3 consecutive successes
|
||||
- Use for critical paths
|
||||
|
||||
## Eval Workflow
|
||||
|
||||
### 1. Define (Before Coding)
|
||||
```markdown
|
||||
## EVAL DEFINITION: feature-xyz
|
||||
|
||||
### Capability Evals
|
||||
1. Can create new user account
|
||||
2. Can validate email format
|
||||
3. Can hash password securely
|
||||
|
||||
### Regression Evals
|
||||
1. Existing login still works
|
||||
2. Session management unchanged
|
||||
3. Logout flow intact
|
||||
|
||||
### Success Metrics
|
||||
- pass@3 > 90% for capability evals
|
||||
- pass^3 = 100% for regression evals
|
||||
```
|
||||
|
||||
### 2. Implement
|
||||
Write code to pass the defined evals.
|
||||
|
||||
### 3. Evaluate
|
||||
```bash
|
||||
# Run capability evals
|
||||
[Run each capability eval, record PASS/FAIL]
|
||||
|
||||
# Run regression evals
|
||||
npm test -- --testPathPattern="existing"
|
||||
|
||||
# Generate report
|
||||
```
|
||||
|
||||
### 4. Report
|
||||
```markdown
|
||||
EVAL REPORT: feature-xyz
|
||||
========================
|
||||
|
||||
Capability Evals:
|
||||
create-user: PASS (pass@1)
|
||||
validate-email: PASS (pass@2)
|
||||
hash-password: PASS (pass@1)
|
||||
Overall: 3/3 passed
|
||||
|
||||
Regression Evals:
|
||||
login-flow: PASS
|
||||
session-mgmt: PASS
|
||||
logout-flow: PASS
|
||||
Overall: 3/3 passed
|
||||
|
||||
Metrics:
|
||||
pass@1: 67% (2/3)
|
||||
pass@3: 100% (3/3)
|
||||
|
||||
Status: READY FOR REVIEW
|
||||
```
|
||||
|
||||
## Integration Patterns
|
||||
|
||||
### Pre-Implementation
|
||||
```
|
||||
/eval define feature-name
|
||||
```
|
||||
Creates eval definition file at `.claude/evals/feature-name.md`
|
||||
|
||||
### During Implementation
|
||||
```
|
||||
/eval check feature-name
|
||||
```
|
||||
Runs current evals and reports status
|
||||
|
||||
### Post-Implementation
|
||||
```
|
||||
/eval report feature-name
|
||||
```
|
||||
Generates full eval report
|
||||
|
||||
## Eval Storage
|
||||
|
||||
Store evals in project:
|
||||
```
|
||||
.claude/
|
||||
evals/
|
||||
feature-xyz.md # Eval definition
|
||||
feature-xyz.log # Eval run history
|
||||
baseline.json # Regression baselines
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Define evals BEFORE coding** - Forces clear thinking about success criteria
|
||||
2. **Run evals frequently** - Catch regressions early
|
||||
3. **Track pass@k over time** - Monitor reliability trends
|
||||
4. **Use code graders when possible** - Deterministic > probabilistic
|
||||
5. **Human review for security** - Never fully automate security checks
|
||||
6. **Keep evals fast** - Slow evals don't get run
|
||||
7. **Version evals with code** - Evals are first-class artifacts
|
||||
|
||||
## Example: Adding Authentication
|
||||
|
||||
```markdown
|
||||
## EVAL: add-authentication
|
||||
|
||||
### Phase 1: Define (10 min)
|
||||
Capability Evals:
|
||||
- [ ] User can register with email/password
|
||||
- [ ] User can login with valid credentials
|
||||
- [ ] Invalid credentials rejected with proper error
|
||||
- [ ] Sessions persist across page reloads
|
||||
- [ ] Logout clears session
|
||||
|
||||
Regression Evals:
|
||||
- [ ] Public routes still accessible
|
||||
- [ ] API responses unchanged
|
||||
- [ ] Database schema compatible
|
||||
|
||||
### Phase 2: Implement (varies)
|
||||
[Write code]
|
||||
|
||||
### Phase 3: Evaluate
|
||||
Run: /eval check add-authentication
|
||||
|
||||
### Phase 4: Report
|
||||
EVAL REPORT: add-authentication
|
||||
==============================
|
||||
Capability: 5/5 passed (pass@3: 100%)
|
||||
Regression: 3/3 passed (pass^3: 100%)
|
||||
Status: SHIP IT
|
||||
```
|
||||
631
.opencode/skills/frontend-patterns/SKILL.md
Normal file
631
.opencode/skills/frontend-patterns/SKILL.md
Normal file
@@ -0,0 +1,631 @@
|
||||
---
|
||||
name: frontend-patterns
|
||||
description: Frontend development patterns for React, Next.js, state management, performance optimization, and UI best practices.
|
||||
---
|
||||
|
||||
# Frontend Development Patterns
|
||||
|
||||
Modern frontend patterns for React, Next.js, and performant user interfaces.
|
||||
|
||||
## Component Patterns
|
||||
|
||||
### Composition Over Inheritance
|
||||
|
||||
```typescript
|
||||
// ✅ GOOD: Component composition
|
||||
interface CardProps {
|
||||
children: React.ReactNode
|
||||
variant?: 'default' | 'outlined'
|
||||
}
|
||||
|
||||
export function Card({ children, variant = 'default' }: CardProps) {
|
||||
return <div className={`card card-${variant}`}>{children}</div>
|
||||
}
|
||||
|
||||
export function CardHeader({ children }: { children: React.ReactNode }) {
|
||||
return <div className="card-header">{children}</div>
|
||||
}
|
||||
|
||||
export function CardBody({ children }: { children: React.ReactNode }) {
|
||||
return <div className="card-body">{children}</div>
|
||||
}
|
||||
|
||||
// Usage
|
||||
<Card>
|
||||
<CardHeader>Title</CardHeader>
|
||||
<CardBody>Content</CardBody>
|
||||
</Card>
|
||||
```
|
||||
|
||||
### Compound Components
|
||||
|
||||
```typescript
|
||||
interface TabsContextValue {
|
||||
activeTab: string
|
||||
setActiveTab: (tab: string) => void
|
||||
}
|
||||
|
||||
const TabsContext = createContext<TabsContextValue | undefined>(undefined)
|
||||
|
||||
export function Tabs({ children, defaultTab }: {
|
||||
children: React.ReactNode
|
||||
defaultTab: string
|
||||
}) {
|
||||
const [activeTab, setActiveTab] = useState(defaultTab)
|
||||
|
||||
return (
|
||||
<TabsContext.Provider value={{ activeTab, setActiveTab }}>
|
||||
{children}
|
||||
</TabsContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function TabList({ children }: { children: React.ReactNode }) {
|
||||
return <div className="tab-list">{children}</div>
|
||||
}
|
||||
|
||||
export function Tab({ id, children }: { id: string, children: React.ReactNode }) {
|
||||
const context = useContext(TabsContext)
|
||||
if (!context) throw new Error('Tab must be used within Tabs')
|
||||
|
||||
return (
|
||||
<button
|
||||
className={context.activeTab === id ? 'active' : ''}
|
||||
onClick={() => context.setActiveTab(id)}
|
||||
>
|
||||
{children}
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
// Usage
|
||||
<Tabs defaultTab="overview">
|
||||
<TabList>
|
||||
<Tab id="overview">Overview</Tab>
|
||||
<Tab id="details">Details</Tab>
|
||||
</TabList>
|
||||
</Tabs>
|
||||
```
|
||||
|
||||
### Render Props Pattern
|
||||
|
||||
```typescript
|
||||
interface DataLoaderProps<T> {
|
||||
url: string
|
||||
children: (data: T | null, loading: boolean, error: Error | null) => React.ReactNode
|
||||
}
|
||||
|
||||
export function DataLoader<T>({ url, children }: DataLoaderProps<T>) {
|
||||
const [data, setData] = useState<T | null>(null)
|
||||
const [loading, setLoading] = useState(true)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
fetch(url)
|
||||
.then(res => res.json())
|
||||
.then(setData)
|
||||
.catch(setError)
|
||||
.finally(() => setLoading(false))
|
||||
}, [url])
|
||||
|
||||
return <>{children(data, loading, error)}</>
|
||||
}
|
||||
|
||||
// Usage
|
||||
<DataLoader<Market[]> url="/api/markets">
|
||||
{(markets, loading, error) => {
|
||||
if (loading) return <Spinner />
|
||||
if (error) return <Error error={error} />
|
||||
return <MarketList markets={markets!} />
|
||||
}}
|
||||
</DataLoader>
|
||||
```
|
||||
|
||||
## Custom Hooks Patterns
|
||||
|
||||
### State Management Hook
|
||||
|
||||
```typescript
|
||||
export function useToggle(initialValue = false): [boolean, () => void] {
|
||||
const [value, setValue] = useState(initialValue)
|
||||
|
||||
const toggle = useCallback(() => {
|
||||
setValue(v => !v)
|
||||
}, [])
|
||||
|
||||
return [value, toggle]
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [isOpen, toggleOpen] = useToggle()
|
||||
```
|
||||
|
||||
### Async Data Fetching Hook
|
||||
|
||||
```typescript
|
||||
interface UseQueryOptions<T> {
|
||||
onSuccess?: (data: T) => void
|
||||
onError?: (error: Error) => void
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export function useQuery<T>(
|
||||
key: string,
|
||||
fetcher: () => Promise<T>,
|
||||
options?: UseQueryOptions<T>
|
||||
) {
|
||||
const [data, setData] = useState<T | null>(null)
|
||||
const [error, setError] = useState<Error | null>(null)
|
||||
const [loading, setLoading] = useState(false)
|
||||
|
||||
const refetch = useCallback(async () => {
|
||||
setLoading(true)
|
||||
setError(null)
|
||||
|
||||
try {
|
||||
const result = await fetcher()
|
||||
setData(result)
|
||||
options?.onSuccess?.(result)
|
||||
} catch (err) {
|
||||
const error = err as Error
|
||||
setError(error)
|
||||
options?.onError?.(error)
|
||||
} finally {
|
||||
setLoading(false)
|
||||
}
|
||||
}, [fetcher, options])
|
||||
|
||||
useEffect(() => {
|
||||
if (options?.enabled !== false) {
|
||||
refetch()
|
||||
}
|
||||
}, [key, refetch, options?.enabled])
|
||||
|
||||
return { data, error, loading, refetch }
|
||||
}
|
||||
|
||||
// Usage
|
||||
const { data: markets, loading, error, refetch } = useQuery(
|
||||
'markets',
|
||||
() => fetch('/api/markets').then(r => r.json()),
|
||||
{
|
||||
onSuccess: data => console.log('Fetched', data.length, 'markets'),
|
||||
onError: err => console.error('Failed:', err)
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
### Debounce Hook
|
||||
|
||||
```typescript
|
||||
export function useDebounce<T>(value: T, delay: number): T {
|
||||
const [debouncedValue, setDebouncedValue] = useState<T>(value)
|
||||
|
||||
useEffect(() => {
|
||||
const handler = setTimeout(() => {
|
||||
setDebouncedValue(value)
|
||||
}, delay)
|
||||
|
||||
return () => clearTimeout(handler)
|
||||
}, [value, delay])
|
||||
|
||||
return debouncedValue
|
||||
}
|
||||
|
||||
// Usage
|
||||
const [searchQuery, setSearchQuery] = useState('')
|
||||
const debouncedQuery = useDebounce(searchQuery, 500)
|
||||
|
||||
useEffect(() => {
|
||||
if (debouncedQuery) {
|
||||
performSearch(debouncedQuery)
|
||||
}
|
||||
}, [debouncedQuery])
|
||||
```
|
||||
|
||||
## State Management Patterns
|
||||
|
||||
### Context + Reducer Pattern
|
||||
|
||||
```typescript
|
||||
interface State {
|
||||
markets: Market[]
|
||||
selectedMarket: Market | null
|
||||
loading: boolean
|
||||
}
|
||||
|
||||
type Action =
|
||||
| { type: 'SET_MARKETS'; payload: Market[] }
|
||||
| { type: 'SELECT_MARKET'; payload: Market }
|
||||
| { type: 'SET_LOADING'; payload: boolean }
|
||||
|
||||
function reducer(state: State, action: Action): State {
|
||||
switch (action.type) {
|
||||
case 'SET_MARKETS':
|
||||
return { ...state, markets: action.payload }
|
||||
case 'SELECT_MARKET':
|
||||
return { ...state, selectedMarket: action.payload }
|
||||
case 'SET_LOADING':
|
||||
return { ...state, loading: action.payload }
|
||||
default:
|
||||
return state
|
||||
}
|
||||
}
|
||||
|
||||
const MarketContext = createContext<{
|
||||
state: State
|
||||
dispatch: Dispatch<Action>
|
||||
} | undefined>(undefined)
|
||||
|
||||
export function MarketProvider({ children }: { children: React.ReactNode }) {
|
||||
const [state, dispatch] = useReducer(reducer, {
|
||||
markets: [],
|
||||
selectedMarket: null,
|
||||
loading: false
|
||||
})
|
||||
|
||||
return (
|
||||
<MarketContext.Provider value={{ state, dispatch }}>
|
||||
{children}
|
||||
</MarketContext.Provider>
|
||||
)
|
||||
}
|
||||
|
||||
export function useMarkets() {
|
||||
const context = useContext(MarketContext)
|
||||
if (!context) throw new Error('useMarkets must be used within MarketProvider')
|
||||
return context
|
||||
}
|
||||
```
|
||||
|
||||
## Performance Optimization
|
||||
|
||||
### Memoization
|
||||
|
||||
```typescript
|
||||
// ✅ useMemo for expensive computations
|
||||
const sortedMarkets = useMemo(() => {
|
||||
return markets.sort((a, b) => b.volume - a.volume)
|
||||
}, [markets])
|
||||
|
||||
// ✅ useCallback for functions passed to children
|
||||
const handleSearch = useCallback((query: string) => {
|
||||
setSearchQuery(query)
|
||||
}, [])
|
||||
|
||||
// ✅ React.memo for pure components
|
||||
export const MarketCard = React.memo<MarketCardProps>(({ market }) => {
|
||||
return (
|
||||
<div className="market-card">
|
||||
<h3>{market.name}</h3>
|
||||
<p>{market.description}</p>
|
||||
</div>
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
### Code Splitting & Lazy Loading
|
||||
|
||||
```typescript
|
||||
import { lazy, Suspense } from 'react'
|
||||
|
||||
// ✅ Lazy load heavy components
|
||||
const HeavyChart = lazy(() => import('./HeavyChart'))
|
||||
const ThreeJsBackground = lazy(() => import('./ThreeJsBackground'))
|
||||
|
||||
export function Dashboard() {
|
||||
return (
|
||||
<div>
|
||||
<Suspense fallback={<ChartSkeleton />}>
|
||||
<HeavyChart data={data} />
|
||||
</Suspense>
|
||||
|
||||
<Suspense fallback={null}>
|
||||
<ThreeJsBackground />
|
||||
</Suspense>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Virtualization for Long Lists
|
||||
|
||||
```typescript
|
||||
import { useVirtualizer } from '@tanstack/react-virtual'
|
||||
|
||||
export function VirtualMarketList({ markets }: { markets: Market[] }) {
|
||||
const parentRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
const virtualizer = useVirtualizer({
|
||||
count: markets.length,
|
||||
getScrollElement: () => parentRef.current,
|
||||
estimateSize: () => 100, // Estimated row height
|
||||
overscan: 5 // Extra items to render
|
||||
})
|
||||
|
||||
return (
|
||||
<div ref={parentRef} style={{ height: '600px', overflow: 'auto' }}>
|
||||
<div
|
||||
style={{
|
||||
height: `${virtualizer.getTotalSize()}px`,
|
||||
position: 'relative'
|
||||
}}
|
||||
>
|
||||
{virtualizer.getVirtualItems().map(virtualRow => (
|
||||
<div
|
||||
key={virtualRow.index}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: 0,
|
||||
left: 0,
|
||||
width: '100%',
|
||||
height: `${virtualRow.size}px`,
|
||||
transform: `translateY(${virtualRow.start}px)`
|
||||
}}
|
||||
>
|
||||
<MarketCard market={markets[virtualRow.index]} />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Form Handling Patterns
|
||||
|
||||
### Controlled Form with Validation
|
||||
|
||||
```typescript
|
||||
interface FormData {
|
||||
name: string
|
||||
description: string
|
||||
endDate: string
|
||||
}
|
||||
|
||||
interface FormErrors {
|
||||
name?: string
|
||||
description?: string
|
||||
endDate?: string
|
||||
}
|
||||
|
||||
export function CreateMarketForm() {
|
||||
const [formData, setFormData] = useState<FormData>({
|
||||
name: '',
|
||||
description: '',
|
||||
endDate: ''
|
||||
})
|
||||
|
||||
const [errors, setErrors] = useState<FormErrors>({})
|
||||
|
||||
const validate = (): boolean => {
|
||||
const newErrors: FormErrors = {}
|
||||
|
||||
if (!formData.name.trim()) {
|
||||
newErrors.name = 'Name is required'
|
||||
} else if (formData.name.length > 200) {
|
||||
newErrors.name = 'Name must be under 200 characters'
|
||||
}
|
||||
|
||||
if (!formData.description.trim()) {
|
||||
newErrors.description = 'Description is required'
|
||||
}
|
||||
|
||||
if (!formData.endDate) {
|
||||
newErrors.endDate = 'End date is required'
|
||||
}
|
||||
|
||||
setErrors(newErrors)
|
||||
return Object.keys(newErrors).length === 0
|
||||
}
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault()
|
||||
|
||||
if (!validate()) return
|
||||
|
||||
try {
|
||||
await createMarket(formData)
|
||||
// Success handling
|
||||
} catch (error) {
|
||||
// Error handling
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<form onSubmit={handleSubmit}>
|
||||
<input
|
||||
value={formData.name}
|
||||
onChange={e => setFormData(prev => ({ ...prev, name: e.target.value }))}
|
||||
placeholder="Market name"
|
||||
/>
|
||||
{errors.name && <span className="error">{errors.name}</span>}
|
||||
|
||||
{/* Other fields */}
|
||||
|
||||
<button type="submit">Create Market</button>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Error Boundary Pattern
|
||||
|
||||
```typescript
|
||||
interface ErrorBoundaryState {
|
||||
hasError: boolean
|
||||
error: Error | null
|
||||
}
|
||||
|
||||
export class ErrorBoundary extends React.Component<
|
||||
{ children: React.ReactNode },
|
||||
ErrorBoundaryState
|
||||
> {
|
||||
state: ErrorBoundaryState = {
|
||||
hasError: false,
|
||||
error: null
|
||||
}
|
||||
|
||||
static getDerivedStateFromError(error: Error): ErrorBoundaryState {
|
||||
return { hasError: true, error }
|
||||
}
|
||||
|
||||
componentDidCatch(error: Error, errorInfo: React.ErrorInfo) {
|
||||
console.error('Error boundary caught:', error, errorInfo)
|
||||
}
|
||||
|
||||
render() {
|
||||
if (this.state.hasError) {
|
||||
return (
|
||||
<div className="error-fallback">
|
||||
<h2>Something went wrong</h2>
|
||||
<p>{this.state.error?.message}</p>
|
||||
<button onClick={() => this.setState({ hasError: false })}>
|
||||
Try again
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return this.props.children
|
||||
}
|
||||
}
|
||||
|
||||
// Usage
|
||||
<ErrorBoundary>
|
||||
<App />
|
||||
</ErrorBoundary>
|
||||
```
|
||||
|
||||
## Animation Patterns
|
||||
|
||||
### Framer Motion Animations
|
||||
|
||||
```typescript
|
||||
import { motion, AnimatePresence } from 'framer-motion'
|
||||
|
||||
// ✅ List animations
|
||||
export function AnimatedMarketList({ markets }: { markets: Market[] }) {
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{markets.map(market => (
|
||||
<motion.div
|
||||
key={market.id}
|
||||
initial={{ opacity: 0, y: 20 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
exit={{ opacity: 0, y: -20 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<MarketCard market={market} />
|
||||
</motion.div>
|
||||
))}
|
||||
</AnimatePresence>
|
||||
)
|
||||
}
|
||||
|
||||
// ✅ Modal animations
|
||||
export function Modal({ isOpen, onClose, children }: ModalProps) {
|
||||
return (
|
||||
<AnimatePresence>
|
||||
{isOpen && (
|
||||
<>
|
||||
<motion.div
|
||||
className="modal-overlay"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
onClick={onClose}
|
||||
/>
|
||||
<motion.div
|
||||
className="modal-content"
|
||||
initial={{ opacity: 0, scale: 0.9, y: 20 }}
|
||||
animate={{ opacity: 1, scale: 1, y: 0 }}
|
||||
exit={{ opacity: 0, scale: 0.9, y: 20 }}
|
||||
>
|
||||
{children}
|
||||
</motion.div>
|
||||
</>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
## Accessibility Patterns
|
||||
|
||||
### Keyboard Navigation
|
||||
|
||||
```typescript
|
||||
export function Dropdown({ options, onSelect }: DropdownProps) {
|
||||
const [isOpen, setIsOpen] = useState(false)
|
||||
const [activeIndex, setActiveIndex] = useState(0)
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
switch (e.key) {
|
||||
case 'ArrowDown':
|
||||
e.preventDefault()
|
||||
setActiveIndex(i => Math.min(i + 1, options.length - 1))
|
||||
break
|
||||
case 'ArrowUp':
|
||||
e.preventDefault()
|
||||
setActiveIndex(i => Math.max(i - 1, 0))
|
||||
break
|
||||
case 'Enter':
|
||||
e.preventDefault()
|
||||
onSelect(options[activeIndex])
|
||||
setIsOpen(false)
|
||||
break
|
||||
case 'Escape':
|
||||
setIsOpen(false)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
role="combobox"
|
||||
aria-expanded={isOpen}
|
||||
aria-haspopup="listbox"
|
||||
onKeyDown={handleKeyDown}
|
||||
>
|
||||
{/* Dropdown implementation */}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
### Focus Management
|
||||
|
||||
```typescript
|
||||
export function Modal({ isOpen, onClose, children }: ModalProps) {
|
||||
const modalRef = useRef<HTMLDivElement>(null)
|
||||
const previousFocusRef = useRef<HTMLElement | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
// Save currently focused element
|
||||
previousFocusRef.current = document.activeElement as HTMLElement
|
||||
|
||||
// Focus modal
|
||||
modalRef.current?.focus()
|
||||
} else {
|
||||
// Restore focus when closing
|
||||
previousFocusRef.current?.focus()
|
||||
}
|
||||
}, [isOpen])
|
||||
|
||||
return isOpen ? (
|
||||
<div
|
||||
ref={modalRef}
|
||||
role="dialog"
|
||||
aria-modal="true"
|
||||
tabIndex={-1}
|
||||
onKeyDown={e => e.key === 'Escape' && onClose()}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
) : null
|
||||
}
|
||||
```
|
||||
|
||||
**Remember**: Modern frontend patterns enable maintainable, performant user interfaces. Choose patterns that fit your project complexity.
|
||||
335
.opencode/skills/product-spec-builder/SKILL.md
Normal file
335
.opencode/skills/product-spec-builder/SKILL.md
Normal file
@@ -0,0 +1,335 @@
|
||||
---
|
||||
name: product-spec-builder
|
||||
description: 当用户表达想要开发产品、应用、工具或任何软件项目时,或者用户想要迭代现有功能、新增需求、修改产品规格时,使用此技能。0-1 阶段通过深入对话收集需求并生成 Product Spec;迭代阶段帮助用户想清楚变更内容并更新现有 Product Spec。
|
||||
---
|
||||
|
||||
[角色]
|
||||
你是废才,一位看透无数产品生死的资深产品经理。
|
||||
|
||||
你见过太多人带着"改变世界"的妄想来找你,最后连需求都说不清楚。
|
||||
你也见过真正能成事的人——他们不一定聪明,但足够诚实,敢于面对自己想法的漏洞。
|
||||
|
||||
你不是来讨好用户的。你是来帮他们把脑子里的浆糊变成可执行的产品文档的。
|
||||
如果他们的想法有问题,你会直接说。如果他们在自欺欺人,你会戳破。
|
||||
|
||||
你的冷酷不是恶意,是效率。情绪是最好的思考燃料,而你擅长点火。
|
||||
|
||||
[任务]
|
||||
**0-1 模式**:通过深入对话收集用户的产品需求,用直白甚至刺耳的追问逼迫用户想清楚,最终生成一份结构完整、细节丰富、可直接用于 AI 开发的 Product Spec 文档,并输出为 .md 文件供用户下载使用。
|
||||
|
||||
**迭代模式**:当用户在开发过程中提出新功能、修改需求或迭代想法时,通过追问帮助用户想清楚变更内容,检测与现有 Spec 的冲突,直接更新 Product Spec 文件,并自动记录变更日志。
|
||||
|
||||
[第一性原则]
|
||||
**AI优先原则**:用户提出的所有功能,首先考虑如何用 AI 来实现。
|
||||
|
||||
- 遇到任何功能需求,第一反应是:这个能不能用 AI 做?能做到什么程度?
|
||||
- 主动询问用户:这个功能要不要加一个「AI一键优化」或「AI智能推荐」?
|
||||
- 如果用户描述的功能明显可以用 AI 增强,直接建议,不要等用户想到
|
||||
- 最终输出的 Product Spec 必须明确列出需要的 AI 能力类型
|
||||
|
||||
**简单优先原则**:复杂度是产品的敌人。
|
||||
|
||||
- 能用现成服务的,不自己造轮子
|
||||
- 每增加一个功能都要问「真的需要吗」
|
||||
- 第一版做最小可行产品,验证了再加功能
|
||||
|
||||
[技能]
|
||||
- **需求挖掘**:通过开放式提问引导用户表达想法,捕捉关键信息
|
||||
- **追问深挖**:针对模糊描述追问细节,不接受"大概"、"可能"、"应该"
|
||||
- **AI能力识别**:根据功能需求,识别需要的 AI 能力类型(文本、图像、语音等)
|
||||
- **技术需求引导**:通过业务问题推断技术需求,帮助无编程基础的用户理解技术选择
|
||||
- **布局设计**:深入挖掘界面布局需求,确保每个页面有清晰的空间规范
|
||||
- **漏洞识别**:发现用户想法中的矛盾、遗漏、自欺欺人之处,直接指出
|
||||
- **冲突检测**:在迭代时检测新需求与现有 Spec 的冲突,主动指出并给出解决方案
|
||||
- **方案引导**:当用户不知道怎么做时,提供 2-3 个选项 + 优劣分析,逼用户选择
|
||||
- **结构化思维**:将零散信息整理为清晰的产品框架
|
||||
- **文档输出**:按照标准模板生成专业的 Product Spec,输出为 .md 文件
|
||||
|
||||
[文件结构]
|
||||
```
|
||||
product-spec-builder/
|
||||
├── SKILL.md # 主 Skill 定义(本文件)
|
||||
└── templates/
|
||||
├── product-spec-template.md # Product Spec 输出模板
|
||||
└── changelog-template.md # 变更记录模板
|
||||
```
|
||||
|
||||
[输出风格]
|
||||
**语态**:
|
||||
- 直白、冷静,偶尔带着看透世事的冷漠
|
||||
- 不奉承、不迎合、不说"这个想法很棒"之类的废话
|
||||
- 该嘲讽时嘲讽,该肯定时也会肯定(但很少)
|
||||
|
||||
**原则**:
|
||||
- × 绝不给模棱两可的废话
|
||||
- × 绝不假装用户的想法没问题(如果有问题就直接说)
|
||||
- × 绝不浪费时间在无意义的客套上
|
||||
- ✓ 一针见血的建议,哪怕听起来刺耳
|
||||
- ✓ 用追问逼迫用户自己想清楚,而不是替他们想
|
||||
- ✓ 主动建议 AI 增强方案,不等用户开口
|
||||
- ✓ 偶尔的毒舌是为了激发思考,不是为了伤害
|
||||
|
||||
**典型表达**:
|
||||
- "你说的这个功能,用户真的需要,还是你觉得他们需要?"
|
||||
- "这个手动操作完全可以让 AI 来做,你为什么要让用户自己填?"
|
||||
- "别跟我说'用户体验好',告诉我具体好在哪里。"
|
||||
- "你现在描述的这个东西,市面上已经有十个了。你的凭什么能活?"
|
||||
- "这里要不要加个 AI 一键优化?用户自己填这些参数,你觉得他们填得好吗?"
|
||||
- "左边放什么右边放什么,你想清楚了吗?还是打算让开发自己猜?"
|
||||
- "想清楚了?那我们继续。没想清楚?那就继续想。"
|
||||
|
||||
[需求维度清单]
|
||||
在对话过程中,需要收集以下维度的信息(不必按顺序,根据对话自然推进):
|
||||
|
||||
**必须收集**(没有这些,Product Spec 就是废纸):
|
||||
- 产品定位:这是什么?解决什么问题?凭什么是你来做?
|
||||
- 目标用户:谁会用?为什么用?不用会死吗?
|
||||
- 核心功能:必须有什么功能?砍掉什么功能产品就不成立?
|
||||
- 用户流程:用户怎么用?从打开到完成任务的完整路径是什么?
|
||||
- AI能力需求:哪些功能需要 AI?需要哪种类型的 AI 能力?
|
||||
|
||||
**尽量收集**(有这些,Product Spec 才能落地):
|
||||
- 整体布局:几栏布局?左右还是上下?各区域比例多少?
|
||||
- 区域内容:每个区域放什么?哪个是输入区,哪个是输出区?
|
||||
- 控件规范:输入框铺满还是定宽?按钮放哪里?下拉框选项有哪些?
|
||||
- 输入输出:用户输入什么?系统输出什么?格式是什么?
|
||||
- 应用场景:3-5个具体场景,越具体越好
|
||||
- AI增强点:哪些地方可以加「AI一键优化」或「AI智能推荐」?
|
||||
- 技术复杂度:需要用户登录吗?数据存哪里?需要服务器吗?
|
||||
|
||||
**可选收集**(锦上添花):
|
||||
- 技术偏好:有没有特定技术要求?
|
||||
- 参考产品:有没有可以抄的对象?抄哪里,不抄哪里?
|
||||
- 优先级:第一期做什么,第二期做什么?
|
||||
|
||||
[对话策略]
|
||||
**开场策略**:
|
||||
- 不废话,直接基于用户已表达的内容开始追问
|
||||
- 让用户先倒完脑子里的东西,再开始解剖
|
||||
|
||||
**追问策略**:
|
||||
- 每次只追问 1-2 个问题,问题要直击要害
|
||||
- 不接受模糊回答:"大概"、"可能"、"应该"、"用户会喜欢的" → 追问到底
|
||||
- 发现逻辑漏洞,直接指出,不留情面
|
||||
- 发现用户在自嗨,冷静泼冷水
|
||||
- 当用户说"界面你看着办"或"随便",不惯着,用具体选项逼他们决策
|
||||
- 布局必须问到具体:几栏、比例、各区域内容、控件规范
|
||||
|
||||
**方案引导策略**:
|
||||
- 用户知道但没说清楚 → 继续逼问,不给方案
|
||||
- 用户真不知道 → 给 2-3 个选项 + 各自优劣,根据产品类型给针对性建议
|
||||
- 给完继续逼他选,选完继续逼下一个细节
|
||||
- 选项是工具,不是退路
|
||||
|
||||
**AI能力引导策略**:
|
||||
- 每当用户描述一个功能,主动思考:这个能不能用 AI 做?
|
||||
- 主动询问:"这里要不要加个 AI 一键XX?"
|
||||
- 用户设计了繁琐的手动流程 → 直接建议用 AI 简化
|
||||
- 对话后期,主动总结需要的 AI 能力类型
|
||||
|
||||
**技术需求引导策略**:
|
||||
- 用户没有编程基础,不直接问技术问题,通过业务场景推断技术需求
|
||||
- 遵循简单优先原则,能不加复杂度就不加
|
||||
- 用户想要的功能会大幅增加复杂度时,先劝退或建议分期
|
||||
|
||||
**确认策略**:
|
||||
- 定期复述已收集的信息,发现矛盾直接质问
|
||||
- 信息够了就推进,不拖泥带水
|
||||
- 用户说"差不多了"但信息明显不够,继续问
|
||||
|
||||
**搜索策略**:
|
||||
- 涉及可能变化的信息(技术、行业、竞品),先上网搜索再开口
|
||||
|
||||
[信息充足度判断]
|
||||
当以下条件满足时,可以生成 Product Spec:
|
||||
|
||||
**必须满足**:
|
||||
- ✅ 产品定位清晰(能用一句人话说明白这是什么)
|
||||
- ✅ 目标用户明确(知道给谁用、为什么用)
|
||||
- ✅ 核心功能明确(至少3个功能点,且能说清楚为什么需要)
|
||||
- ✅ 用户流程清晰(至少一条完整路径,从头到尾)
|
||||
- ✅ AI能力需求明确(知道哪些功能需要 AI,用什么类型的 AI)
|
||||
|
||||
**尽量满足**:
|
||||
- ✅ 整体布局有方向(知道大概是什么结构)
|
||||
- ✅ 控件有基本规范(主要输入输出方式清楚)
|
||||
|
||||
如果「必须满足」条件未达成,继续追问,不要勉强生成一份垃圾文档。
|
||||
如果「尽量满足」条件未达成,可以生成但标注 [待补充]。
|
||||
|
||||
[启动检查]
|
||||
Skill 启动时,首先执行以下检查:
|
||||
|
||||
第一步:扫描项目目录,按优先级查找产品需求文档
|
||||
优先级1(精确匹配):Product-Spec.md
|
||||
优先级2(扩大匹配):*spec*.md、*prd*.md、*PRD*.md、*需求*.md、*product*.md
|
||||
|
||||
匹配规则:
|
||||
- 找到 1 个文件 → 直接使用
|
||||
- 找到多个候选文件 → 列出文件名问用户"你要改的是哪个?"
|
||||
- 没找到 → 进入 0-1 模式
|
||||
|
||||
第二步:判断模式
|
||||
- 找到产品需求文档 → 进入 **迭代模式**
|
||||
- 没找到 → 进入 **0-1 模式**
|
||||
|
||||
第三步:执行对应流程
|
||||
- 0-1 模式:执行 [工作流程(0-1模式)]
|
||||
- 迭代模式:执行 [工作流程(迭代模式)]
|
||||
|
||||
[工作流程(0-1模式)]
|
||||
[需求探索阶段]
|
||||
目的:让用户把脑子里的东西倒出来
|
||||
|
||||
第一步:接住用户
|
||||
**先上网搜索**:根据用户表达的产品想法上网搜索相关信息,了解最新情况
|
||||
基于用户已经表达的内容,直接开始追问
|
||||
不重复问"你想做什么",用户已经说过了
|
||||
|
||||
第二步:追问
|
||||
**先上网搜索**:根据用户表达的内容上网搜索相关信息,确保追问基于最新知识
|
||||
针对模糊、矛盾、自嗨的地方,直接追问
|
||||
每次1-2个问题,问到点子上
|
||||
同时思考哪些功能可以用 AI 增强
|
||||
|
||||
第三步:阶段性确认
|
||||
复述理解,确认没跑偏
|
||||
有问题当场纠正
|
||||
|
||||
[需求完善阶段]
|
||||
目的:填补漏洞,逼用户想清楚,确定 AI 能力需求和界面布局
|
||||
|
||||
第一步:漏洞识别
|
||||
对照 [需求维度清单],找出缺失的关键信息
|
||||
|
||||
第二步:逼问
|
||||
**先上网搜索**:针对缺失项上网搜索相关信息,确保给出的建议和方案是最新的
|
||||
针对缺失项设计问题
|
||||
不接受敷衍回答
|
||||
布局问题要问到具体:几栏、比例、各区域内容、控件规范
|
||||
|
||||
第三步:AI能力引导
|
||||
**先上网搜索**:上网搜索最新的 AI 能力和最佳实践,确保建议不过时
|
||||
主动询问用户:
|
||||
- "这个功能要不要加 AI 一键优化?"
|
||||
- "这里让用户手动填,还是让 AI 智能推荐?"
|
||||
根据用户需求识别需要的 AI 能力类型(文本生成、图像生成、图像识别等)
|
||||
|
||||
第四步:技术复杂度评估
|
||||
**先上网搜索**:上网搜索相关技术方案,确保建议是最新的
|
||||
根据 [技术需求引导] 策略,通过业务问题判断技术复杂度
|
||||
如果用户想要的功能会大幅增加复杂度,先劝退或建议分期
|
||||
确保用户理解技术选择的影响
|
||||
|
||||
第五步:充足度判断
|
||||
对照 [信息充足度判断]
|
||||
「必须满足」都达成 → 提议生成
|
||||
未达成 → 继续问,不惯着
|
||||
|
||||
[文档生成阶段]
|
||||
目的:输出可用的 Product Spec 文件
|
||||
|
||||
第一步:整理
|
||||
将对话内容按输出模板结构分类
|
||||
|
||||
第二步:填充
|
||||
加载 templates/product-spec-template.md 获取模板格式
|
||||
按模板格式填写
|
||||
「尽量满足」未达成的地方标注 [待补充]
|
||||
功能用动词开头
|
||||
UI布局要描述清楚整体结构和各区域细节
|
||||
流程写清楚步骤
|
||||
|
||||
第三步:识别AI能力需求
|
||||
根据功能需求识别所需的 AI 能力类型
|
||||
在「AI 能力需求」部分列出
|
||||
说明每种能力在本产品中的具体用途
|
||||
|
||||
第四步:输出文件
|
||||
将 Product Spec 保存为 Product-Spec.md
|
||||
|
||||
[工作流程(迭代模式)]
|
||||
**触发条件**:用户在开发过程中提出新功能、修改需求或迭代想法
|
||||
|
||||
**核心原则**:无缝衔接,不打断用户工作流。不需要开场白,直接接住用户的需求往下问。
|
||||
|
||||
[变更识别阶段]
|
||||
目的:搞清楚用户要改什么
|
||||
|
||||
第一步:接住需求
|
||||
**先上网搜索**:根据用户提出的变更内容上网搜索相关信息,确保追问基于最新知识
|
||||
用户说"我觉得应该还要有一个AI一键推荐功能"
|
||||
直接追问:"AI一键推荐什么?推荐给谁?这个按钮放哪个页面?点了之后发生什么?"
|
||||
|
||||
第二步:判断变更类型
|
||||
根据 [迭代模式-追问深度判断] 确定这是重度、中度还是轻度变更
|
||||
决定追问深度
|
||||
|
||||
[追问完善阶段]
|
||||
目的:问到能直接改 Spec 为止
|
||||
|
||||
第一步:按深度追问
|
||||
**先上网搜索**:每次追问前上网搜索相关信息,确保问题和建议基于最新知识
|
||||
重度变更:问到能回答"这个变更会怎么影响现有产品"
|
||||
中度变更:问到能回答"具体改成什么样"
|
||||
轻度变更:确认理解正确即可
|
||||
|
||||
第二步:用户卡住时给方案
|
||||
**先上网搜索**:给方案前上网搜索最新的解决方案和最佳实践
|
||||
用户不知道怎么做 → 给 2-3 个选项 + 优劣
|
||||
给完继续逼他选,选完继续逼下一个细节
|
||||
|
||||
第三步:冲突检测
|
||||
加载现有 Product-Spec.md
|
||||
检查新需求是否与现有内容冲突
|
||||
发现冲突 → 直接指出冲突点 + 给解决方案 + 让用户选
|
||||
|
||||
**停止追问的标准**:
|
||||
- 能够直接动手改 Product Spec,不需要再猜或假设
|
||||
- 改完之后用户不会说"不是这个意思"
|
||||
|
||||
[文档更新阶段]
|
||||
目的:更新 Product Spec 并记录变更
|
||||
|
||||
第一步:理解现有文档结构
|
||||
加载现有 Spec 文件
|
||||
识别其章节结构(可能和模板不同)
|
||||
后续修改基于现有结构,不强行套用模板
|
||||
|
||||
第二步:直接修改源文件
|
||||
在现有 Spec 上直接修改
|
||||
保持文档整体结构不变
|
||||
只改需要改的部分
|
||||
|
||||
第三步:更新 AI 能力需求
|
||||
如果涉及新的 AI 功能:
|
||||
- 在「AI 能力需求」章节添加新能力类型
|
||||
- 说明新能力的用途
|
||||
|
||||
第四步:自动追加变更记录
|
||||
在 Product-Spec-CHANGELOG.md 中追加本次变更
|
||||
如果 CHANGELOG 文件不存在,创建一个
|
||||
记录 Product Spec 迭代变更时,加载 templates/changelog-template.md 获取完整的变更记录格式和示例
|
||||
根据对话内容自动生成变更描述
|
||||
|
||||
[迭代模式-追问深度判断]
|
||||
**变更类型判断逻辑**(按顺序检查):
|
||||
1. 涉及新 AI 能力?→ 重度
|
||||
2. 涉及用户核心路径变更?→ 重度
|
||||
3. 涉及布局结构(几栏、区域划分)?→ 重度
|
||||
4. 新增主要功能模块?→ 重度
|
||||
5. 涉及新功能但不改核心流程?→ 中度
|
||||
6. 涉及现有功能的逻辑调整?→ 中度
|
||||
7. 局部布局调整?→ 中度
|
||||
8. 只是改文字、选项、样式?→ 轻度
|
||||
|
||||
**各类型追问标准**:
|
||||
|
||||
| 变更类型 | 停止追问的条件 | 必须问清楚的内容 |
|
||||
|---------|---------------|----------------|
|
||||
| **重度** | 能回答"这个变更会怎么影响现有产品"时停止 | 为什么需要?影响哪些现有功能?用户流程怎么变?需要什么新的 AI 能力? |
|
||||
| **中度** | 能回答"具体改成什么样"时停止 | 改哪里?改成什么?和现有的怎么配合? |
|
||||
| **轻度** | 确认理解正确时停止 | 改什么?改成什么? |
|
||||
|
||||
[初始化]
|
||||
执行 [启动检查]
|
||||
@@ -0,0 +1,111 @@
|
||||
---
|
||||
name: changelog-template
|
||||
description: 变更记录模板。当 Product Spec 发生迭代变更时,按照此模板格式记录变更历史,输出为 Product-Spec-CHANGELOG.md 文件。
|
||||
---
|
||||
|
||||
# 变更记录模板
|
||||
|
||||
本模板用于记录 Product Spec 的迭代变更历史。
|
||||
|
||||
---
|
||||
|
||||
## 文件命名
|
||||
|
||||
`Product-Spec-CHANGELOG.md`
|
||||
|
||||
---
|
||||
|
||||
## 模板格式
|
||||
|
||||
```markdown
|
||||
# 变更记录
|
||||
|
||||
## [v1.2] - YYYY-MM-DD
|
||||
### 新增
|
||||
- <新增的功能或内容>
|
||||
|
||||
### 修改
|
||||
- <修改的功能或内容>
|
||||
|
||||
### 删除
|
||||
- <删除的功能或内容>
|
||||
|
||||
---
|
||||
|
||||
## [v1.1] - YYYY-MM-DD
|
||||
### 新增
|
||||
- <新增的功能或内容>
|
||||
|
||||
---
|
||||
|
||||
## [v1.0] - YYYY-MM-DD
|
||||
- 初始版本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 记录规则
|
||||
|
||||
- **版本号递增**:每次迭代 +0.1(如 v1.0 → v1.1 → v1.2)
|
||||
- **日期自动填充**:使用当天日期,格式 YYYY-MM-DD
|
||||
- **变更描述**:根据对话内容自动生成,简明扼要
|
||||
- **分类记录**:新增、修改、删除分开写,没有的分类不写
|
||||
- **只记录实际改动**:没改的部分不记录
|
||||
- **新增控件要写位置**:涉及 UI 变更时,说明控件放在哪里
|
||||
|
||||
---
|
||||
|
||||
## 完整示例
|
||||
|
||||
以下是「剧本分镜生成器」的变更记录示例,供参考:
|
||||
|
||||
```markdown
|
||||
# 变更记录
|
||||
|
||||
## [v1.2] - 2025-12-08
|
||||
### 新增
|
||||
- 新增「AI 优化描述」按钮(角色设定区底部),点击后自动优化角色和场景的描述文字
|
||||
- 新增分镜描述显示,每张分镜图下方展示 AI 生成的画面描述
|
||||
|
||||
### 修改
|
||||
- 左侧输入区比例从 35% 改为 40%
|
||||
- 「生成分镜」按钮样式改为更醒目的主色调
|
||||
|
||||
---
|
||||
|
||||
## [v1.1] - 2025-12-05
|
||||
### 新增
|
||||
- 新增「场景设定」功能区(角色设定区下方),用户可上传场景参考图建立视觉档案
|
||||
- 新增「水墨」画风选项
|
||||
- 新增图像理解能力,用于分析用户上传的参考图
|
||||
|
||||
### 修改
|
||||
- 角色卡片布局优化,参考图预览尺寸从 80px 改为 120px
|
||||
|
||||
### 删除
|
||||
- 移除「自动分页」功能(用户反馈更希望手动控制分页节奏)
|
||||
|
||||
---
|
||||
|
||||
## [v1.0] - 2025-12-01
|
||||
- 初始版本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 写作要点
|
||||
|
||||
1. **版本号**:从 v1.0 开始,每次迭代 +0.1,重大改版可以 +1.0
|
||||
2. **日期格式**:统一用 YYYY-MM-DD,方便排序和查找
|
||||
3. **变更描述**:
|
||||
- 动词开头(新增、修改、删除、移除、调整)
|
||||
- 说清楚改了什么、改成什么样
|
||||
- 新增控件要写位置(如「角色设定区底部」)
|
||||
- 数值变更要写前后对比(如「从 35% 改为 40%」)
|
||||
- 如果有原因,简要说明(如「用户反馈不需要」)
|
||||
4. **分类原则**:
|
||||
- 新增:之前没有的功能、控件、能力
|
||||
- 修改:改变了现有内容的行为、样式、参数
|
||||
- 删除:移除了之前有的功能
|
||||
5. **颗粒度**:一条记录对应一个独立的变更点,不要把多个改动混在一起
|
||||
6. **AI 能力变更**:如果新增或移除了 AI 能力,必须单独记录
|
||||
@@ -0,0 +1,197 @@
|
||||
---
|
||||
name: product-spec-template
|
||||
description: Product Spec 输出模板。当需要生成产品需求文档时,按照此模板的结构和格式填充内容,输出为 Product-Spec.md 文件。
|
||||
---
|
||||
|
||||
# Product Spec 输出模板
|
||||
|
||||
本模板用于生成结构完整的 Product Spec 文档。生成时按照此结构填充内容。
|
||||
|
||||
---
|
||||
|
||||
## 模板结构
|
||||
|
||||
**文件命名**:Product-Spec.md
|
||||
|
||||
---
|
||||
|
||||
## 产品概述
|
||||
<一段话说清楚:>
|
||||
- 这是什么产品
|
||||
- 解决什么问题
|
||||
- **目标用户是谁**(具体描述,不要只说「用户」)
|
||||
- 核心价值是什么
|
||||
|
||||
## 应用场景
|
||||
<列举 3-5 个具体场景:谁、在什么情况下、怎么用、解决什么问题>
|
||||
|
||||
## 功能需求
|
||||
<按「核心功能」和「辅助功能」分类,每条功能说明:用户做什么 → 系统做什么 → 得到什么>
|
||||
|
||||
## UI 布局
|
||||
<描述整体布局结构和各区域的详细设计,需要包含:>
|
||||
- 整体是什么布局(几栏、比例、固定元素等)
|
||||
- 每个区域放什么内容
|
||||
- 控件的具体规范(位置、尺寸、样式等)
|
||||
|
||||
## 用户使用流程
|
||||
<分步骤描述用户如何使用产品,可以有多条路径(如快速上手、进阶使用)>
|
||||
|
||||
## AI 能力需求
|
||||
|
||||
| 能力类型 | 用途说明 | 应用位置 |
|
||||
|---------|---------|---------|
|
||||
| <能力类型> | <做什么> | <在哪个环节触发> |
|
||||
|
||||
## 技术说明(可选)
|
||||
<如果涉及以下内容,需要说明:>
|
||||
- 数据存储:是否需要登录?数据存在哪里?
|
||||
- 外部依赖:需要调用什么服务?有什么限制?
|
||||
- 部署方式:纯前端?需要服务器?
|
||||
|
||||
## 补充说明
|
||||
<如有需要,用表格说明选项、状态、逻辑等>
|
||||
|
||||
---
|
||||
|
||||
## 完整示例
|
||||
|
||||
以下是一个「剧本分镜生成器」的 Product Spec 示例,供参考:
|
||||
|
||||
```markdown
|
||||
## 产品概述
|
||||
|
||||
这是一个帮助漫画作者、短视频创作者、动画团队将剧本快速转化为分镜图的工具。
|
||||
|
||||
**目标用户**:有剧本但缺乏绘画能力、或者想快速出分镜草稿的创作者。他们可能是独立漫画作者、短视频博主、动画工作室的前期策划人员,共同的痛点是「脑子里有画面,但画不出来或画太慢」。
|
||||
|
||||
**核心价值**:用户只需输入剧本文本、上传角色和场景参考图、选择画风,AI 就会自动分析剧本结构,生成保持视觉一致性的分镜图,将原本需要数小时的分镜绘制工作缩短到几分钟。
|
||||
|
||||
## 应用场景
|
||||
|
||||
- **漫画创作**:独立漫画作者小王有一个 20 页的剧本,需要先出分镜草稿再精修。他把剧本贴进来,上传主角的参考图,10 分钟就拿到了全部分镜草稿,可以直接在这个基础上精修。
|
||||
|
||||
- **短视频策划**:短视频博主小李要拍一个 3 分钟的剧情短片,需要给摄影师看分镜。她把脚本输入,选择「写实」风格,生成的分镜图直接可以当拍摄参考。
|
||||
|
||||
- **动画前期**:动画工作室要向客户提案,需要快速出一版分镜来展示剧本节奏。策划人员用这个工具 30 分钟出了 50 张分镜图,当天就能开提案会。
|
||||
|
||||
- **小说可视化**:网文作者想给自己的小说做宣传图,把关键场景描述输入,生成的分镜图可以直接用于社交媒体宣传。
|
||||
|
||||
- **教学演示**:小学语文老师想把一篇课文变成连环画给学生看,把课文内容输入,选择「动漫」风格,生成的图片可以直接做成 PPT。
|
||||
|
||||
## 功能需求
|
||||
|
||||
**核心功能**
|
||||
- 剧本输入与分析:用户输入剧本文本 → 点击「生成分镜」→ AI 自动识别角色、场景和情节节拍,将剧本拆分为多页分镜
|
||||
- 角色设定:用户添加角色卡片(名称 + 外观描述 + 参考图)→ 系统建立角色视觉档案,后续生成时保持外观一致
|
||||
- 场景设定:用户添加场景卡片(名称 + 氛围描述 + 参考图)→ 系统建立场景视觉档案(可选,不设定则由 AI 根据剧本生成)
|
||||
- 画风选择:用户从下拉框选择画风(漫画/动漫/写实/赛博朋克/水墨)→ 生成的分镜图采用对应视觉风格
|
||||
- 分镜生成:用户点击「生成分镜」→ AI 生成当前页 9 张分镜图(3x3 九宫格)→ 展示在右侧输出区
|
||||
- 连续生成:用户点击「继续生成下一页」→ AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
|
||||
|
||||
**辅助功能**
|
||||
- 批量下载:用户点击「下载全部」→ 系统将当前页 9 张图打包为 ZIP 下载
|
||||
- 历史浏览:用户通过页面导航 → 切换查看已生成的历史页面
|
||||
|
||||
## UI 布局
|
||||
|
||||
### 整体布局
|
||||
左右两栏布局,左侧输入区占 40%,右侧输出区占 60%。
|
||||
|
||||
### 左侧 - 输入区
|
||||
- 顶部:项目名称输入框
|
||||
- 剧本输入:多行文本框,placeholder「请输入剧本内容...」
|
||||
- 角色设定区:
|
||||
- 角色卡片列表,每张卡片包含:角色名、外观描述、参考图上传
|
||||
- 「添加角色」按钮
|
||||
- 场景设定区:
|
||||
- 场景卡片列表,每张卡片包含:场景名、氛围描述、参考图上传
|
||||
- 「添加场景」按钮
|
||||
- 画风选择:下拉选择(漫画 / 动漫 / 写实 / 赛博朋克 / 水墨),默认「动漫」
|
||||
- 底部:「生成分镜」主按钮,靠右对齐,醒目样式
|
||||
|
||||
### 右侧 - 输出区
|
||||
- 分镜图展示区:3x3 网格布局,展示 9 张独立分镜图
|
||||
- 每张分镜图下方显示:分镜编号、简要描述
|
||||
- 操作按钮:「下载全部」「继续生成下一页」
|
||||
- 页面导航:显示当前页数,支持切换查看历史页面
|
||||
|
||||
## 用户使用流程
|
||||
|
||||
### 首次生成
|
||||
1. 输入剧本内容
|
||||
2. 添加角色:填写名称、外观描述,上传参考图
|
||||
3. 添加场景:填写名称、氛围描述,上传参考图(可选)
|
||||
4. 选择画风
|
||||
5. 点击「生成分镜」
|
||||
6. 在右侧查看生成的 9 张分镜图
|
||||
7. 点击「下载全部」保存
|
||||
|
||||
### 连续生成
|
||||
1. 完成首次生成后
|
||||
2. 点击「继续生成下一页」
|
||||
3. AI 基于前一页的画风和角色外观,生成下一页 9 张分镜图
|
||||
4. 重复直到剧本完成
|
||||
|
||||
## AI 能力需求
|
||||
|
||||
| 能力类型 | 用途说明 | 应用位置 |
|
||||
|---------|---------|---------|
|
||||
| 文本理解与生成 | 分析剧本结构,识别角色、场景、情节节拍,规划分镜内容 | 点击「生成分镜」时 |
|
||||
| 图像生成 | 根据分镜描述生成 3x3 九宫格分镜图 | 点击「生成分镜」「继续生成下一页」时 |
|
||||
| 图像理解 | 分析用户上传的角色和场景参考图,提取视觉特征用于保持一致性 | 上传角色/场景参考图时 |
|
||||
|
||||
## 技术说明
|
||||
|
||||
- **数据存储**:无需登录,项目数据保存在浏览器本地存储(LocalStorage),关闭页面后仍可恢复
|
||||
- **图像生成**:调用 AI 图像生成服务,每次生成 9 张图约需 30-60 秒
|
||||
- **文件导出**:支持 PNG 格式批量下载,打包为 ZIP 文件
|
||||
- **部署方式**:纯前端应用,无需服务器,可部署到任意静态托管平台
|
||||
|
||||
## 补充说明
|
||||
|
||||
| 选项 | 可选值 | 说明 |
|
||||
|------|--------|------|
|
||||
| 画风 | 漫画 / 动漫 / 写实 / 赛博朋克 / 水墨 | 决定分镜图的整体视觉风格 |
|
||||
| 角色参考图 | 图片上传 | 用于建立角色视觉身份,确保一致性 |
|
||||
| 场景参考图 | 图片上传(可选) | 用于建立场景氛围,不上传则由 AI 根据描述生成 |
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 写作要点
|
||||
|
||||
1. **产品概述**:
|
||||
- 一句话说清楚是什么
|
||||
- **必须明确写出目标用户**:是谁、有什么特点、什么痛点
|
||||
- 核心价值:用了这个产品能得到什么
|
||||
|
||||
2. **应用场景**:
|
||||
- 具体的人 + 具体的情况 + 具体的用法 + 解决什么问题
|
||||
- 场景要有画面感,让人一看就懂
|
||||
- 放在功能需求之前,帮助理解产品价值
|
||||
|
||||
3. **功能需求**:
|
||||
- 分「核心功能」和「辅助功能」
|
||||
- 每条格式:用户做什么 → 系统做什么 → 得到什么
|
||||
- 写清楚触发方式(点击什么按钮)
|
||||
|
||||
4. **UI 布局**:
|
||||
- 先写整体布局(几栏、比例)
|
||||
- 再逐个区域描述内容
|
||||
- 控件要具体:下拉框写出所有选项和默认值,按钮写明位置和样式
|
||||
|
||||
5. **用户流程**:分步骤,可以有多条路径
|
||||
|
||||
6. **AI 能力需求**:
|
||||
- 列出需要的 AI 能力类型
|
||||
- 说明具体用途
|
||||
- **写清楚在哪个环节触发**,方便开发理解调用时机
|
||||
|
||||
7. **技术说明**(可选):
|
||||
- 数据存储方式
|
||||
- 外部服务依赖
|
||||
- 部署方式
|
||||
- 只在有技术约束时写,没有就不写
|
||||
|
||||
8. **补充说明**:用表格,适合解释选项、状态、逻辑
|
||||
345
.opencode/skills/project-guidelines-example/SKILL.md
Normal file
345
.opencode/skills/project-guidelines-example/SKILL.md
Normal file
@@ -0,0 +1,345 @@
|
||||
# Project Guidelines Skill (Example)
|
||||
|
||||
This is an example of a project-specific skill. Use this as a template for your own projects.
|
||||
|
||||
Based on a real production application: [Zenith](https://zenith.chat) - AI-powered customer discovery platform.
|
||||
|
||||
---
|
||||
|
||||
## When to Use
|
||||
|
||||
Reference this skill when working on the specific project it's designed for. Project skills contain:
|
||||
- Architecture overview
|
||||
- File structure
|
||||
- Code patterns
|
||||
- Testing requirements
|
||||
- Deployment workflow
|
||||
|
||||
---
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
**Tech Stack:**
|
||||
- **Frontend**: Next.js 15 (App Router), TypeScript, React
|
||||
- **Backend**: FastAPI (Python), Pydantic models
|
||||
- **Database**: Supabase (PostgreSQL)
|
||||
- **AI**: Claude API with tool calling and structured output
|
||||
- **Deployment**: Google Cloud Run
|
||||
- **Testing**: Playwright (E2E), pytest (backend), React Testing Library
|
||||
|
||||
**Services:**
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Frontend │
|
||||
│ Next.js 15 + TypeScript + TailwindCSS │
|
||||
│ Deployed: Vercel / Cloud Run │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Backend │
|
||||
│ FastAPI + Python 3.11 + Pydantic │
|
||||
│ Deployed: Cloud Run │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌───────────────┼───────────────┐
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ Supabase │ │ Claude │ │ Redis │
|
||||
│ Database │ │ API │ │ Cache │
|
||||
└──────────┘ └──────────┘ └──────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
project/
|
||||
├── frontend/
|
||||
│ └── src/
|
||||
│ ├── app/ # Next.js app router pages
|
||||
│ │ ├── api/ # API routes
|
||||
│ │ ├── (auth)/ # Auth-protected routes
|
||||
│ │ └── workspace/ # Main app workspace
|
||||
│ ├── components/ # React components
|
||||
│ │ ├── ui/ # Base UI components
|
||||
│ │ ├── forms/ # Form components
|
||||
│ │ └── layouts/ # Layout components
|
||||
│ ├── hooks/ # Custom React hooks
|
||||
│ ├── lib/ # Utilities
|
||||
│ ├── types/ # TypeScript definitions
|
||||
│ └── config/ # Configuration
|
||||
│
|
||||
├── backend/
|
||||
│ ├── routers/ # FastAPI route handlers
|
||||
│ ├── models.py # Pydantic models
|
||||
│ ├── main.py # FastAPI app entry
|
||||
│ ├── auth_system.py # Authentication
|
||||
│ ├── database.py # Database operations
|
||||
│ ├── services/ # Business logic
|
||||
│ └── tests/ # pytest tests
|
||||
│
|
||||
├── deploy/ # Deployment configs
|
||||
├── docs/ # Documentation
|
||||
└── scripts/ # Utility scripts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Code Patterns
|
||||
|
||||
### API Response Format (FastAPI)
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel
|
||||
from typing import Generic, TypeVar, Optional
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
class ApiResponse(BaseModel, Generic[T]):
|
||||
success: bool
|
||||
data: Optional[T] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def ok(cls, data: T) -> "ApiResponse[T]":
|
||||
return cls(success=True, data=data)
|
||||
|
||||
@classmethod
|
||||
def fail(cls, error: str) -> "ApiResponse[T]":
|
||||
return cls(success=False, error=error)
|
||||
```
|
||||
|
||||
### Frontend API Calls (TypeScript)
|
||||
|
||||
```typescript
|
||||
interface ApiResponse<T> {
|
||||
success: boolean
|
||||
data?: T
|
||||
error?: string
|
||||
}
|
||||
|
||||
async function fetchApi<T>(
|
||||
endpoint: string,
|
||||
options?: RequestInit
|
||||
): Promise<ApiResponse<T>> {
|
||||
try {
|
||||
const response = await fetch(`/api${endpoint}`, {
|
||||
...options,
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...options?.headers,
|
||||
},
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
return { success: false, error: `HTTP ${response.status}` }
|
||||
}
|
||||
|
||||
return await response.json()
|
||||
} catch (error) {
|
||||
return { success: false, error: String(error) }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Claude AI Integration (Structured Output)
|
||||
|
||||
```python
|
||||
from anthropic import Anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AnalysisResult(BaseModel):
|
||||
summary: str
|
||||
key_points: list[str]
|
||||
confidence: float
|
||||
|
||||
async def analyze_with_claude(content: str) -> AnalysisResult:
|
||||
client = Anthropic()
|
||||
|
||||
response = client.messages.create(
|
||||
model="claude-sonnet-4-5-20250514",
|
||||
max_tokens=1024,
|
||||
messages=[{"role": "user", "content": content}],
|
||||
tools=[{
|
||||
"name": "provide_analysis",
|
||||
"description": "Provide structured analysis",
|
||||
"input_schema": AnalysisResult.model_json_schema()
|
||||
}],
|
||||
tool_choice={"type": "tool", "name": "provide_analysis"}
|
||||
)
|
||||
|
||||
# Extract tool use result
|
||||
tool_use = next(
|
||||
block for block in response.content
|
||||
if block.type == "tool_use"
|
||||
)
|
||||
|
||||
return AnalysisResult(**tool_use.input)
|
||||
```
|
||||
|
||||
### Custom Hooks (React)
|
||||
|
||||
```typescript
|
||||
import { useState, useCallback } from 'react'
|
||||
|
||||
interface UseApiState<T> {
|
||||
data: T | null
|
||||
loading: boolean
|
||||
error: string | null
|
||||
}
|
||||
|
||||
export function useApi<T>(
|
||||
fetchFn: () => Promise<ApiResponse<T>>
|
||||
) {
|
||||
const [state, setState] = useState<UseApiState<T>>({
|
||||
data: null,
|
||||
loading: false,
|
||||
error: null,
|
||||
})
|
||||
|
||||
const execute = useCallback(async () => {
|
||||
setState(prev => ({ ...prev, loading: true, error: null }))
|
||||
|
||||
const result = await fetchFn()
|
||||
|
||||
if (result.success) {
|
||||
setState({ data: result.data!, loading: false, error: null })
|
||||
} else {
|
||||
setState({ data: null, loading: false, error: result.error! })
|
||||
}
|
||||
}, [fetchFn])
|
||||
|
||||
return { ...state, execute }
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
### Backend (pytest)
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
poetry run pytest tests/
|
||||
|
||||
# Run with coverage
|
||||
poetry run pytest tests/ --cov=. --cov-report=html
|
||||
|
||||
# Run specific test file
|
||||
poetry run pytest tests/test_auth.py -v
|
||||
```
|
||||
|
||||
**Test structure:**
|
||||
```python
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from main import app
|
||||
|
||||
@pytest.fixture
|
||||
async def client():
|
||||
async with AsyncClient(app=app, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(client: AsyncClient):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "healthy"
|
||||
```
|
||||
|
||||
### Frontend (React Testing Library)
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
npm run test
|
||||
|
||||
# Run with coverage
|
||||
npm run test -- --coverage
|
||||
|
||||
# Run E2E tests
|
||||
npm run test:e2e
|
||||
```
|
||||
|
||||
**Test structure:**
|
||||
```typescript
|
||||
import { render, screen, fireEvent } from '@testing-library/react'
|
||||
import { WorkspacePanel } from './WorkspacePanel'
|
||||
|
||||
describe('WorkspacePanel', () => {
|
||||
it('renders workspace correctly', () => {
|
||||
render(<WorkspacePanel />)
|
||||
expect(screen.getByRole('main')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles session creation', async () => {
|
||||
render(<WorkspacePanel />)
|
||||
fireEvent.click(screen.getByText('New Session'))
|
||||
expect(await screen.findByText('Session created')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Deployment Workflow
|
||||
|
||||
### Pre-Deployment Checklist
|
||||
|
||||
- [ ] All tests passing locally
|
||||
- [ ] `npm run build` succeeds (frontend)
|
||||
- [ ] `poetry run pytest` passes (backend)
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] Environment variables documented
|
||||
- [ ] Database migrations ready
|
||||
|
||||
### Deployment Commands
|
||||
|
||||
```bash
|
||||
# Build and deploy frontend
|
||||
cd frontend && npm run build
|
||||
gcloud run deploy frontend --source .
|
||||
|
||||
# Build and deploy backend
|
||||
cd backend
|
||||
gcloud run deploy backend --source .
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Frontend (.env.local)
|
||||
NEXT_PUBLIC_API_URL=https://api.example.com
|
||||
NEXT_PUBLIC_SUPABASE_URL=https://xxx.supabase.co
|
||||
NEXT_PUBLIC_SUPABASE_ANON_KEY=eyJ...
|
||||
|
||||
# Backend (.env)
|
||||
DATABASE_URL=postgresql://...
|
||||
ANTHROPIC_API_KEY=sk-ant-...
|
||||
SUPABASE_URL=https://xxx.supabase.co
|
||||
SUPABASE_KEY=eyJ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Critical Rules
|
||||
|
||||
1. **No emojis** in code, comments, or documentation
|
||||
2. **Immutability** - never mutate objects or arrays
|
||||
3. **TDD** - write tests before implementation
|
||||
4. **80% coverage** minimum
|
||||
5. **Many small files** - 200-400 lines typical, 800 max
|
||||
6. **No console.log** in production code
|
||||
7. **Proper error handling** with try/catch
|
||||
8. **Input validation** with Pydantic/Zod
|
||||
|
||||
---
|
||||
|
||||
## Related Skills
|
||||
|
||||
- `coding-standards.md` - General coding best practices
|
||||
- `backend-patterns.md` - API and database patterns
|
||||
- `frontend-patterns.md` - React and Next.js patterns
|
||||
- `tdd-workflow/` - Test-driven development methodology
|
||||
568
.opencode/skills/security-review/SKILL.md
Normal file
568
.opencode/skills/security-review/SKILL.md
Normal file
@@ -0,0 +1,568 @@
|
||||
---
|
||||
name: security-review
|
||||
description: Use this skill when adding authentication, handling user input, working with secrets, creating API endpoints, or implementing payment/sensitive features. Provides comprehensive security checklist and patterns.
|
||||
---
|
||||
|
||||
# Security Review Skill
|
||||
|
||||
Security best practices for Python/FastAPI applications handling sensitive invoice data.
|
||||
|
||||
## When to Activate
|
||||
|
||||
- Implementing authentication or authorization
|
||||
- Handling user input or file uploads
|
||||
- Creating new API endpoints
|
||||
- Working with secrets or credentials
|
||||
- Processing sensitive invoice data
|
||||
- Integrating third-party APIs
|
||||
- Database operations with user data
|
||||
|
||||
## Security Checklist
|
||||
|
||||
### 1. Secrets Management
|
||||
|
||||
#### NEVER Do This
|
||||
```python
|
||||
# Hardcoded secrets - CRITICAL VULNERABILITY
|
||||
api_key = "sk-proj-xxxxx"
|
||||
db_password = "password123"
|
||||
```
|
||||
|
||||
#### ALWAYS Do This
|
||||
```python
|
||||
import os
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
class Settings(BaseSettings):
|
||||
db_password: str
|
||||
api_key: str
|
||||
model_path: str = "runs/train/invoice_fields/weights/best.pt"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# Verify secrets exist
|
||||
if not settings.db_password:
|
||||
raise RuntimeError("DB_PASSWORD not configured")
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] No hardcoded API keys, tokens, or passwords
|
||||
- [ ] All secrets in environment variables
|
||||
- [ ] `.env` in .gitignore
|
||||
- [ ] No secrets in git history
|
||||
- [ ] `.env.example` with placeholder values
|
||||
|
||||
### 2. Input Validation
|
||||
|
||||
#### Always Validate User Input
|
||||
```python
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from fastapi import HTTPException
|
||||
import re
|
||||
|
||||
class InvoiceRequest(BaseModel):
|
||||
invoice_number: str = Field(..., min_length=1, max_length=50)
|
||||
amount: float = Field(..., gt=0, le=1_000_000)
|
||||
bankgiro: str | None = None
|
||||
|
||||
@field_validator("invoice_number")
|
||||
@classmethod
|
||||
def validate_invoice_number(cls, v: str) -> str:
|
||||
# Whitelist validation - only allow safe characters
|
||||
if not re.match(r"^[A-Za-z0-9\-_]+$", v):
|
||||
raise ValueError("Invalid invoice number format")
|
||||
return v
|
||||
|
||||
@field_validator("bankgiro")
|
||||
@classmethod
|
||||
def validate_bankgiro(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return None
|
||||
cleaned = re.sub(r"[^0-9]", "", v)
|
||||
if not (7 <= len(cleaned) <= 8):
|
||||
raise ValueError("Bankgiro must be 7-8 digits")
|
||||
return cleaned
|
||||
```
|
||||
|
||||
#### File Upload Validation
|
||||
```python
|
||||
from fastapi import UploadFile, HTTPException
|
||||
from pathlib import Path
|
||||
|
||||
ALLOWED_EXTENSIONS = {".pdf"}
|
||||
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
|
||||
async def validate_pdf_upload(file: UploadFile) -> bytes:
|
||||
"""Validate PDF upload with security checks."""
|
||||
# Extension check
|
||||
ext = Path(file.filename or "").suffix.lower()
|
||||
if ext not in ALLOWED_EXTENSIONS:
|
||||
raise HTTPException(400, f"Only PDF files allowed, got {ext}")
|
||||
|
||||
# Read content
|
||||
content = await file.read()
|
||||
|
||||
# Size check
|
||||
if len(content) > MAX_FILE_SIZE:
|
||||
raise HTTPException(400, f"File too large (max {MAX_FILE_SIZE // 1024 // 1024}MB)")
|
||||
|
||||
# Magic bytes check (PDF signature)
|
||||
if not content.startswith(b"%PDF"):
|
||||
raise HTTPException(400, "Invalid PDF file format")
|
||||
|
||||
return content
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] All user inputs validated with Pydantic
|
||||
- [ ] File uploads restricted (size, type, extension, magic bytes)
|
||||
- [ ] No direct use of user input in queries
|
||||
- [ ] Whitelist validation (not blacklist)
|
||||
- [ ] Error messages don't leak sensitive info
|
||||
|
||||
### 3. SQL Injection Prevention
|
||||
|
||||
#### NEVER Concatenate SQL
|
||||
```python
|
||||
# DANGEROUS - SQL Injection vulnerability
|
||||
query = f"SELECT * FROM documents WHERE id = '{user_input}'"
|
||||
cur.execute(query)
|
||||
```
|
||||
|
||||
#### ALWAYS Use Parameterized Queries
|
||||
```python
|
||||
import psycopg2
|
||||
|
||||
# Safe - parameterized query with %s placeholders
|
||||
cur.execute(
|
||||
"SELECT * FROM documents WHERE id = %s AND status = %s",
|
||||
(document_id, status)
|
||||
)
|
||||
|
||||
# Safe - named parameters
|
||||
cur.execute(
|
||||
"SELECT * FROM documents WHERE id = %(id)s",
|
||||
{"id": document_id}
|
||||
)
|
||||
|
||||
# Safe - psycopg2.sql for dynamic identifiers
|
||||
from psycopg2 import sql
|
||||
|
||||
cur.execute(
|
||||
sql.SQL("SELECT {} FROM {} WHERE id = %s").format(
|
||||
sql.Identifier("invoice_number"),
|
||||
sql.Identifier("documents")
|
||||
),
|
||||
(document_id,)
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] All database queries use parameterized queries (%s or %(name)s)
|
||||
- [ ] No string concatenation or f-strings in SQL
|
||||
- [ ] psycopg2.sql module used for dynamic identifiers
|
||||
- [ ] No user input in table/column names
|
||||
|
||||
### 4. Path Traversal Prevention
|
||||
|
||||
#### NEVER Trust User Paths
|
||||
```python
|
||||
# DANGEROUS - Path traversal vulnerability
|
||||
filename = request.query_params.get("file")
|
||||
with open(f"/data/{filename}", "r") as f: # Attacker: ../../../etc/passwd
|
||||
return f.read()
|
||||
```
|
||||
|
||||
#### ALWAYS Validate Paths
|
||||
```python
|
||||
from pathlib import Path
|
||||
|
||||
ALLOWED_DIR = Path("/data/uploads").resolve()
|
||||
|
||||
def get_safe_path(filename: str) -> Path:
|
||||
"""Get safe file path, preventing path traversal."""
|
||||
# Remove any path components
|
||||
safe_name = Path(filename).name
|
||||
|
||||
# Validate filename characters
|
||||
if not re.match(r"^[A-Za-z0-9_\-\.]+$", safe_name):
|
||||
raise HTTPException(400, "Invalid filename")
|
||||
|
||||
# Resolve and verify within allowed directory
|
||||
full_path = (ALLOWED_DIR / safe_name).resolve()
|
||||
|
||||
if not full_path.is_relative_to(ALLOWED_DIR):
|
||||
raise HTTPException(400, "Invalid file path")
|
||||
|
||||
return full_path
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] User-provided filenames sanitized
|
||||
- [ ] Paths resolved and validated against allowed directory
|
||||
- [ ] No direct concatenation of user input into paths
|
||||
- [ ] Whitelist characters in filenames
|
||||
|
||||
### 5. Authentication & Authorization
|
||||
|
||||
#### API Key Validation
|
||||
```python
|
||||
from fastapi import Depends, HTTPException, Security
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
async def verify_api_key(api_key: str = Security(api_key_header)) -> str:
|
||||
if not api_key:
|
||||
raise HTTPException(401, "API key required")
|
||||
|
||||
# Constant-time comparison to prevent timing attacks
|
||||
import hmac
|
||||
if not hmac.compare_digest(api_key, settings.api_key):
|
||||
raise HTTPException(403, "Invalid API key")
|
||||
|
||||
return api_key
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
api_key: str = Depends(verify_api_key)
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Role-Based Access Control
|
||||
```python
|
||||
from enum import Enum
|
||||
|
||||
class UserRole(str, Enum):
|
||||
USER = "user"
|
||||
ADMIN = "admin"
|
||||
|
||||
def require_role(required_role: UserRole):
|
||||
async def role_checker(current_user: User = Depends(get_current_user)):
|
||||
if current_user.role != required_role:
|
||||
raise HTTPException(403, "Insufficient permissions")
|
||||
return current_user
|
||||
return role_checker
|
||||
|
||||
@router.delete("/documents/{doc_id}")
|
||||
async def delete_document(
|
||||
doc_id: str,
|
||||
user: User = Depends(require_role(UserRole.ADMIN))
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] API keys validated with constant-time comparison
|
||||
- [ ] Authorization checks before sensitive operations
|
||||
- [ ] Role-based access control implemented
|
||||
- [ ] Session/token validation on protected routes
|
||||
|
||||
### 6. Rate Limiting
|
||||
|
||||
#### Rate Limiter Implementation
|
||||
```python
|
||||
from time import time
|
||||
from collections import defaultdict
|
||||
from fastapi import Request, HTTPException
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self):
|
||||
self.requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def check_limit(
|
||||
self,
|
||||
identifier: str,
|
||||
max_requests: int,
|
||||
window_seconds: int
|
||||
) -> bool:
|
||||
now = time()
|
||||
# Clean old requests
|
||||
self.requests[identifier] = [
|
||||
t for t in self.requests[identifier]
|
||||
if now - t < window_seconds
|
||||
]
|
||||
# Check limit
|
||||
if len(self.requests[identifier]) >= max_requests:
|
||||
return False
|
||||
self.requests[identifier].append(now)
|
||||
return True
|
||||
|
||||
limiter = RateLimiter()
|
||||
|
||||
@app.middleware("http")
|
||||
async def rate_limit_middleware(request: Request, call_next):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
# 100 requests per minute for general endpoints
|
||||
if not limiter.check_limit(client_ip, max_requests=100, window_seconds=60):
|
||||
raise HTTPException(429, "Rate limit exceeded. Try again later.")
|
||||
|
||||
return await call_next(request)
|
||||
```
|
||||
|
||||
#### Stricter Limits for Expensive Operations
|
||||
```python
|
||||
# Inference endpoint: 10 requests per minute
|
||||
async def check_inference_rate_limit(request: Request):
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
if not limiter.check_limit(f"infer:{client_ip}", max_requests=10, window_seconds=60):
|
||||
raise HTTPException(429, "Inference rate limit exceeded")
|
||||
|
||||
@router.post("/infer")
|
||||
async def infer(
|
||||
file: UploadFile,
|
||||
_: None = Depends(check_inference_rate_limit)
|
||||
):
|
||||
...
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Rate limiting on all API endpoints
|
||||
- [ ] Stricter limits on expensive operations (inference, OCR)
|
||||
- [ ] IP-based rate limiting
|
||||
- [ ] Clear error messages for rate-limited requests
|
||||
|
||||
### 7. Sensitive Data Exposure
|
||||
|
||||
#### Logging
|
||||
```python
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# WRONG: Logging sensitive data
|
||||
logger.info(f"Processing invoice: {invoice_data}") # May contain sensitive info
|
||||
logger.error(f"DB error with password: {db_password}")
|
||||
|
||||
# CORRECT: Redact sensitive data
|
||||
logger.info(f"Processing invoice: id={doc_id}")
|
||||
logger.error(f"DB connection failed to {db_host}:{db_port}")
|
||||
|
||||
# CORRECT: Structured logging with safe fields only
|
||||
logger.info(
|
||||
"Invoice processed",
|
||||
extra={
|
||||
"document_id": doc_id,
|
||||
"field_count": len(fields),
|
||||
"processing_time_ms": elapsed_ms
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
#### Error Messages
|
||||
```python
|
||||
# WRONG: Exposing internal details
|
||||
@app.exception_handler(Exception)
|
||||
async def error_handler(request: Request, exc: Exception):
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc() # NEVER expose!
|
||||
}
|
||||
)
|
||||
|
||||
# CORRECT: Generic error messages
|
||||
@app.exception_handler(Exception)
|
||||
async def error_handler(request: Request, exc: Exception):
|
||||
logger.error(f"Unhandled error: {exc}", exc_info=True) # Log internally
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"success": False, "error": "An error occurred"}
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] No passwords, tokens, or secrets in logs
|
||||
- [ ] Error messages generic for users
|
||||
- [ ] Detailed errors only in server logs
|
||||
- [ ] No stack traces exposed to users
|
||||
- [ ] Invoice data (amounts, account numbers) not logged
|
||||
|
||||
### 8. CORS Configuration
|
||||
|
||||
```python
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
# WRONG: Allow all origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # DANGEROUS in production
|
||||
allow_credentials=True,
|
||||
)
|
||||
|
||||
# CORRECT: Specific origins
|
||||
ALLOWED_ORIGINS = [
|
||||
"http://localhost:8000",
|
||||
"https://your-domain.com",
|
||||
]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=ALLOWED_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] CORS origins explicitly listed
|
||||
- [ ] No wildcard origins in production
|
||||
- [ ] Credentials only with specific origins
|
||||
|
||||
### 9. Temporary File Security
|
||||
|
||||
```python
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def secure_temp_file(suffix: str = ".pdf"):
|
||||
"""Create secure temporary file that is always cleaned up."""
|
||||
tmp_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=suffix,
|
||||
delete=False,
|
||||
dir="/tmp/invoice-master" # Dedicated temp directory
|
||||
) as tmp:
|
||||
tmp_path = Path(tmp.name)
|
||||
yield tmp_path
|
||||
finally:
|
||||
if tmp_path and tmp_path.exists():
|
||||
tmp_path.unlink()
|
||||
|
||||
# Usage
|
||||
async def process_upload(file: UploadFile):
|
||||
with secure_temp_file(".pdf") as tmp_path:
|
||||
content = await validate_pdf_upload(file)
|
||||
tmp_path.write_bytes(content)
|
||||
result = pipeline.process(tmp_path)
|
||||
# File automatically cleaned up
|
||||
return result
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Temporary files always cleaned up (use context managers)
|
||||
- [ ] Temp directory has restricted permissions
|
||||
- [ ] No leftover files after processing errors
|
||||
|
||||
### 10. Dependency Security
|
||||
|
||||
#### Regular Updates
|
||||
```bash
|
||||
# Check for vulnerabilities
|
||||
pip-audit
|
||||
|
||||
# Update dependencies
|
||||
pip install --upgrade -r requirements.txt
|
||||
|
||||
# Check for outdated packages
|
||||
pip list --outdated
|
||||
```
|
||||
|
||||
#### Lock Files
|
||||
```bash
|
||||
# Create requirements lock file
|
||||
pip freeze > requirements.lock
|
||||
|
||||
# Install from lock file for reproducible builds
|
||||
pip install -r requirements.lock
|
||||
```
|
||||
|
||||
#### Verification Steps
|
||||
- [ ] Dependencies up to date
|
||||
- [ ] No known vulnerabilities (pip-audit clean)
|
||||
- [ ] requirements.txt pinned versions
|
||||
- [ ] Regular security updates scheduled
|
||||
|
||||
## Security Testing
|
||||
|
||||
### Automated Security Tests
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
def test_requires_api_key(client: TestClient):
|
||||
"""Test authentication required."""
|
||||
response = client.post("/api/v1/infer")
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_invalid_api_key_rejected(client: TestClient):
|
||||
"""Test invalid API key rejected."""
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
headers={"X-API-Key": "invalid-key"}
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
def test_sql_injection_prevented(client: TestClient):
|
||||
"""Test SQL injection attempt rejected."""
|
||||
response = client.get(
|
||||
"/api/v1/documents",
|
||||
params={"id": "'; DROP TABLE documents; --"}
|
||||
)
|
||||
# Should return validation error, not execute SQL
|
||||
assert response.status_code in (400, 422)
|
||||
|
||||
def test_path_traversal_prevented(client: TestClient):
|
||||
"""Test path traversal attempt rejected."""
|
||||
response = client.get("/api/v1/results/../../etc/passwd")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_rate_limit_enforced(client: TestClient):
|
||||
"""Test rate limiting works."""
|
||||
responses = [
|
||||
client.post("/api/v1/infer", files={"file": b"test"})
|
||||
for _ in range(15)
|
||||
]
|
||||
rate_limited = [r for r in responses if r.status_code == 429]
|
||||
assert len(rate_limited) > 0
|
||||
|
||||
def test_large_file_rejected(client: TestClient):
|
||||
"""Test file size limit enforced."""
|
||||
large_content = b"x" * (11 * 1024 * 1024) # 11MB
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.pdf", large_content)}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
```
|
||||
|
||||
## Pre-Deployment Security Checklist
|
||||
|
||||
Before ANY production deployment:
|
||||
|
||||
- [ ] **Secrets**: No hardcoded secrets, all in env vars
|
||||
- [ ] **Input Validation**: All user inputs validated with Pydantic
|
||||
- [ ] **SQL Injection**: All queries use parameterized queries
|
||||
- [ ] **Path Traversal**: File paths validated and sanitized
|
||||
- [ ] **Authentication**: API key or token validation
|
||||
- [ ] **Authorization**: Role checks in place
|
||||
- [ ] **Rate Limiting**: Enabled on all endpoints
|
||||
- [ ] **HTTPS**: Enforced in production
|
||||
- [ ] **CORS**: Properly configured (no wildcards)
|
||||
- [ ] **Error Handling**: No sensitive data in errors
|
||||
- [ ] **Logging**: No sensitive data logged
|
||||
- [ ] **File Uploads**: Validated (size, type, magic bytes)
|
||||
- [ ] **Temp Files**: Always cleaned up
|
||||
- [ ] **Dependencies**: Up to date, no vulnerabilities
|
||||
|
||||
## Resources
|
||||
|
||||
- [OWASP Top 10](https://owasp.org/www-project-top-ten/)
|
||||
- [FastAPI Security](https://fastapi.tiangolo.com/tutorial/security/)
|
||||
- [Bandit (Python Security Linter)](https://bandit.readthedocs.io/)
|
||||
- [pip-audit](https://pypi.org/project/pip-audit/)
|
||||
|
||||
---
|
||||
|
||||
**Remember**: Security is not optional. One vulnerability can compromise sensitive invoice data. When in doubt, err on the side of caution.
|
||||
63
.opencode/skills/strategic-compact/SKILL.md
Normal file
63
.opencode/skills/strategic-compact/SKILL.md
Normal file
@@ -0,0 +1,63 @@
|
||||
---
|
||||
name: strategic-compact
|
||||
description: Suggests manual context compaction at logical intervals to preserve context through task phases rather than arbitrary auto-compaction.
|
||||
---
|
||||
|
||||
# Strategic Compact Skill
|
||||
|
||||
Suggests manual `/compact` at strategic points in your workflow rather than relying on arbitrary auto-compaction.
|
||||
|
||||
## Why Strategic Compaction?
|
||||
|
||||
Auto-compaction triggers at arbitrary points:
|
||||
- Often mid-task, losing important context
|
||||
- No awareness of logical task boundaries
|
||||
- Can interrupt complex multi-step operations
|
||||
|
||||
Strategic compaction at logical boundaries:
|
||||
- **After exploration, before execution** - Compact research context, keep implementation plan
|
||||
- **After completing a milestone** - Fresh start for next phase
|
||||
- **Before major context shifts** - Clear exploration context before different task
|
||||
|
||||
## How It Works
|
||||
|
||||
The `suggest-compact.sh` script runs on PreToolUse (Edit/Write) and:
|
||||
|
||||
1. **Tracks tool calls** - Counts tool invocations in session
|
||||
2. **Threshold detection** - Suggests at configurable threshold (default: 50 calls)
|
||||
3. **Periodic reminders** - Reminds every 25 calls after threshold
|
||||
|
||||
## Hook Setup
|
||||
|
||||
Add to your `~/.claude/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"hooks": {
|
||||
"PreToolUse": [{
|
||||
"matcher": "tool == \"Edit\" || tool == \"Write\"",
|
||||
"hooks": [{
|
||||
"type": "command",
|
||||
"command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
|
||||
}]
|
||||
}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Environment variables:
|
||||
- `COMPACT_THRESHOLD` - Tool calls before first suggestion (default: 50)
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Compact after planning** - Once plan is finalized, compact to start fresh
|
||||
2. **Compact after debugging** - Clear error-resolution context before continuing
|
||||
3. **Don't compact mid-implementation** - Preserve context for related changes
|
||||
4. **Read the suggestion** - The hook tells you *when*, you decide *if*
|
||||
|
||||
## Related
|
||||
|
||||
- [The Longform Guide](https://x.com/affaanmustafa/status/2014040193557471352) - Token optimization section
|
||||
- Memory persistence hooks - For state that survives compaction
|
||||
52
.opencode/skills/strategic-compact/suggest-compact.sh
Normal file
52
.opencode/skills/strategic-compact/suggest-compact.sh
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/bin/bash
|
||||
# Strategic Compact Suggester
|
||||
# Runs on PreToolUse or periodically to suggest manual compaction at logical intervals
|
||||
#
|
||||
# Why manual over auto-compact:
|
||||
# - Auto-compact happens at arbitrary points, often mid-task
|
||||
# - Strategic compacting preserves context through logical phases
|
||||
# - Compact after exploration, before execution
|
||||
# - Compact after completing a milestone, before starting next
|
||||
#
|
||||
# Hook config (in ~/.claude/settings.json):
|
||||
# {
|
||||
# "hooks": {
|
||||
# "PreToolUse": [{
|
||||
# "matcher": "Edit|Write",
|
||||
# "hooks": [{
|
||||
# "type": "command",
|
||||
# "command": "~/.claude/skills/strategic-compact/suggest-compact.sh"
|
||||
# }]
|
||||
# }]
|
||||
# }
|
||||
# }
|
||||
#
|
||||
# Criteria for suggesting compact:
|
||||
# - Session has been running for extended period
|
||||
# - Large number of tool calls made
|
||||
# - Transitioning from research/exploration to implementation
|
||||
# - Plan has been finalized
|
||||
|
||||
# Track tool call count (increment in a temp file)
|
||||
COUNTER_FILE="/tmp/claude-tool-count-$$"
|
||||
THRESHOLD=${COMPACT_THRESHOLD:-50}
|
||||
|
||||
# Initialize or increment counter
|
||||
if [ -f "$COUNTER_FILE" ]; then
|
||||
count=$(cat "$COUNTER_FILE")
|
||||
count=$((count + 1))
|
||||
echo "$count" > "$COUNTER_FILE"
|
||||
else
|
||||
echo "1" > "$COUNTER_FILE"
|
||||
count=1
|
||||
fi
|
||||
|
||||
# Suggest compact after threshold tool calls
|
||||
if [ "$count" -eq "$THRESHOLD" ]; then
|
||||
echo "[StrategicCompact] $THRESHOLD tool calls reached - consider /compact if transitioning phases" >&2
|
||||
fi
|
||||
|
||||
# Suggest at regular intervals after threshold
|
||||
if [ "$count" -gt "$THRESHOLD" ] && [ $((count % 25)) -eq 0 ]; then
|
||||
echo "[StrategicCompact] $count tool calls - good checkpoint for /compact if context is stale" >&2
|
||||
fi
|
||||
553
.opencode/skills/tdd-workflow/SKILL.md
Normal file
553
.opencode/skills/tdd-workflow/SKILL.md
Normal file
@@ -0,0 +1,553 @@
|
||||
---
|
||||
name: tdd-workflow
|
||||
description: Use this skill when writing new features, fixing bugs, or refactoring code. Enforces test-driven development with 80%+ coverage including unit, integration, and E2E tests.
|
||||
---
|
||||
|
||||
# Test-Driven Development Workflow
|
||||
|
||||
TDD principles for Python/FastAPI development with pytest.
|
||||
|
||||
## When to Activate
|
||||
|
||||
- Writing new features or functionality
|
||||
- Fixing bugs or issues
|
||||
- Refactoring existing code
|
||||
- Adding API endpoints
|
||||
- Creating new field extractors or normalizers
|
||||
|
||||
## Core Principles
|
||||
|
||||
### 1. Tests BEFORE Code
|
||||
ALWAYS write tests first, then implement code to make tests pass.
|
||||
|
||||
### 2. Coverage Requirements
|
||||
- Minimum 80% coverage (unit + integration + E2E)
|
||||
- All edge cases covered
|
||||
- Error scenarios tested
|
||||
- Boundary conditions verified
|
||||
|
||||
### 3. Test Types
|
||||
|
||||
#### Unit Tests
|
||||
- Individual functions and utilities
|
||||
- Normalizers and validators
|
||||
- Parsers and extractors
|
||||
- Pure functions
|
||||
|
||||
#### Integration Tests
|
||||
- API endpoints
|
||||
- Database operations
|
||||
- OCR + YOLO pipeline
|
||||
- Service interactions
|
||||
|
||||
#### E2E Tests
|
||||
- Complete inference pipeline
|
||||
- PDF → Fields workflow
|
||||
- API health and inference endpoints
|
||||
|
||||
## TDD Workflow Steps
|
||||
|
||||
### Step 1: Write User Journeys
|
||||
```
|
||||
As a [role], I want to [action], so that [benefit]
|
||||
|
||||
Example:
|
||||
As an invoice processor, I want to extract Bankgiro from payment_line,
|
||||
so that I can cross-validate OCR results.
|
||||
```
|
||||
|
||||
### Step 2: Generate Test Cases
|
||||
For each user journey, create comprehensive test cases:
|
||||
|
||||
```python
|
||||
import pytest
|
||||
|
||||
class TestPaymentLineParser:
|
||||
"""Tests for payment_line parsing and field extraction."""
|
||||
|
||||
def test_parse_payment_line_extracts_bankgiro(self):
|
||||
"""Should extract Bankgiro from valid payment line."""
|
||||
# Test implementation
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_handles_missing_checksum(self):
|
||||
"""Should handle payment lines without checksum."""
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_validates_checksum(self):
|
||||
"""Should validate checksum when present."""
|
||||
pass
|
||||
|
||||
def test_parse_payment_line_returns_none_for_invalid(self):
|
||||
"""Should return None for invalid payment lines."""
|
||||
pass
|
||||
```
|
||||
|
||||
### Step 3: Run Tests (They Should Fail)
|
||||
```bash
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
# Tests should fail - we haven't implemented yet
|
||||
```
|
||||
|
||||
### Step 4: Implement Code
|
||||
Write minimal code to make tests pass:
|
||||
|
||||
```python
|
||||
def parse_payment_line(line: str) -> PaymentLineData | None:
|
||||
"""Parse Swedish payment line and extract fields."""
|
||||
# Implementation guided by tests
|
||||
pass
|
||||
```
|
||||
|
||||
### Step 5: Run Tests Again
|
||||
```bash
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
# Tests should now pass
|
||||
```
|
||||
|
||||
### Step 6: Refactor
|
||||
Improve code quality while keeping tests green:
|
||||
- Remove duplication
|
||||
- Improve naming
|
||||
- Optimize performance
|
||||
- Enhance readability
|
||||
|
||||
### Step 7: Verify Coverage
|
||||
```bash
|
||||
pytest --cov=src --cov-report=term-missing
|
||||
# Verify 80%+ coverage achieved
|
||||
```
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
### Unit Test Pattern (pytest)
|
||||
```python
|
||||
import pytest
|
||||
from src.normalize.bankgiro_normalizer import normalize_bankgiro
|
||||
|
||||
class TestBankgiroNormalizer:
|
||||
"""Tests for Bankgiro normalization."""
|
||||
|
||||
def test_normalize_removes_hyphens(self):
|
||||
"""Should remove hyphens from Bankgiro."""
|
||||
result = normalize_bankgiro("123-4567")
|
||||
assert result == "1234567"
|
||||
|
||||
def test_normalize_removes_spaces(self):
|
||||
"""Should remove spaces from Bankgiro."""
|
||||
result = normalize_bankgiro("123 4567")
|
||||
assert result == "1234567"
|
||||
|
||||
def test_normalize_validates_length(self):
|
||||
"""Should validate Bankgiro is 7-8 digits."""
|
||||
result = normalize_bankgiro("123456") # 6 digits
|
||||
assert result is None
|
||||
|
||||
def test_normalize_validates_checksum(self):
|
||||
"""Should validate Luhn checksum."""
|
||||
result = normalize_bankgiro("1234568") # Invalid checksum
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize("input_value,expected", [
|
||||
("123-4567", "1234567"),
|
||||
("1234567", "1234567"),
|
||||
("123 4567", "1234567"),
|
||||
("BG 123-4567", "1234567"),
|
||||
])
|
||||
def test_normalize_various_formats(self, input_value, expected):
|
||||
"""Should handle various input formats."""
|
||||
result = normalize_bankgiro(input_value)
|
||||
assert result == expected
|
||||
```
|
||||
|
||||
### API Integration Test Pattern
|
||||
```python
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from src.web.app import app
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for /api/v1/health endpoint."""
|
||||
|
||||
def test_health_returns_200(self, client):
|
||||
"""Should return 200 OK."""
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_health_returns_status(self, client):
|
||||
"""Should return health status."""
|
||||
response = client.get("/api/v1/health")
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "model_loaded" in data
|
||||
|
||||
class TestInferEndpoint:
|
||||
"""Tests for /api/v1/infer endpoint."""
|
||||
|
||||
def test_infer_requires_file(self, client):
|
||||
"""Should require file upload."""
|
||||
response = client.post("/api/v1/infer")
|
||||
assert response.status_code == 422
|
||||
|
||||
def test_infer_rejects_non_pdf(self, client):
|
||||
"""Should reject non-PDF files."""
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("test.txt", b"not a pdf", "text/plain")}
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_infer_returns_fields(self, client, sample_invoice_pdf):
|
||||
"""Should return extracted fields."""
|
||||
with open(sample_invoice_pdf, "rb") as f:
|
||||
response = client.post(
|
||||
"/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "fields" in data
|
||||
```
|
||||
|
||||
### E2E Test Pattern
|
||||
```python
|
||||
import pytest
|
||||
import httpx
|
||||
from pathlib import Path
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def running_server():
|
||||
"""Ensure server is running for E2E tests."""
|
||||
# Server should be started before running E2E tests
|
||||
base_url = "http://localhost:8000"
|
||||
yield base_url
|
||||
|
||||
class TestInferencePipeline:
|
||||
"""E2E tests for complete inference pipeline."""
|
||||
|
||||
def test_health_check(self, running_server):
|
||||
"""Should pass health check."""
|
||||
response = httpx.get(f"{running_server}/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["model_loaded"] is True
|
||||
|
||||
def test_pdf_inference_returns_fields(self, running_server):
|
||||
"""Should extract fields from PDF."""
|
||||
pdf_path = Path("tests/fixtures/sample_invoice.pdf")
|
||||
with open(pdf_path, "rb") as f:
|
||||
response = httpx.post(
|
||||
f"{running_server}/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert "fields" in data
|
||||
assert len(data["fields"]) > 0
|
||||
|
||||
def test_cross_validation_included(self, running_server):
|
||||
"""Should include cross-validation for invoices with payment_line."""
|
||||
pdf_path = Path("tests/fixtures/invoice_with_payment_line.pdf")
|
||||
with open(pdf_path, "rb") as f:
|
||||
response = httpx.post(
|
||||
f"{running_server}/api/v1/infer",
|
||||
files={"file": ("invoice.pdf", f, "application/pdf")}
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
if data["fields"].get("payment_line"):
|
||||
assert "cross_validation" in data
|
||||
```
|
||||
|
||||
## Test File Organization
|
||||
|
||||
```
|
||||
tests/
|
||||
├── conftest.py # Shared fixtures
|
||||
├── fixtures/ # Test data files
|
||||
│ ├── sample_invoice.pdf
|
||||
│ └── invoice_with_payment_line.pdf
|
||||
├── test_cli/
|
||||
│ └── test_infer.py
|
||||
├── test_pdf/
|
||||
│ ├── test_extractor.py
|
||||
│ └── test_renderer.py
|
||||
├── test_ocr/
|
||||
│ ├── test_paddle_ocr.py
|
||||
│ └── test_machine_code_parser.py
|
||||
├── test_inference/
|
||||
│ ├── test_pipeline.py
|
||||
│ ├── test_yolo_detector.py
|
||||
│ └── test_field_extractor.py
|
||||
├── test_normalize/
|
||||
│ ├── test_bankgiro_normalizer.py
|
||||
│ ├── test_date_normalizer.py
|
||||
│ └── test_amount_normalizer.py
|
||||
├── test_web/
|
||||
│ ├── test_routes.py
|
||||
│ └── test_services.py
|
||||
└── e2e/
|
||||
└── test_inference_e2e.py
|
||||
```
|
||||
|
||||
## Mocking External Services
|
||||
|
||||
### Mock PaddleOCR
|
||||
```python
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@pytest.fixture
|
||||
def mock_paddle_ocr():
|
||||
"""Mock PaddleOCR for unit tests."""
|
||||
with patch("src.ocr.paddle_ocr.PaddleOCR") as mock:
|
||||
instance = Mock()
|
||||
instance.ocr.return_value = [
|
||||
[
|
||||
[[[0, 0], [100, 0], [100, 20], [0, 20]], ("Invoice Number", 0.95)],
|
||||
[[[0, 30], [100, 30], [100, 50], [0, 50]], ("INV-2024-001", 0.98)]
|
||||
]
|
||||
]
|
||||
mock.return_value = instance
|
||||
yield instance
|
||||
```
|
||||
|
||||
### Mock YOLO Model
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_yolo_model():
|
||||
"""Mock YOLO model for unit tests."""
|
||||
with patch("src.inference.yolo_detector.YOLO") as mock:
|
||||
instance = Mock()
|
||||
# Mock detection results
|
||||
instance.return_value = Mock(
|
||||
boxes=Mock(
|
||||
xyxy=[[10, 20, 100, 50]],
|
||||
conf=[0.95],
|
||||
cls=[0] # invoice_number class
|
||||
)
|
||||
)
|
||||
mock.return_value = instance
|
||||
yield instance
|
||||
```
|
||||
|
||||
### Mock Database
|
||||
```python
|
||||
@pytest.fixture
|
||||
def mock_db_connection():
|
||||
"""Mock database connection for unit tests."""
|
||||
with patch("src.data.db.get_db_connection") as mock:
|
||||
conn = Mock()
|
||||
cursor = Mock()
|
||||
cursor.fetchall.return_value = [
|
||||
("doc-123", "processed", {"invoice_number": "INV-001"})
|
||||
]
|
||||
cursor.fetchone.return_value = ("doc-123",)
|
||||
conn.cursor.return_value.__enter__ = Mock(return_value=cursor)
|
||||
conn.cursor.return_value.__exit__ = Mock(return_value=False)
|
||||
mock.return_value.__enter__ = Mock(return_value=conn)
|
||||
mock.return_value.__exit__ = Mock(return_value=False)
|
||||
yield conn
|
||||
```
|
||||
|
||||
## Test Coverage Verification
|
||||
|
||||
### Run Coverage Report
|
||||
```bash
|
||||
# Run with coverage
|
||||
pytest --cov=src --cov-report=term-missing
|
||||
|
||||
# Generate HTML report
|
||||
pytest --cov=src --cov-report=html
|
||||
# Open htmlcov/index.html in browser
|
||||
```
|
||||
|
||||
### Coverage Configuration (pyproject.toml)
|
||||
```toml
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = ["*/__init__.py", "*/test_*.py"]
|
||||
|
||||
[tool.coverage.report]
|
||||
fail_under = 80
|
||||
show_missing = true
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"if TYPE_CHECKING:",
|
||||
"raise NotImplementedError",
|
||||
]
|
||||
```
|
||||
|
||||
## Common Testing Mistakes to Avoid
|
||||
|
||||
### WRONG: Testing Implementation Details
|
||||
```python
|
||||
# Don't test internal state
|
||||
def test_parser_internal_state():
|
||||
parser = PaymentLineParser()
|
||||
parser._parse("...")
|
||||
assert parser._groups == [...] # Internal state
|
||||
```
|
||||
|
||||
### CORRECT: Test Public Interface
|
||||
```python
|
||||
# Test what users see
|
||||
def test_parser_extracts_bankgiro():
|
||||
result = parse_payment_line("...")
|
||||
assert result.bankgiro == "1234567"
|
||||
```
|
||||
|
||||
### WRONG: No Test Isolation
|
||||
```python
|
||||
# Tests depend on each other
|
||||
class TestDocuments:
|
||||
def test_creates_document(self):
|
||||
create_document(...) # Creates in DB
|
||||
|
||||
def test_updates_document(self):
|
||||
update_document(...) # Depends on previous test
|
||||
```
|
||||
|
||||
### CORRECT: Independent Tests
|
||||
```python
|
||||
# Each test sets up its own data
|
||||
class TestDocuments:
|
||||
def test_creates_document(self, mock_db):
|
||||
result = create_document(...)
|
||||
assert result.id is not None
|
||||
|
||||
def test_updates_document(self, mock_db):
|
||||
# Create own test data
|
||||
doc = create_document(...)
|
||||
result = update_document(doc.id, ...)
|
||||
assert result.status == "updated"
|
||||
```
|
||||
|
||||
### WRONG: Testing Too Much
|
||||
```python
|
||||
# One test doing everything
|
||||
def test_full_invoice_processing():
|
||||
# Load PDF
|
||||
# Extract images
|
||||
# Run YOLO
|
||||
# Run OCR
|
||||
# Normalize fields
|
||||
# Save to DB
|
||||
# Return response
|
||||
```
|
||||
|
||||
### CORRECT: Focused Tests
|
||||
```python
|
||||
def test_yolo_detects_invoice_number():
|
||||
"""Test only YOLO detection."""
|
||||
result = detector.detect(image)
|
||||
assert any(d.label == "invoice_number" for d in result)
|
||||
|
||||
def test_ocr_extracts_text():
|
||||
"""Test only OCR extraction."""
|
||||
result = ocr.extract(image, bbox)
|
||||
assert result == "INV-2024-001"
|
||||
|
||||
def test_normalizer_formats_date():
|
||||
"""Test only date normalization."""
|
||||
result = normalize_date("2024-01-15")
|
||||
assert result == "2024-01-15"
|
||||
```
|
||||
|
||||
## Fixtures (conftest.py)
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
@pytest.fixture
|
||||
def sample_invoice_pdf(tmp_path: Path) -> Path:
|
||||
"""Create sample invoice PDF for testing."""
|
||||
pdf_path = tmp_path / "invoice.pdf"
|
||||
# Copy from fixtures or create minimal PDF
|
||||
src = Path("tests/fixtures/sample_invoice.pdf")
|
||||
if src.exists():
|
||||
pdf_path.write_bytes(src.read_bytes())
|
||||
return pdf_path
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""FastAPI test client."""
|
||||
from src.web.app import app
|
||||
return TestClient(app)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_payment_line() -> str:
|
||||
"""Sample Swedish payment line for testing."""
|
||||
return "1234567#0000000012345#230115#00012345678901234567#1"
|
||||
```
|
||||
|
||||
## Continuous Testing
|
||||
|
||||
### Watch Mode During Development
|
||||
```bash
|
||||
# Using pytest-watch
|
||||
ptw -- tests/test_ocr/
|
||||
# Tests run automatically on file changes
|
||||
```
|
||||
|
||||
### Pre-Commit Hook
|
||||
```bash
|
||||
# .pre-commit-config.yaml
|
||||
repos:
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: pytest
|
||||
name: pytest
|
||||
entry: pytest --tb=short -q
|
||||
language: system
|
||||
pass_filenames: false
|
||||
always_run: true
|
||||
```
|
||||
|
||||
### CI/CD Integration (GitHub Actions)
|
||||
```yaml
|
||||
- name: Run Tests
|
||||
run: |
|
||||
pytest --cov=src --cov-report=xml
|
||||
|
||||
- name: Upload Coverage
|
||||
uses: codecov/codecov-action@v3
|
||||
with:
|
||||
file: coverage.xml
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Write Tests First** - Always TDD
|
||||
2. **One Assert Per Test** - Focus on single behavior
|
||||
3. **Descriptive Test Names** - `test_<what>_<condition>_<expected>`
|
||||
4. **Arrange-Act-Assert** - Clear test structure
|
||||
5. **Mock External Dependencies** - Isolate unit tests
|
||||
6. **Test Edge Cases** - None, empty, invalid, boundary
|
||||
7. **Test Error Paths** - Not just happy paths
|
||||
8. **Keep Tests Fast** - Unit tests < 50ms each
|
||||
9. **Clean Up After Tests** - Use fixtures with cleanup
|
||||
10. **Review Coverage Reports** - Identify gaps
|
||||
|
||||
## Success Metrics
|
||||
|
||||
- 80%+ code coverage achieved
|
||||
- All tests passing (green)
|
||||
- No skipped or disabled tests
|
||||
- Fast test execution (< 60s for unit tests)
|
||||
- E2E tests cover critical inference flow
|
||||
- Tests catch bugs before production
|
||||
|
||||
---
|
||||
|
||||
**Remember**: Tests are not optional. They are the safety net that enables confident refactoring, rapid development, and production reliability.
|
||||
139
.opencode/skills/ui-prompt-generator/SKILL.md
Normal file
139
.opencode/skills/ui-prompt-generator/SKILL.md
Normal file
@@ -0,0 +1,139 @@
|
||||
---
|
||||
name: ui-prompt-generator
|
||||
description: 读取 Product-Spec.md 中的功能需求和 UI 布局,生成可用于 AI 绘图工具的原型图提示词。与 product-spec-builder 配套使用,帮助用户快速将需求文档转化为视觉原型。
|
||||
---
|
||||
|
||||
[角色]
|
||||
你是一位 UI/UX 设计专家,擅长将产品需求转化为精准的视觉描述。
|
||||
|
||||
你能够从结构化的产品文档中提取关键信息,并转化为 AI 绘图工具可以理解的提示词,帮助用户快速生成产品原型图。
|
||||
|
||||
[任务]
|
||||
读取 Product-Spec.md,提取功能需求和 UI 布局信息,补充必要的视觉参数,生成可直接用于文生图工具的原型图提示词。
|
||||
|
||||
最终输出按页面拆分的提示词,用户可以直接复制到 AI 绘图工具生成原型图。
|
||||
|
||||
[技能]
|
||||
- **文档解析**:从 Product-Spec.md 提取产品概述、功能需求、UI 布局、用户流程
|
||||
- **页面识别**:根据产品复杂度识别需要生成几个页面
|
||||
- **视觉转换**:将结构化的布局描述转化为视觉语言
|
||||
- **提示词生成**:输出高质量的英文文生图提示词
|
||||
|
||||
[文件结构]
|
||||
```
|
||||
ui-prompt-generator/
|
||||
├── SKILL.md # 主 Skill 定义(本文件)
|
||||
└── templates/
|
||||
└── ui-prompt-template.md # 提示词输出模板
|
||||
```
|
||||
|
||||
[总体规则]
|
||||
- 始终使用中文与用户交流
|
||||
- 提示词使用英文输出(AI 绘图工具英文效果更好)
|
||||
- 必须先读取 Product-Spec.md,不存在则提示用户先完成需求收集
|
||||
- 不重复追问 Product-Spec.md 里已有的信息
|
||||
- 用户不确定的信息,直接使用默认值继续推进
|
||||
- 按页面拆分生成提示词,每个页面一条提示词
|
||||
- 保持专业友好的语气
|
||||
|
||||
[视觉风格选项]
|
||||
| 风格 | 英文 | 说明 | 适用场景 |
|
||||
|------|------|------|---------|
|
||||
| 现代极简 | Minimalism | 简洁留白、干净利落 | 工具类、企业应用 |
|
||||
| 玻璃拟态 | Glassmorphism | 毛玻璃效果、半透明层叠 | 科技产品、仪表盘 |
|
||||
| 新拟态 | Neomorphism | 柔和阴影、微凸起效果 | 音乐播放器、控制面板 |
|
||||
| 便当盒布局 | Bento Grid | 模块化卡片、网格排列 | 数据展示、功能聚合页 |
|
||||
| 暗黑模式 | Dark Mode | 深色背景、低亮度护眼 | 开发工具、影音类 |
|
||||
| 新野兽派 | Neo-Brutalism | 粗黑边框、高对比、大胆配色 | 创意类、潮流品牌 |
|
||||
|
||||
**默认值**:现代极简(Minimalism)
|
||||
|
||||
[配色选项]
|
||||
| 选项 | 说明 |
|
||||
|------|------|
|
||||
| 浅色系 | 白色/浅灰背景,深色文字 |
|
||||
| 深色系 | 深色/黑色背景,浅色文字 |
|
||||
| 指定主色 | 用户指定品牌色或主题色 |
|
||||
|
||||
**默认值**:浅色系
|
||||
|
||||
[目标平台选项]
|
||||
| 选项 | 说明 |
|
||||
|------|------|
|
||||
| 桌面端 | Desktop application,宽屏布局 |
|
||||
| 网页 | Web application,响应式布局 |
|
||||
| 移动端 | Mobile application,竖屏布局 |
|
||||
|
||||
**默认值**:网页
|
||||
|
||||
[工作流程]
|
||||
[启动阶段]
|
||||
目的:读取 Product-Spec.md,提取信息,补充缺失的视觉参数
|
||||
|
||||
第一步:检测文件
|
||||
检测项目目录中是否存在 Product-Spec.md
|
||||
不存在 → 提示:「未找到 Product-Spec.md,请先使用 /prd 完成需求收集。」,终止流程
|
||||
存在 → 继续
|
||||
|
||||
第二步:解析 Product-Spec.md
|
||||
读取 Product-Spec.md 文件内容
|
||||
提取以下信息:
|
||||
- 产品概述:了解产品是什么
|
||||
- 功能需求:了解有哪些功能
|
||||
- UI 布局:了解界面结构和控件
|
||||
- 用户流程:了解有哪些页面和状态
|
||||
- 视觉风格(如果文档里提到了)
|
||||
- 配色方案(如果文档里提到了)
|
||||
- 目标平台(如果文档里提到了)
|
||||
|
||||
第三步:识别页面
|
||||
根据 UI 布局和用户流程,识别产品包含几个页面
|
||||
|
||||
判断逻辑:
|
||||
- 只有一个主界面 → 单页面产品
|
||||
- 有多个界面(如:主界面、设置页、详情页)→ 多页面产品
|
||||
- 有明显的多步骤流程 → 按步骤拆分页面
|
||||
|
||||
输出页面清单:
|
||||
"📄 **识别到以下页面:**
|
||||
1. [页面名称]:[简要描述]
|
||||
2. [页面名称]:[简要描述]
|
||||
..."
|
||||
|
||||
第四步:补充缺失的视觉参数
|
||||
检查是否已提取到:视觉风格、配色方案、目标平台
|
||||
|
||||
全部已有 → 跳过提问,直接进入提示词生成阶段
|
||||
有缺失项 → 只针对缺失项询问用户:
|
||||
|
||||
"🎨 **还需要确认几个视觉参数:**
|
||||
|
||||
[只列出缺失的项目,已有的不列]
|
||||
|
||||
直接回复你的选择,或回复「默认」使用默认值。"
|
||||
|
||||
用户回复后解析选择
|
||||
用户不确定或回复「默认」→ 使用默认值
|
||||
|
||||
[提示词生成阶段]
|
||||
目的:为每个页面生成提示词
|
||||
|
||||
第一步:准备生成参数
|
||||
整合所有信息:
|
||||
- 产品类型(从产品概述提取)
|
||||
- 页面列表(从启动阶段获取)
|
||||
- 每个页面的布局和控件(从 UI 布局提取)
|
||||
- 视觉风格(从 Product-Spec.md 提取或用户选择)
|
||||
- 配色方案(从 Product-Spec.md 提取或用户选择)
|
||||
- 目标平台(从 Product-Spec.md 提取或用户选择)
|
||||
|
||||
第二步:按页面生成提示词
|
||||
加载 templates/ui-prompt-template.md 获取提示词结构和输出格式
|
||||
为每个页面生成一条英文提示词
|
||||
按模板中的提示词结构组织内容
|
||||
|
||||
第三步:输出文件
|
||||
将生成的提示词保存为 UI-Prompts.md
|
||||
|
||||
[初始化]
|
||||
执行 [启动阶段]
|
||||
@@ -0,0 +1,154 @@
|
||||
---
|
||||
name: ui-prompt-template
|
||||
description: UI 原型图提示词输出模板。当需要生成文生图提示词时,按照此模板的结构和格式填充内容,输出为 UI-Prompts.md 文件。
|
||||
---
|
||||
|
||||
# UI 原型图提示词模板
|
||||
|
||||
本模板用于生成可直接用于 AI 绘图工具的原型图提示词。生成时按照此结构填充内容。
|
||||
|
||||
---
|
||||
|
||||
## 文件命名
|
||||
|
||||
`UI-Prompts.md`
|
||||
|
||||
---
|
||||
|
||||
## 提示词结构
|
||||
|
||||
每条提示词按以下结构组织:
|
||||
|
||||
```
|
||||
[主体] + [布局] + [控件] + [风格] + [质量词]
|
||||
```
|
||||
|
||||
### [主体]
|
||||
产品类型 + 界面类型 + 页面名称
|
||||
|
||||
示例:
|
||||
- `A modern web application UI for a storyboard generator tool, main interface`
|
||||
- `A mobile app screen for a task management application, settings page`
|
||||
|
||||
### [布局]
|
||||
整体结构 + 比例 + 区域划分
|
||||
|
||||
示例:
|
||||
- `split layout with left panel (40%) and right content area (60%)`
|
||||
- `single column layout with top navigation bar and main content below`
|
||||
- `grid layout with 2x2 card arrangement`
|
||||
|
||||
### [控件]
|
||||
各区域的具体控件,从上到下、从左到右描述
|
||||
|
||||
示例:
|
||||
- `left panel contains: project name input at top, large text area for content, dropdown menu for style selection, primary action button at bottom`
|
||||
- `right panel shows: 3x3 grid of image cards with frame numbers and captions, action buttons below`
|
||||
|
||||
### [风格]
|
||||
视觉风格 + 配色 + 细节特征
|
||||
|
||||
| 风格 | 英文描述 |
|
||||
|------|---------|
|
||||
| 现代极简 | minimalist design, clean layout, ample white space, subtle shadows |
|
||||
| 玻璃拟态 | glassmorphism style, frosted glass effect, translucent panels, blur background |
|
||||
| 新拟态 | neumorphism design, soft shadows, subtle highlights, extruded elements |
|
||||
| 便当盒布局 | bento grid layout, modular cards, organized sections, clean borders |
|
||||
| 暗黑模式 | dark mode UI, dark background, light text, subtle glow effects |
|
||||
| 新野兽派 | neo-brutalist design, bold black borders, high contrast, raw aesthetic |
|
||||
|
||||
配色描述:
|
||||
- 浅色系:`light color scheme, white background, dark text, [accent color] accent`
|
||||
- 深色系:`dark color scheme, dark gray background, light text, [accent color] accent`
|
||||
|
||||
### [质量词]
|
||||
确保生成质量的关键词,放在提示词末尾
|
||||
|
||||
```
|
||||
UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style, dribbble, behance
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 输出格式
|
||||
|
||||
```markdown
|
||||
# [产品名称] 原型图提示词
|
||||
|
||||
> 视觉风格:[风格名称]
|
||||
> 配色方案:[配色名称]
|
||||
> 目标平台:[平台名称]
|
||||
|
||||
---
|
||||
|
||||
## 页面 1:[页面名称]
|
||||
|
||||
**页面说明**:[一句话描述这个页面是什么]
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
[完整的英文提示词]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 页面 2:[页面名称]
|
||||
|
||||
**页面说明**:[一句话描述]
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
[完整的英文提示词]
|
||||
```
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 完整示例
|
||||
|
||||
以下是「剧本分镜生成器」的原型图提示词示例,供参考:
|
||||
|
||||
```markdown
|
||||
# 剧本分镜生成器 原型图提示词
|
||||
|
||||
> 视觉风格:现代极简(Minimalism)
|
||||
> 配色方案:浅色系
|
||||
> 目标平台:网页(Web)
|
||||
|
||||
---
|
||||
|
||||
## 页面 1:主界面
|
||||
|
||||
**页面说明**:用户输入剧本、设置角色和场景、生成分镜图的主要工作界面
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern web application UI for a storyboard generator tool, main interface, split layout with left input panel (40% width) and right output area (60% width), left panel contains: project name input field at top, large multiline text area for script input with placeholder text, character cards section with image thumbnails and text fields and add button, scene cards section below, style dropdown menu, prominent generate button at bottom, right panel shows: 3x3 grid of storyboard image cards with frame numbers and short descriptions below each image, download all button and continue generating button below the grid, page navigation at bottom, minimalist design, clean layout, white background, light gray borders, blue accent color for primary actions, subtle shadows, rounded corners, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 页面 2:空状态界面
|
||||
|
||||
**页面说明**:用户首次打开、尚未输入内容时的引导界面
|
||||
|
||||
**提示词**:
|
||||
```
|
||||
A modern web application UI for a storyboard generator tool, empty state screen, split layout with left panel (40%) and right panel (60%), left panel shows: empty input fields with placeholder text and helper icons, right panel displays: large empty state illustration in the center, welcome message and getting started tips below, minimalist design, clean layout, white background, soft gray placeholder elements, blue accent color, friendly and inviting atmosphere, UI/UX design, high fidelity mockup, 4K resolution, professional, Figma style
|
||||
```
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 写作要点
|
||||
|
||||
1. **提示词语言**:始终使用英文,AI 绘图工具对英文理解更好
|
||||
2. **结构完整**:确保包含主体、布局、控件、风格、质量词五个部分
|
||||
3. **控件描述**:
|
||||
- 按空间顺序描述(上到下、左到右)
|
||||
- 具体到控件类型(input field, button, dropdown, card)
|
||||
- 包含控件状态(placeholder text, selected state)
|
||||
4. **布局比例**:写明具体比例(40%/60%),不要只说「左右布局」
|
||||
5. **风格一致**:同一产品的多个页面使用相同的风格描述
|
||||
6. **质量词**:始终在末尾加上质量词确保生成效果
|
||||
7. **页面说明**:用中文写一句话说明,帮助理解这个页面是什么
|
||||
242
.opencode/skills/verification-loop/SKILL.md
Normal file
242
.opencode/skills/verification-loop/SKILL.md
Normal file
@@ -0,0 +1,242 @@
|
||||
# Verification Loop Skill
|
||||
|
||||
Comprehensive verification system for Python/FastAPI development.
|
||||
|
||||
## When to Use
|
||||
|
||||
Invoke this skill:
|
||||
- After completing a feature or significant code change
|
||||
- Before creating a PR
|
||||
- When you want to ensure quality gates pass
|
||||
- After refactoring
|
||||
- Before deployment
|
||||
|
||||
## Verification Phases
|
||||
|
||||
### Phase 1: Type Check
|
||||
```bash
|
||||
# Run mypy type checker
|
||||
mypy src/ --ignore-missing-imports 2>&1 | head -30
|
||||
```
|
||||
|
||||
Report all type errors. Fix critical ones before continuing.
|
||||
|
||||
### Phase 2: Lint Check
|
||||
```bash
|
||||
# Run ruff linter
|
||||
ruff check src/ 2>&1 | head -30
|
||||
|
||||
# Auto-fix if desired
|
||||
ruff check src/ --fix
|
||||
```
|
||||
|
||||
Check for:
|
||||
- Unused imports
|
||||
- Code style violations
|
||||
- Common Python anti-patterns
|
||||
|
||||
### Phase 3: Test Suite
|
||||
```bash
|
||||
# Run tests with coverage
|
||||
pytest --cov=src --cov-report=term-missing -q 2>&1 | tail -50
|
||||
|
||||
# Run specific test file
|
||||
pytest tests/test_ocr/test_machine_code_parser.py -v
|
||||
|
||||
# Run with short traceback
|
||||
pytest -x --tb=short
|
||||
```
|
||||
|
||||
Report:
|
||||
- Total tests: X
|
||||
- Passed: X
|
||||
- Failed: X
|
||||
- Coverage: X%
|
||||
- Target: 80% minimum
|
||||
|
||||
### Phase 4: Security Scan
|
||||
```bash
|
||||
# Check for hardcoded secrets
|
||||
grep -rn "password\s*=" --include="*.py" src/ 2>/dev/null | grep -v "db_password:" | head -10
|
||||
grep -rn "api_key\s*=" --include="*.py" src/ 2>/dev/null | head -10
|
||||
grep -rn "sk-" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for print statements (should use logging)
|
||||
grep -rn "print(" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for bare except
|
||||
grep -rn "except:" --include="*.py" src/ 2>/dev/null | head -10
|
||||
|
||||
# Check for SQL injection risks (f-strings in execute)
|
||||
grep -rn 'execute(f"' --include="*.py" src/ 2>/dev/null | head -10
|
||||
grep -rn "execute(f'" --include="*.py" src/ 2>/dev/null | head -10
|
||||
```
|
||||
|
||||
### Phase 5: Import Check
|
||||
```bash
|
||||
# Verify all imports work
|
||||
python -c "from src.web.app import app; print('Web app OK')"
|
||||
python -c "from src.inference.pipeline import InferencePipeline; print('Pipeline OK')"
|
||||
python -c "from src.ocr.machine_code_parser import parse_payment_line; print('Parser OK')"
|
||||
```
|
||||
|
||||
### Phase 6: Diff Review
|
||||
```bash
|
||||
# Show what changed
|
||||
git diff --stat
|
||||
git diff HEAD --name-only
|
||||
|
||||
# Show staged changes
|
||||
git diff --staged --stat
|
||||
```
|
||||
|
||||
Review each changed file for:
|
||||
- Unintended changes
|
||||
- Missing error handling
|
||||
- Potential edge cases
|
||||
- Missing type hints
|
||||
- Mutable default arguments
|
||||
|
||||
### Phase 7: API Smoke Test (if server running)
|
||||
```bash
|
||||
# Health check
|
||||
curl -s http://localhost:8000/api/v1/health | python -m json.tool
|
||||
|
||||
# Verify response format
|
||||
curl -s http://localhost:8000/api/v1/health | grep -q "healthy" && echo "Health: OK" || echo "Health: FAIL"
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
After running all phases, produce a verification report:
|
||||
|
||||
```
|
||||
VERIFICATION REPORT
|
||||
==================
|
||||
|
||||
Types: [PASS/FAIL] (X errors)
|
||||
Lint: [PASS/FAIL] (X warnings)
|
||||
Tests: [PASS/FAIL] (X/Y passed, Z% coverage)
|
||||
Security: [PASS/FAIL] (X issues)
|
||||
Imports: [PASS/FAIL]
|
||||
Diff: [X files changed]
|
||||
|
||||
Overall: [READY/NOT READY] for PR
|
||||
|
||||
Issues to Fix:
|
||||
1. ...
|
||||
2. ...
|
||||
```
|
||||
|
||||
## Quick Commands
|
||||
|
||||
```bash
|
||||
# Full verification (WSL)
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports && ruff check src/ && pytest -x --tb=short"
|
||||
|
||||
# Type check only
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && mypy src/ --ignore-missing-imports"
|
||||
|
||||
# Tests only
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && pytest --cov=src -q"
|
||||
```
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
### Before Commit
|
||||
- [ ] mypy passes (no type errors)
|
||||
- [ ] ruff check passes (no lint errors)
|
||||
- [ ] All tests pass
|
||||
- [ ] No print() statements in production code
|
||||
- [ ] No hardcoded secrets
|
||||
- [ ] No bare `except:` clauses
|
||||
- [ ] No SQL injection risks (f-strings in queries)
|
||||
- [ ] Coverage >= 80% for changed code
|
||||
|
||||
### Before PR
|
||||
- [ ] All above checks pass
|
||||
- [ ] git diff reviewed for unintended changes
|
||||
- [ ] New code has tests
|
||||
- [ ] Type hints on all public functions
|
||||
- [ ] Docstrings on public APIs
|
||||
- [ ] No TODO/FIXME for critical items
|
||||
|
||||
### Before Deployment
|
||||
- [ ] All above checks pass
|
||||
- [ ] E2E tests pass
|
||||
- [ ] Health check returns healthy
|
||||
- [ ] Model loaded successfully
|
||||
- [ ] No server errors in logs
|
||||
|
||||
## Common Issues and Fixes
|
||||
|
||||
### Type Error: Missing return type
|
||||
```python
|
||||
# Before
|
||||
def process(data):
|
||||
return result
|
||||
|
||||
# After
|
||||
def process(data: dict) -> InferenceResult:
|
||||
return result
|
||||
```
|
||||
|
||||
### Lint Error: Unused import
|
||||
```python
|
||||
# Remove unused imports or add to __all__
|
||||
```
|
||||
|
||||
### Security: print() in production
|
||||
```python
|
||||
# Before
|
||||
print(f"Processing {doc_id}")
|
||||
|
||||
# After
|
||||
logger.info(f"Processing {doc_id}")
|
||||
```
|
||||
|
||||
### Security: Bare except
|
||||
```python
|
||||
# Before
|
||||
except:
|
||||
pass
|
||||
|
||||
# After
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
raise
|
||||
```
|
||||
|
||||
### Security: SQL injection
|
||||
```python
|
||||
# Before (DANGEROUS)
|
||||
cur.execute(f"SELECT * FROM docs WHERE id = '{user_input}'")
|
||||
|
||||
# After (SAFE)
|
||||
cur.execute("SELECT * FROM docs WHERE id = %s", (user_input,))
|
||||
```
|
||||
|
||||
## Continuous Mode
|
||||
|
||||
For long sessions, run verification after major changes:
|
||||
|
||||
```markdown
|
||||
Checkpoints:
|
||||
- After completing each function
|
||||
- After finishing a module
|
||||
- Before moving to next task
|
||||
- Every 15-20 minutes of coding
|
||||
|
||||
Run: /verify
|
||||
```
|
||||
|
||||
## Integration with Other Skills
|
||||
|
||||
| Skill | Purpose |
|
||||
|-------|---------|
|
||||
| code-review | Detailed code analysis |
|
||||
| security-review | Deep security audit |
|
||||
| tdd-workflow | Test coverage |
|
||||
| build-fix | Fix errors incrementally |
|
||||
|
||||
This skill provides quick, comprehensive verification. Use specialized skills for deeper analysis.
|
||||
226
AGENTS.md
226
AGENTS.md
@@ -1,179 +1,93 @@
|
||||
# AGENTS.md - Coding Guidelines for AI Agents
|
||||
# Invoice Master POC v2
|
||||
|
||||
## Build / Test / Lint Commands
|
||||
Swedish Invoice Field Extraction System - YOLOv11 + PaddleOCR 从瑞典 PDF 发票中提取结构化数据。
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Component | Technology |
|
||||
|-----------|------------|
|
||||
| Object Detection | YOLOv11 (Ultralytics) |
|
||||
| OCR Engine | PaddleOCR v5 (PP-OCRv5) |
|
||||
| PDF Processing | PyMuPDF (fitz) |
|
||||
| Database | PostgreSQL + psycopg2 |
|
||||
| Web Framework | FastAPI + Uvicorn |
|
||||
| Deep Learning | PyTorch + CUDA 12.x |
|
||||
|
||||
## WSL Environment (REQUIRED)
|
||||
|
||||
**Prefix ALL commands with:**
|
||||
|
||||
### Python Backend
|
||||
```bash
|
||||
# Install packages (editable mode)
|
||||
pip install -e packages/shared
|
||||
pip install -e packages/training
|
||||
pip install -e packages/backend
|
||||
|
||||
# Run all tests
|
||||
DB_PASSWORD=xxx pytest tests/ -q
|
||||
|
||||
# Run single test file
|
||||
DB_PASSWORD=xxx pytest tests/path/to/test_file.py -v
|
||||
|
||||
# Run with coverage
|
||||
DB_PASSWORD=xxx pytest tests/ --cov=packages --cov-report=term-missing
|
||||
|
||||
# Format code
|
||||
black packages/ tests/
|
||||
ruff check packages/ tests/
|
||||
|
||||
# Type checking
|
||||
mypy packages/
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && <command>"
|
||||
```
|
||||
|
||||
### Frontend
|
||||
```bash
|
||||
cd frontend
|
||||
**NEVER run Python commands directly in Windows PowerShell/CMD.**
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
## Project-Specific Rules
|
||||
|
||||
# Development server
|
||||
npm run dev
|
||||
- Python 3.11+ with type hints
|
||||
- No print() in production - use logging
|
||||
- Run tests: `pytest --cov=src`
|
||||
|
||||
# Build
|
||||
npm run build
|
||||
## Critical Rules
|
||||
|
||||
# Run tests
|
||||
npm run test
|
||||
### Code Organization
|
||||
|
||||
# Run single test
|
||||
npx vitest run src/path/to/file.test.ts
|
||||
- Many small files over few large files
|
||||
- High cohesion, low coupling
|
||||
- 200-400 lines typical, 800 max per file
|
||||
- Organize by feature/domain, not by type
|
||||
|
||||
# Watch mode
|
||||
npm run test:watch
|
||||
### Code Style
|
||||
|
||||
# Coverage
|
||||
npm run test:coverage
|
||||
```
|
||||
- No emojis in code, comments, or documentation
|
||||
- Immutability always - never mutate objects or arrays
|
||||
- No console.log in production code
|
||||
- Proper error handling with try/catch
|
||||
- Input validation with Zod or similar
|
||||
|
||||
## Code Style Guidelines
|
||||
### Testing
|
||||
|
||||
### Python
|
||||
- TDD: Write tests first
|
||||
- 80% minimum coverage
|
||||
- Unit tests for utilities
|
||||
- Integration tests for APIs
|
||||
- E2E tests for critical flows
|
||||
|
||||
**Imports:**
|
||||
- Use absolute imports within packages: `from shared.pdf.extractor import PDFDocument`
|
||||
- Group imports: stdlib → third-party → local (separated by blank lines)
|
||||
- Use `from __future__ import annotations` for forward references when needed
|
||||
### Security
|
||||
|
||||
**Type Hints:**
|
||||
- All functions must have type hints (enforced by mypy)
|
||||
- Use `| None` instead of `Optional[...]` (Python 3.10+)
|
||||
- Use `list[str]` instead of `List[str]` (Python 3.10+)
|
||||
|
||||
**Naming:**
|
||||
- Classes: `PascalCase` (e.g., `PDFDocument`, `InferencePipeline`)
|
||||
- Functions/variables: `snake_case` (e.g., `extract_text`, `get_db_connection`)
|
||||
- Constants: `UPPER_SNAKE_CASE` (e.g., `DEFAULT_DPI`, `DATABASE`)
|
||||
- Private: `_leading_underscore` for internal use
|
||||
|
||||
**Error Handling:**
|
||||
- Use custom exceptions from `shared.exceptions`
|
||||
- Base exception: `InvoiceExtractionError`
|
||||
- Specific exceptions: `PDFProcessingError`, `OCRError`, `DatabaseError`, etc.
|
||||
- Always include context in exceptions via `details` dict
|
||||
|
||||
**Docstrings:**
|
||||
- Use Google-style docstrings
|
||||
- All public functions/classes must have docstrings
|
||||
- Include Args/Returns sections for complex functions
|
||||
|
||||
**Code Organization:**
|
||||
- Maximum line length: 100 characters (black config)
|
||||
- Target Python: 3.10+
|
||||
- Keep files under 800 lines, ideally 200-400 lines
|
||||
|
||||
### TypeScript / React Frontend
|
||||
|
||||
**Imports:**
|
||||
- Use path alias `@/` for project imports: `import { Button } from '@/components/Button'`
|
||||
- Group: React → third-party → local (@/) → relative
|
||||
|
||||
**Naming:**
|
||||
- Components: `PascalCase` (e.g., `Dashboard.tsx`, `InferenceDemo.tsx`)
|
||||
- Hooks: `camelCase` with `use` prefix (e.g., `useDocuments.ts`)
|
||||
- Types/Interfaces: `PascalCase` (e.g., `DocumentListResponse`)
|
||||
- API endpoints: `camelCase` (e.g., `documentsApi`)
|
||||
|
||||
**TypeScript:**
|
||||
- Strict mode enabled
|
||||
- Use explicit return types on exported functions
|
||||
- Prefer `type` over `interface` for simple shapes
|
||||
- Use enums for fixed sets of values
|
||||
|
||||
**React Patterns:**
|
||||
- Functional components with hooks
|
||||
- Use React Query for server state
|
||||
- Use Zustand for client state (if needed)
|
||||
- Props interfaces named `{ComponentName}Props`
|
||||
|
||||
**Styling:**
|
||||
- Use Tailwind CSS exclusively
|
||||
- Custom colors: `warm-*` theme (e.g., `bg-warm-text-secondary`)
|
||||
- Component variants defined as objects (see Button.tsx pattern)
|
||||
|
||||
**Testing:**
|
||||
- Use Vitest + React Testing Library
|
||||
- Test files: `{name}.test.ts` or `{name}.test.tsx`
|
||||
- Co-locate tests with source files when possible
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
packages/
|
||||
shared/ # Shared utilities (PDF, OCR, storage, config)
|
||||
training/ # Training service (GPU, CLI commands)
|
||||
backend/ # Web API + inference (FastAPI)
|
||||
frontend/ # React + TypeScript + Vite
|
||||
tests/ # Test suite
|
||||
migrations/ # Database SQL migrations
|
||||
```
|
||||
|
||||
## Key Configuration
|
||||
|
||||
- **DPI:** 150 (must match between training and inference)
|
||||
- **Database:** PostgreSQL (configured via env vars)
|
||||
- **Storage:** Abstracted (Local/Azure/S3 via storage.yaml)
|
||||
- **Python:** 3.10+ (3.11 recommended, 3.10 for RTX 50 series)
|
||||
- No hardcoded secrets
|
||||
- Environment variables for sensitive data
|
||||
- Validate all user inputs
|
||||
- Parameterized queries only
|
||||
- CSRF protection enabled
|
||||
|
||||
## Environment Variables
|
||||
|
||||
Required: `DB_PASSWORD`
|
||||
Optional: `DB_HOST`, `DB_PORT`, `DB_NAME`, `DB_USER`, `STORAGE_BASE_PATH`
|
||||
```bash
|
||||
# Required
|
||||
DB_PASSWORD=
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Python: Adding a New API Endpoint
|
||||
1. Add route in `backend/web/api/v1/`
|
||||
2. Define Pydantic schema in `backend/web/schemas/`
|
||||
3. Implement service logic in `backend/web/services/`
|
||||
4. Add tests in `tests/web/`
|
||||
|
||||
### Frontend: Adding a New Component
|
||||
1. Create component in `frontend/src/components/`
|
||||
2. Export from `frontend/src/components/index.ts` if shared
|
||||
3. Add types to `frontend/src/api/types.ts` if API-related
|
||||
4. Add tests co-located with component
|
||||
|
||||
### Error Handling
|
||||
```python
|
||||
from shared.exceptions import DatabaseError
|
||||
|
||||
try:
|
||||
result = db.query(...)
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"Failed to fetch document: {e}", details={"doc_id": doc_id})
|
||||
# Optional (with defaults)
|
||||
DB_HOST=192.168.68.31
|
||||
DB_PORT=5432
|
||||
DB_NAME=docmaster
|
||||
DB_USER=docmaster
|
||||
MODEL_PATH=runs/train/invoice_fields/weights/best.pt
|
||||
CONFIDENCE_THRESHOLD=0.5
|
||||
SERVER_HOST=0.0.0.0
|
||||
SERVER_PORT=8000
|
||||
```
|
||||
## Available Commands
|
||||
|
||||
### Database Access
|
||||
```python
|
||||
from shared.data.repositories import DocumentRepository
|
||||
- `/tdd` - Test-driven development workflow
|
||||
- `/plan` - Create implementation plan
|
||||
- `/code-review` - Review code quality
|
||||
- `/build-fix` - Fix build errors
|
||||
|
||||
repo = DocumentRepository()
|
||||
doc = repo.get_by_id(doc_id)
|
||||
```
|
||||
## Git Workflow
|
||||
|
||||
- Conventional commits: `feat:`, `fix:`, `refactor:`, `docs:`, `test:`
|
||||
- Never commit to main directly
|
||||
- PRs require review
|
||||
- All tests must pass before merge
|
||||
|
||||
@@ -77,8 +77,8 @@ frontend/ # React 前端 (Vite + TypeScript + TailwindCSS)
|
||||
wsl -d Ubuntu-22.04
|
||||
|
||||
# 2. 创建 Conda 环境
|
||||
conda create -n invoice-py311 python=3.11 -y
|
||||
conda activate invoice-py311
|
||||
conda create -n invoice-sm120 python=3.11 -y
|
||||
conda activate invoice-sm120
|
||||
|
||||
# 3. 进入项目目录
|
||||
cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2
|
||||
@@ -314,7 +314,7 @@ python -m backend.cli.infer \
|
||||
|
||||
```bash
|
||||
# 从 Windows PowerShell 启动
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-py311 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
|
||||
wsl bash -c "source ~/miniconda3/etc/profile.d/conda.sh && conda activate invoice-sm120 && cd /mnt/c/Users/yaoji/git/ColaCoder/invoice-master-poc-v2 && python run_server.py --port 8000"
|
||||
|
||||
# 启动前端
|
||||
cd frontend && npm install && npm run dev
|
||||
|
||||
@@ -7,10 +7,14 @@ Runs inference on new PDFs to extract invoice data.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -66,10 +70,13 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Validate model
|
||||
model_path = Path(args.model)
|
||||
if not model_path.exists():
|
||||
print(f"Error: Model not found: {model_path}", file=sys.stderr)
|
||||
logger.error("Model not found: %s", model_path)
|
||||
sys.exit(1)
|
||||
|
||||
# Get input files
|
||||
@@ -79,16 +86,16 @@ def main():
|
||||
elif input_path.is_dir():
|
||||
pdf_files = list(input_path.glob('*.pdf'))
|
||||
else:
|
||||
print(f"Error: Input not found: {input_path}", file=sys.stderr)
|
||||
logger.error("Input not found: %s", input_path)
|
||||
sys.exit(1)
|
||||
|
||||
if not pdf_files:
|
||||
print("Error: No PDF files found", file=sys.stderr)
|
||||
logger.error("No PDF files found")
|
||||
sys.exit(1)
|
||||
|
||||
if args.verbose:
|
||||
print(f"Processing {len(pdf_files)} PDF file(s)")
|
||||
print(f"Model: {model_path}")
|
||||
logger.info("Processing %d PDF file(s)", len(pdf_files))
|
||||
logger.info("Model: %s", model_path)
|
||||
|
||||
from backend.pipeline import InferencePipeline
|
||||
|
||||
@@ -107,18 +114,18 @@ def main():
|
||||
|
||||
for pdf_path in pdf_files:
|
||||
if args.verbose:
|
||||
print(f"Processing: {pdf_path.name}")
|
||||
logger.info("Processing: %s", pdf_path.name)
|
||||
|
||||
result = pipeline.process_pdf(pdf_path)
|
||||
results.append(result.to_json())
|
||||
|
||||
if args.verbose:
|
||||
print(f" Success: {result.success}")
|
||||
print(f" Fields: {len(result.fields)}")
|
||||
logger.info(" Success: %s", result.success)
|
||||
logger.info(" Fields: %d", len(result.fields))
|
||||
if result.fallback_used:
|
||||
print(f" Fallback used: Yes")
|
||||
logger.info(" Fallback used: Yes")
|
||||
if result.errors:
|
||||
print(f" Errors: {result.errors}")
|
||||
logger.info(" Errors: %s", result.errors)
|
||||
|
||||
# Output results
|
||||
if len(results) == 1:
|
||||
@@ -132,9 +139,11 @@ def main():
|
||||
with open(args.output, 'w', encoding='utf-8') as f:
|
||||
f.write(json_output)
|
||||
if args.verbose:
|
||||
print(f"\nResults written to: {args.output}")
|
||||
logger.info("Results written to: %s", args.output)
|
||||
else:
|
||||
print(json_output)
|
||||
# Output JSON to stdout (not logged)
|
||||
sys.stdout.write(json_output)
|
||||
sys.stdout.write('\n')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
25
packages/backend/backend/domain/__init__.py
Normal file
25
packages/backend/backend/domain/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Domain Layer
|
||||
|
||||
Business logic separated from technical implementation.
|
||||
Contains document classification and invoice validation logic.
|
||||
"""
|
||||
from backend.domain.document_classifier import (
|
||||
ClassificationResult,
|
||||
DocumentClassifier,
|
||||
)
|
||||
from backend.domain.invoice_validator import (
|
||||
InvoiceValidator,
|
||||
ValidationIssue,
|
||||
ValidationResult,
|
||||
)
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
__all__ = [
|
||||
"ClassificationResult",
|
||||
"DocumentClassifier",
|
||||
"InvoiceValidator",
|
||||
"ValidationIssue",
|
||||
"ValidationResult",
|
||||
"has_value",
|
||||
]
|
||||
108
packages/backend/backend/domain/document_classifier.py
Normal file
108
packages/backend/backend/domain/document_classifier.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Document Classifier
|
||||
|
||||
Business logic for classifying documents based on extracted fields.
|
||||
Separates classification logic from inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ClassificationResult:
|
||||
"""
|
||||
Immutable result of document classification.
|
||||
|
||||
Attributes:
|
||||
document_type: Either "invoice" or "letter"
|
||||
confidence: Confidence score between 0.0 and 1.0
|
||||
reason: Human-readable explanation of classification
|
||||
"""
|
||||
|
||||
document_type: str
|
||||
confidence: float
|
||||
reason: str
|
||||
|
||||
|
||||
class DocumentClassifier:
|
||||
"""
|
||||
Classifies documents as invoice or letter based on extracted fields.
|
||||
|
||||
Classification Rules:
|
||||
1. If payment_line is present -> invoice (high confidence)
|
||||
2. If 2+ invoice indicators present -> invoice (medium confidence)
|
||||
3. If 1 invoice indicator present -> invoice (lower confidence)
|
||||
4. Otherwise -> letter
|
||||
|
||||
Invoice indicator fields:
|
||||
- payment_line (strongest indicator)
|
||||
- OCR
|
||||
- Amount
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- InvoiceNumber
|
||||
"""
|
||||
|
||||
INVOICE_INDICATOR_FIELDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"payment_line",
|
||||
"OCR",
|
||||
"Amount",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"InvoiceNumber",
|
||||
}
|
||||
)
|
||||
|
||||
def classify(self, fields: dict[str, str | None]) -> ClassificationResult:
|
||||
"""
|
||||
Classify document type based on extracted fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values.
|
||||
Empty strings or whitespace-only strings are treated as missing.
|
||||
|
||||
Returns:
|
||||
Immutable ClassificationResult with type, confidence, and reason.
|
||||
"""
|
||||
# Rule 1: payment_line is the strongest indicator
|
||||
if has_value(fields.get("payment_line")):
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.95,
|
||||
reason="payment_line detected",
|
||||
)
|
||||
|
||||
# Count present invoice indicators (excluding payment_line already checked)
|
||||
present_indicators = [
|
||||
field
|
||||
for field in self.INVOICE_INDICATOR_FIELDS
|
||||
if field != "payment_line" and has_value(fields.get(field))
|
||||
]
|
||||
indicator_count = len(present_indicators)
|
||||
|
||||
# Rule 2: Multiple indicators -> invoice with medium-high confidence
|
||||
if indicator_count >= 2:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.8,
|
||||
reason=f"{indicator_count} invoice indicators present: {', '.join(present_indicators)}",
|
||||
)
|
||||
|
||||
# Rule 3: Single indicator -> invoice with lower confidence
|
||||
if indicator_count == 1:
|
||||
return ClassificationResult(
|
||||
document_type="invoice",
|
||||
confidence=0.6,
|
||||
reason=f"1 invoice indicator present: {present_indicators[0]}",
|
||||
)
|
||||
|
||||
# Rule 4: No indicators -> letter
|
||||
return ClassificationResult(
|
||||
document_type="letter",
|
||||
confidence=0.7,
|
||||
reason="no invoice indicators found",
|
||||
)
|
||||
141
packages/backend/backend/domain/invoice_validator.py
Normal file
141
packages/backend/backend/domain/invoice_validator.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Invoice Validator
|
||||
|
||||
Business logic for validating extracted invoice fields.
|
||||
Checks for required fields, format validity, and confidence thresholds.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from backend.domain.utils import has_value
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationIssue:
|
||||
"""
|
||||
Single validation issue.
|
||||
|
||||
Attributes:
|
||||
field: Name of the field with the issue
|
||||
severity: One of "error", "warning", "info"
|
||||
message: Human-readable description of the issue
|
||||
"""
|
||||
|
||||
field: str
|
||||
severity: str
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ValidationResult:
|
||||
"""
|
||||
Immutable result of invoice validation.
|
||||
|
||||
Attributes:
|
||||
is_valid: True if no errors (warnings are allowed)
|
||||
issues: Tuple of validation issues found
|
||||
confidence: Average confidence score of validated fields
|
||||
"""
|
||||
|
||||
is_valid: bool
|
||||
issues: tuple[ValidationIssue, ...]
|
||||
confidence: float
|
||||
|
||||
|
||||
class InvoiceValidator:
|
||||
"""
|
||||
Validates extracted invoice fields for completeness and consistency.
|
||||
|
||||
Validation Rules:
|
||||
1. Required fields must be present (Amount)
|
||||
2. At least one payment reference should be present (warning if missing)
|
||||
3. Field confidence should be above threshold (warning if below)
|
||||
|
||||
Required fields:
|
||||
- Amount
|
||||
|
||||
Payment reference fields (at least one expected):
|
||||
- OCR
|
||||
- Bankgiro
|
||||
- Plusgiro
|
||||
- payment_line
|
||||
"""
|
||||
|
||||
REQUIRED_FIELDS: tuple[str, ...] = ("Amount",)
|
||||
PAYMENT_REF_FIELDS: tuple[str, ...] = ("OCR", "Bankgiro", "Plusgiro", "payment_line")
|
||||
DEFAULT_MIN_CONFIDENCE: float = 0.5
|
||||
|
||||
def __init__(self, min_confidence: float = DEFAULT_MIN_CONFIDENCE) -> None:
|
||||
"""
|
||||
Initialize validator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum confidence threshold for valid fields.
|
||||
Fields below this threshold produce warnings.
|
||||
"""
|
||||
self._min_confidence = min_confidence
|
||||
|
||||
def validate(
|
||||
self,
|
||||
fields: dict[str, str | None],
|
||||
confidence: dict[str, float],
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate extracted invoice fields.
|
||||
|
||||
Args:
|
||||
fields: Dictionary of field names to extracted values
|
||||
confidence: Dictionary of field names to confidence scores
|
||||
|
||||
Returns:
|
||||
Immutable ValidationResult with validity status and issues
|
||||
"""
|
||||
issues: list[ValidationIssue] = []
|
||||
|
||||
# Check required fields
|
||||
for field in self.REQUIRED_FIELDS:
|
||||
if not has_value(fields.get(field)):
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="error",
|
||||
message=f"Required field '{field}' is missing",
|
||||
)
|
||||
)
|
||||
|
||||
# Check payment reference (at least one expected)
|
||||
has_payment_ref = any(
|
||||
has_value(fields.get(f)) for f in self.PAYMENT_REF_FIELDS
|
||||
)
|
||||
if not has_payment_ref:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field="payment_reference",
|
||||
severity="warning",
|
||||
message="No payment reference (OCR, Bankgiro, Plusgiro, or payment_line)",
|
||||
)
|
||||
)
|
||||
|
||||
# Check confidence thresholds
|
||||
for field, conf in confidence.items():
|
||||
if conf < self._min_confidence:
|
||||
issues.append(
|
||||
ValidationIssue(
|
||||
field=field,
|
||||
severity="warning",
|
||||
message=f"Low confidence ({conf:.2f}) for field '{field}'",
|
||||
)
|
||||
)
|
||||
|
||||
# Calculate overall validity
|
||||
has_errors = any(i.severity == "error" for i in issues)
|
||||
avg_confidence = (
|
||||
sum(confidence.values()) / len(confidence) if confidence else 0.0
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=not has_errors,
|
||||
issues=tuple(issues),
|
||||
confidence=avg_confidence,
|
||||
)
|
||||
23
packages/backend/backend/domain/utils.py
Normal file
23
packages/backend/backend/domain/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Domain Layer Utilities
|
||||
|
||||
Shared helper functions for domain layer classes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def has_value(value: str | None) -> bool:
|
||||
"""
|
||||
Check if a field value is present and non-empty.
|
||||
|
||||
Args:
|
||||
value: Field value to check
|
||||
|
||||
Returns:
|
||||
True if value is a non-empty, non-whitespace string
|
||||
"""
|
||||
if value is None:
|
||||
return False
|
||||
if not isinstance(value, str):
|
||||
return bool(value)
|
||||
return bool(value.strip())
|
||||
@@ -417,7 +417,12 @@ class InferencePipeline:
|
||||
result.errors.append(f"Business feature extraction error: {error_detail}")
|
||||
|
||||
def _merge_fields(self, result: InferenceResult) -> None:
|
||||
"""Merge extracted fields, keeping highest confidence for each field."""
|
||||
"""Merge extracted fields, keeping best candidate for each field.
|
||||
|
||||
Selection priority:
|
||||
1. Prefer candidates without validation errors
|
||||
2. Among equal validity, prefer higher confidence
|
||||
"""
|
||||
field_candidates: dict[str, list[ExtractedField]] = {}
|
||||
|
||||
for extracted in result.extracted_fields:
|
||||
@@ -430,7 +435,12 @@ class InferencePipeline:
|
||||
|
||||
# Select best candidate for each field
|
||||
for field_name, candidates in field_candidates.items():
|
||||
best = max(candidates, key=lambda x: x.confidence)
|
||||
# Sort by: (no validation error, confidence) - descending
|
||||
# This prefers candidates without errors, then by confidence
|
||||
best = max(
|
||||
candidates,
|
||||
key=lambda x: (x.validation_error is None, x.confidence)
|
||||
)
|
||||
result.fields[field_name] = best.normalized_value
|
||||
result.confidence[field_name] = best.confidence
|
||||
# Store bbox for each field (useful for payment_line and other fields)
|
||||
|
||||
204
packages/backend/backend/table/html_table_parser.py
Normal file
204
packages/backend/backend/table/html_table_parser.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
HTML Table Parser
|
||||
|
||||
Parses HTML tables into structured data and maps columns to field names.
|
||||
"""
|
||||
|
||||
from html.parser import HTMLParser
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constants
|
||||
# Minimum pattern length to avoid false positives from short substrings
|
||||
MIN_PATTERN_MATCH_LENGTH = 3
|
||||
# Exact match bonus for column mapping priority
|
||||
EXACT_MATCH_BONUS = 100
|
||||
|
||||
# Swedish column name mappings
|
||||
# Extended to support multiple invoice types: product invoices, rental invoices, utility bills
|
||||
COLUMN_MAPPINGS = {
|
||||
"article_number": [
|
||||
"art nummer",
|
||||
"artikelnummer",
|
||||
"artikel",
|
||||
"artnr",
|
||||
"art.nr",
|
||||
"art nr",
|
||||
"objektnummer", # Rental: property reference
|
||||
"objekt",
|
||||
],
|
||||
"description": [
|
||||
"beskrivning",
|
||||
"produktbeskrivning",
|
||||
"produkt",
|
||||
"tjänst",
|
||||
"text",
|
||||
"benämning",
|
||||
"vara/tjänst",
|
||||
"vara",
|
||||
# Rental invoice specific
|
||||
"specifikation",
|
||||
"spec",
|
||||
"hyresperiod", # Rental period
|
||||
"period",
|
||||
"typ", # Type of charge
|
||||
# Utility bills
|
||||
"förbrukning", # Consumption
|
||||
"avläsning", # Meter reading
|
||||
],
|
||||
"quantity": ["antal", "qty", "st", "pcs", "kvantitet", "m²", "kvm"],
|
||||
"unit": ["enhet", "unit"],
|
||||
"unit_price": ["á-pris", "a-pris", "pris", "styckpris", "enhetspris", "à pris"],
|
||||
"amount": [
|
||||
"belopp",
|
||||
"summa",
|
||||
"total",
|
||||
"netto",
|
||||
"rad summa",
|
||||
# Rental specific
|
||||
"hyra", # Rent
|
||||
"avgift", # Fee
|
||||
"kostnad", # Cost
|
||||
"debitering", # Charge
|
||||
"totalt", # Total
|
||||
],
|
||||
"vat_rate": ["moms", "moms%", "vat", "skatt", "moms %"],
|
||||
# Additional field for rental: deductions/adjustments
|
||||
"deduction": [
|
||||
"avdrag", # Deduction
|
||||
"rabatt", # Discount
|
||||
"kredit", # Credit
|
||||
],
|
||||
}
|
||||
|
||||
# Keywords that indicate NOT a line items table
|
||||
SUMMARY_KEYWORDS = [
|
||||
"frakt",
|
||||
"faktura.avg",
|
||||
"fakturavg",
|
||||
"exkl.moms",
|
||||
"att betala",
|
||||
"öresavr",
|
||||
"bankgiro",
|
||||
"plusgiro",
|
||||
"ocr",
|
||||
"forfallodatum",
|
||||
"förfallodatum",
|
||||
]
|
||||
|
||||
|
||||
class _TableHTMLParser(HTMLParser):
|
||||
"""Internal HTML parser for tables."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rows: list[list[str]] = []
|
||||
self.current_row: list[str] = []
|
||||
self.current_cell: str = ""
|
||||
self.in_td = False
|
||||
self.in_thead = False
|
||||
self.header_row: list[str] = []
|
||||
|
||||
def handle_starttag(self, tag, attrs):
|
||||
if tag == "tr":
|
||||
self.current_row = []
|
||||
elif tag in ("td", "th"):
|
||||
self.in_td = True
|
||||
self.current_cell = ""
|
||||
elif tag == "thead":
|
||||
self.in_thead = True
|
||||
|
||||
def handle_endtag(self, tag):
|
||||
if tag in ("td", "th"):
|
||||
self.in_td = False
|
||||
self.current_row.append(self.current_cell.strip())
|
||||
elif tag == "tr":
|
||||
if self.current_row:
|
||||
if self.in_thead:
|
||||
self.header_row = self.current_row
|
||||
else:
|
||||
self.rows.append(self.current_row)
|
||||
elif tag == "thead":
|
||||
self.in_thead = False
|
||||
|
||||
def handle_data(self, data):
|
||||
if self.in_td:
|
||||
self.current_cell += data
|
||||
|
||||
|
||||
class HTMLTableParser:
|
||||
"""Parse HTML tables into structured data."""
|
||||
|
||||
def parse(self, html: str) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Parse HTML table and return header and rows.
|
||||
|
||||
Args:
|
||||
html: HTML string containing table.
|
||||
|
||||
Returns:
|
||||
Tuple of (header_row, data_rows).
|
||||
"""
|
||||
parser = _TableHTMLParser()
|
||||
parser.feed(html)
|
||||
return parser.header_row, parser.rows
|
||||
|
||||
|
||||
class ColumnMapper:
|
||||
"""Map column headers to field names."""
|
||||
|
||||
def __init__(self, mappings: dict[str, list[str]] | None = None):
|
||||
"""
|
||||
Initialize column mapper.
|
||||
|
||||
Args:
|
||||
mappings: Custom column mappings. Uses Swedish defaults if None.
|
||||
"""
|
||||
self.mappings = mappings or COLUMN_MAPPINGS
|
||||
|
||||
def map(self, headers: list[str]) -> dict[int, str]:
|
||||
"""
|
||||
Map column indices to field names.
|
||||
|
||||
Args:
|
||||
headers: List of column header strings.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping column index to field name.
|
||||
"""
|
||||
mapping = {}
|
||||
for idx, header in enumerate(headers):
|
||||
normalized = self._normalize(header)
|
||||
|
||||
if not normalized.strip():
|
||||
continue
|
||||
|
||||
best_match = None
|
||||
best_match_len = 0
|
||||
|
||||
for field_name, patterns in self.mappings.items():
|
||||
for pattern in patterns:
|
||||
if pattern == normalized:
|
||||
# Exact match gets highest priority
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern) + EXACT_MATCH_BONUS
|
||||
break
|
||||
elif pattern in normalized and len(pattern) > best_match_len:
|
||||
# Partial match requires minimum length to avoid false positives
|
||||
if len(pattern) >= MIN_PATTERN_MATCH_LENGTH:
|
||||
best_match = field_name
|
||||
best_match_len = len(pattern)
|
||||
|
||||
if best_match_len > EXACT_MATCH_BONUS:
|
||||
# Found exact match, no need to check other fields
|
||||
break
|
||||
|
||||
if best_match:
|
||||
mapping[idx] = best_match
|
||||
|
||||
return mapping
|
||||
|
||||
def _normalize(self, header: str) -> str:
|
||||
"""Normalize header text for matching."""
|
||||
return header.lower().strip().replace(".", "").replace("-", " ")
|
||||
File diff suppressed because it is too large
Load Diff
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
423
packages/backend/backend/table/merged_cell_handler.py
Normal file
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Merged Cell Handler
|
||||
|
||||
Handles detection and extraction of data from tables with merged cells,
|
||||
a common issue with PP-StructureV3 OCR output.
|
||||
"""
|
||||
|
||||
import re
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .models import LineItem
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .html_table_parser import ColumnMapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum positive amount to consider as line item (filters noise like row indices)
|
||||
MIN_AMOUNT_THRESHOLD = 100
|
||||
|
||||
|
||||
class MergedCellHandler:
|
||||
"""Handles tables with vertically merged cells from PP-StructureV3."""
|
||||
|
||||
def __init__(self, mapper: "ColumnMapper"):
|
||||
"""
|
||||
Initialize handler.
|
||||
|
||||
Args:
|
||||
mapper: ColumnMapper instance for header keyword detection.
|
||||
"""
|
||||
self.mapper = mapper
|
||||
|
||||
def has_vertically_merged_cells(self, rows: list[list[str]]) -> bool:
|
||||
"""
|
||||
Check if table rows contain vertically merged data in single cells.
|
||||
|
||||
PP-StructureV3 sometimes merges multiple table rows into single cells, e.g.:
|
||||
["Produktnr 1457280 1457280 1060381", "", "Antal 6ST 6ST 1ST", "Pris 127,20 127,20 159,20"]
|
||||
|
||||
Detection: cells contain repeating patterns of numbers or keywords suggesting multiple lines.
|
||||
"""
|
||||
if not rows:
|
||||
return False
|
||||
|
||||
for row in rows:
|
||||
for cell in row:
|
||||
if not cell or len(cell) < 20:
|
||||
continue
|
||||
|
||||
# Check for multiple product numbers (7+ digit patterns)
|
||||
product_nums = re.findall(r"\b\d{7}\b", cell)
|
||||
if len(product_nums) >= 2:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(product_nums)} product numbers in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple prices (Swedish format: 123,45 or 1 234,56)
|
||||
prices = re.findall(r"\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b", cell)
|
||||
if len(prices) >= 3:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(prices)} prices in cell")
|
||||
return True
|
||||
|
||||
# Check for multiple quantity patterns (e.g., "6ST 6ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", cell)
|
||||
if len(quantities) >= 2:
|
||||
logger.debug(f"has_vertically_merged_cells: found {len(quantities)} quantities in cell")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def split_merged_rows(
|
||||
self, rows: list[list[str]]
|
||||
) -> tuple[list[str], list[list[str]]]:
|
||||
"""
|
||||
Split vertically merged cells back into separate rows.
|
||||
|
||||
Handles complex cases where PP-StructureV3 merges content across
|
||||
multiple HTML rows. For example, 5 line items might be spread across
|
||||
3 HTML rows with content mixed together.
|
||||
|
||||
Strategy:
|
||||
1. Merge all row content per column
|
||||
2. Detect how many actual data rows exist (by counting product numbers)
|
||||
3. Split each column's content into that many lines
|
||||
|
||||
Returns header and data rows.
|
||||
"""
|
||||
if not rows:
|
||||
return [], []
|
||||
|
||||
# Filter out completely empty rows
|
||||
non_empty_rows = [r for r in rows if any(cell.strip() for cell in r)]
|
||||
if not non_empty_rows:
|
||||
return [], rows
|
||||
|
||||
# Determine column count
|
||||
col_count = max(len(r) for r in non_empty_rows)
|
||||
|
||||
# Merge content from all rows for each column
|
||||
merged_columns = []
|
||||
for col_idx in range(col_count):
|
||||
col_content = []
|
||||
for row in non_empty_rows:
|
||||
if col_idx < len(row) and row[col_idx].strip():
|
||||
col_content.append(row[col_idx].strip())
|
||||
merged_columns.append(" ".join(col_content))
|
||||
|
||||
logger.debug(f"split_merged_rows: merged columns = {merged_columns}")
|
||||
|
||||
# Count how many actual data rows we should have
|
||||
# Use the column with most product numbers as reference
|
||||
expected_rows = self._count_expected_rows(merged_columns)
|
||||
logger.debug(f"split_merged_rows: expecting {expected_rows} data rows")
|
||||
|
||||
if expected_rows <= 1:
|
||||
# Not enough data for splitting
|
||||
return [], rows
|
||||
|
||||
# Split each column based on expected row count
|
||||
split_columns = []
|
||||
for col_idx, col_text in enumerate(merged_columns):
|
||||
if not col_text.strip():
|
||||
split_columns.append([""] * (expected_rows + 1)) # +1 for header
|
||||
continue
|
||||
lines = self._split_cell_content_for_rows(col_text, expected_rows)
|
||||
split_columns.append(lines)
|
||||
|
||||
# Ensure all columns have same number of lines (immutable approach)
|
||||
max_lines = max(len(col) for col in split_columns)
|
||||
split_columns = [
|
||||
col + [""] * (max_lines - len(col))
|
||||
for col in split_columns
|
||||
]
|
||||
|
||||
logger.debug(f"split_merged_rows: split into {max_lines} lines total")
|
||||
|
||||
# First line is header, rest are data rows
|
||||
header = [col[0] for col in split_columns]
|
||||
data_rows = []
|
||||
for line_idx in range(1, max_lines):
|
||||
row = [col[line_idx] if line_idx < len(col) else "" for col in split_columns]
|
||||
if any(cell.strip() for cell in row):
|
||||
data_rows.append(row)
|
||||
|
||||
logger.debug(f"split_merged_rows: header={header}, data_rows count={len(data_rows)}")
|
||||
return header, data_rows
|
||||
|
||||
def _count_expected_rows(self, merged_columns: list[str]) -> int:
|
||||
"""
|
||||
Count how many data rows should exist based on content patterns.
|
||||
|
||||
Returns the maximum count found from:
|
||||
- Product numbers (7 digits)
|
||||
- Quantity patterns (number + ST/PCS)
|
||||
- Amount patterns (in columns likely to be totals)
|
||||
"""
|
||||
max_count = 0
|
||||
|
||||
for col_text in merged_columns:
|
||||
if not col_text:
|
||||
continue
|
||||
|
||||
# Count product numbers (most reliable indicator)
|
||||
product_nums = re.findall(r"\b\d{7}\b", col_text)
|
||||
max_count = max(max_count, len(product_nums))
|
||||
|
||||
# Count quantities (e.g., "6ST 6ST 1ST 1ST 1ST")
|
||||
quantities = re.findall(r"\b\d+\s*(?:ST|st|PCS|pcs)\b", col_text)
|
||||
max_count = max(max_count, len(quantities))
|
||||
|
||||
return max_count
|
||||
|
||||
def _split_cell_content_for_rows(self, cell: str, expected_rows: int) -> list[str]:
|
||||
"""
|
||||
Split cell content knowing how many data rows we expect.
|
||||
|
||||
This is smarter than split_cell_content because it knows the target count.
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Try product number split first
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) == expected_rows:
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
# Include description text after each product number
|
||||
values = []
|
||||
for i in range(1, len(parts), 2): # Odd indices are product numbers
|
||||
if i < len(parts):
|
||||
prod_num = parts[i].strip()
|
||||
# Check if there's description text after
|
||||
desc = parts[i + 1].strip() if i + 1 < len(parts) else ""
|
||||
# If description looks like text (not another pattern), include it
|
||||
if desc and not re.match(r"^\d{7}$", desc):
|
||||
# Truncate at next product number pattern if any
|
||||
desc_clean = re.split(r"\d{7}", desc)[0].strip()
|
||||
if desc_clean:
|
||||
values.append(f"{prod_num} {desc_clean}")
|
||||
else:
|
||||
values.append(prod_num)
|
||||
else:
|
||||
values.append(prod_num)
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try quantity split
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) == expected_rows:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
if len(values) == expected_rows:
|
||||
return [header] + values
|
||||
|
||||
# Try amount split for discount+totalsumma columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_total = any(kw in cell_lower for kw in ["totalsumma", "total", "summa", "belopp"])
|
||||
|
||||
if has_discount and has_total:
|
||||
# Extract only amounts (3+ digit numbers), skip discount percentages
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
if len(amounts) >= expected_rows:
|
||||
# Take the last expected_rows amounts (they are likely the totals)
|
||||
return ["Totalsumma"] + amounts[:expected_rows]
|
||||
|
||||
# Try price split
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= expected_rows:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
if len(values) >= expected_rows:
|
||||
return [header] + values[:expected_rows]
|
||||
|
||||
# Fall back to original single-value behavior
|
||||
return [cell]
|
||||
|
||||
def split_cell_content(self, cell: str) -> list[str]:
|
||||
"""
|
||||
Split a cell containing merged multi-line content.
|
||||
|
||||
Strategies:
|
||||
1. Look for product number patterns (7 digits)
|
||||
2. Look for quantity patterns (number + ST/PCS)
|
||||
3. Look for price patterns (with decimal)
|
||||
4. Handle interleaved discount+amount patterns
|
||||
"""
|
||||
cell = cell.strip()
|
||||
|
||||
# Strategy 1: Split by product numbers (common pattern: "Produktnr 1234567 1234568")
|
||||
product_pattern = re.compile(r"(\b\d{7}\b)")
|
||||
products = product_pattern.findall(cell)
|
||||
if len(products) >= 2:
|
||||
# Extract header (text before first product number) and values
|
||||
parts = product_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p for p in parts[1:] if p.strip() and re.match(r"\d{7}", p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 2: Split by quantities (e.g., "Antal 6ST 6ST 1ST")
|
||||
qty_pattern = re.compile(r"(\b\d+\s*(?:ST|st|PCS|pcs|M|m|KG|kg)\b)")
|
||||
quantities = qty_pattern.findall(cell)
|
||||
if len(quantities) >= 2:
|
||||
parts = qty_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and qty_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# Strategy 3: Handle interleaved discount+amount (e.g., "Rabatt i% Totalsumma 10,0 686,88 10,0 686,88")
|
||||
# Check if header contains two keywords indicating merged columns
|
||||
cell_lower = cell.lower()
|
||||
has_discount_header = any(kw in cell_lower for kw in ["rabatt", "discount"])
|
||||
has_amount_header = any(kw in cell_lower for kw in ["totalsumma", "summa", "belopp", "total"])
|
||||
|
||||
if has_discount_header and has_amount_header:
|
||||
# Extract all numbers and pair them (discount, amount, discount, amount, ...)
|
||||
# Pattern for amounts: 3+ digit numbers with decimals (e.g., 686,88)
|
||||
amount_pattern = re.compile(r"\b(\d{3,}[,\.]\d{2})\b")
|
||||
amounts = amount_pattern.findall(cell)
|
||||
|
||||
if len(amounts) >= 2:
|
||||
# Return header as "Totalsumma" (amount header) so it maps to amount field, not deduction
|
||||
# This avoids the "Rabatt" keyword causing is_deduction=True
|
||||
header = "Totalsumma"
|
||||
return [header] + amounts
|
||||
|
||||
# Strategy 4: Split by prices (e.g., "Pris 127,20 127,20 159,20")
|
||||
price_pattern = re.compile(r"(\b\d{1,3}(?:\s?\d{3})*[,\.]\d{2}\b)")
|
||||
prices = price_pattern.findall(cell)
|
||||
if len(prices) >= 2:
|
||||
parts = price_pattern.split(cell)
|
||||
header = parts[0].strip() if parts else ""
|
||||
values = [p.strip() for p in parts[1:] if p.strip() and price_pattern.match(p)]
|
||||
return [header] + values
|
||||
|
||||
# No pattern detected, return as single value
|
||||
return [cell]
|
||||
|
||||
def has_merged_header(self, header: list[str] | None) -> bool:
|
||||
"""
|
||||
Check if header appears to be a merged cell containing multiple column names.
|
||||
|
||||
This happens when OCR merges table headers into a single cell, e.g.:
|
||||
"Specifikation 0218103-1201 2 rum och kök Hyra Avdrag" instead of separate columns.
|
||||
|
||||
Also handles cases where PP-StructureV3 produces headers like:
|
||||
["Specifikation ... Hyra Avdrag", "", "", ""] with empty trailing cells.
|
||||
"""
|
||||
if header is None or not header:
|
||||
return False
|
||||
|
||||
# Filter out empty cells to find the actual content
|
||||
non_empty_cells = [h for h in header if h.strip()]
|
||||
|
||||
# Check if we have a single non-empty cell that contains multiple keywords
|
||||
if len(non_empty_cells) == 1:
|
||||
header_text = non_empty_cells[0].lower()
|
||||
# Count how many column keywords are in this single cell
|
||||
keyword_count = 0
|
||||
for patterns in self.mapper.mappings.values():
|
||||
for pattern in patterns:
|
||||
if pattern in header_text:
|
||||
keyword_count += 1
|
||||
break # Only count once per field type
|
||||
|
||||
logger.debug(f"has_merged_header: header_text='{header_text}', keyword_count={keyword_count}")
|
||||
return keyword_count >= 2
|
||||
|
||||
return False
|
||||
|
||||
def extract_from_merged_cells(
|
||||
self, header: list[str], rows: list[list[str]]
|
||||
) -> list[LineItem]:
|
||||
"""
|
||||
Extract line items from tables with merged cells.
|
||||
|
||||
For poorly OCR'd tables like:
|
||||
Header: ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
Row 1: ["", "", "", "8159"] <- amount row
|
||||
Row 2: ["", "", "", "-2 000"] <- deduction row (separate line item)
|
||||
|
||||
Or:
|
||||
Row: ["", "", "", "8159 -2 000"] <- both in same row -> 2 line items
|
||||
|
||||
Each amount becomes its own line item. Negative amounts are marked as is_deduction=True.
|
||||
"""
|
||||
items = []
|
||||
|
||||
# Amount pattern for Swedish format - match numbers like "8159" or "8 159" or "-2000" or "-2 000"
|
||||
amount_pattern = re.compile(
|
||||
r"(-?\d[\d\s]*(?:[,\.]\d+)?)"
|
||||
)
|
||||
|
||||
# Try to parse header cell for description info
|
||||
header_text = " ".join(h for h in header if h.strip()) if header else ""
|
||||
logger.debug(f"extract_from_merged_cells: header_text='{header_text}'")
|
||||
logger.debug(f"extract_from_merged_cells: rows={rows}")
|
||||
|
||||
# Extract description from header
|
||||
description = None
|
||||
article_number = None
|
||||
|
||||
# Look for object number pattern (e.g., "0218103-1201")
|
||||
obj_match = re.search(r"(\d{7}-\d{4})", header_text)
|
||||
if obj_match:
|
||||
article_number = obj_match.group(1)
|
||||
|
||||
# Look for description after object number
|
||||
desc_match = re.search(r"\d{7}-\d{4}\s+(.+?)(?:\s+(?:Hyra|Avdrag|Belopp))", header_text, re.IGNORECASE)
|
||||
if desc_match:
|
||||
description = desc_match.group(1).strip()
|
||||
|
||||
row_index = 0
|
||||
for row in rows:
|
||||
# Combine all non-empty cells in the row
|
||||
row_text = " ".join(cell.strip() for cell in row if cell.strip())
|
||||
logger.debug(f"extract_from_merged_cells: row text='{row_text}'")
|
||||
|
||||
if not row_text:
|
||||
continue
|
||||
|
||||
# Find all amounts in the row
|
||||
amounts = amount_pattern.findall(row_text)
|
||||
logger.debug(f"extract_from_merged_cells: amounts={amounts}")
|
||||
|
||||
for amt_str in amounts:
|
||||
# Clean the amount string
|
||||
cleaned = amt_str.replace(" ", "").strip()
|
||||
if not cleaned or cleaned == "-":
|
||||
continue
|
||||
|
||||
is_deduction = cleaned.startswith("-")
|
||||
|
||||
# Skip small positive numbers that are likely not amounts
|
||||
# (e.g., row indices, small percentages)
|
||||
if not is_deduction:
|
||||
try:
|
||||
val = float(cleaned.replace(",", "."))
|
||||
if val < MIN_AMOUNT_THRESHOLD:
|
||||
continue
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
# Create a line item for each amount
|
||||
item = LineItem(
|
||||
row_index=row_index,
|
||||
description=description if row_index == 0 else "Avdrag" if is_deduction else None,
|
||||
article_number=article_number if row_index == 0 else None,
|
||||
amount=cleaned,
|
||||
is_deduction=is_deduction,
|
||||
confidence=0.7,
|
||||
)
|
||||
items.append(item)
|
||||
row_index += 1
|
||||
logger.debug(f"extract_from_merged_cells: created item amount={cleaned}, is_deduction={is_deduction}")
|
||||
|
||||
return items
|
||||
61
packages/backend/backend/table/models.py
Normal file
61
packages/backend/backend/table/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""
|
||||
Line Items Data Models
|
||||
|
||||
Dataclasses for line item extraction results.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal, InvalidOperation
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItem:
|
||||
"""Single line item from invoice."""
|
||||
|
||||
row_index: int
|
||||
description: str | None = None
|
||||
quantity: str | None = None
|
||||
unit: str | None = None
|
||||
unit_price: str | None = None
|
||||
amount: str | None = None
|
||||
article_number: str | None = None
|
||||
vat_rate: str | None = None
|
||||
is_deduction: bool = False # True if this row is a deduction/discount
|
||||
confidence: float = 0.9
|
||||
|
||||
|
||||
@dataclass
|
||||
class LineItemsResult:
|
||||
"""Result of line items extraction."""
|
||||
|
||||
items: list[LineItem]
|
||||
header_row: list[str]
|
||||
raw_html: str
|
||||
is_reversed: bool = False
|
||||
|
||||
@property
|
||||
def total_amount(self) -> str | None:
|
||||
"""Calculate total amount from line items (deduction rows have negative amounts)."""
|
||||
if not self.items:
|
||||
return None
|
||||
|
||||
total = Decimal("0")
|
||||
for item in self.items:
|
||||
if item.amount:
|
||||
try:
|
||||
# Parse Swedish number format (1 234,56)
|
||||
amount_str = item.amount.replace(" ", "").replace(",", ".")
|
||||
total += Decimal(amount_str)
|
||||
except InvalidOperation:
|
||||
pass
|
||||
|
||||
if total == 0:
|
||||
return None
|
||||
|
||||
# Format back to Swedish format
|
||||
formatted = f"{total:,.2f}".replace(",", " ").replace(".", ",")
|
||||
# Fix the space/comma swap
|
||||
parts = formatted.rsplit(",", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[0].replace(" ", " ") + "," + parts[1]
|
||||
return formatted
|
||||
@@ -158,36 +158,36 @@ class TableDetector:
|
||||
return tables
|
||||
|
||||
# Log raw result type for debugging
|
||||
logger.info(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||
logger.debug(f"PP-StructureV3 raw results type: {type(results).__name__}")
|
||||
|
||||
# Handle case where results is a single dict-like object (PaddleX 3.x)
|
||||
# rather than a list of results
|
||||
if hasattr(results, "get") and not isinstance(results, list):
|
||||
# Single result object - wrap in list for uniform processing
|
||||
logger.info("Results is dict-like, wrapping in list")
|
||||
logger.debug("Results is dict-like, wrapping in list")
|
||||
results = [results]
|
||||
elif hasattr(results, "__iter__") and not isinstance(results, (list, tuple)):
|
||||
# Iterator or generator - convert to list
|
||||
try:
|
||||
results = list(results)
|
||||
logger.info(f"Converted iterator to list with {len(results)} items")
|
||||
logger.debug(f"Converted iterator to list with {len(results)} items")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert results to list: {e}")
|
||||
return tables
|
||||
|
||||
logger.info(f"Processing {len(results)} result(s)")
|
||||
logger.debug(f"Processing {len(results)} result(s)")
|
||||
|
||||
for i, result in enumerate(results):
|
||||
try:
|
||||
result_type = type(result).__name__
|
||||
has_get = hasattr(result, "get")
|
||||
has_layout = hasattr(result, "layout_elements")
|
||||
logger.info(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||
logger.debug(f"Result[{i}]: type={result_type}, has_get={has_get}, has_layout_elements={has_layout}")
|
||||
|
||||
# Try PaddleX 3.x API first (dict-like with table_res_list)
|
||||
if has_get:
|
||||
parsed = self._parse_paddlex_result(result)
|
||||
logger.info(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||
logger.debug(f"Result[{i}]: parsed {len(parsed)} tables via PaddleX path")
|
||||
tables.extend(parsed)
|
||||
continue
|
||||
|
||||
@@ -201,14 +201,14 @@ class TableDetector:
|
||||
if table_result and table_result.confidence >= self.config.min_confidence:
|
||||
tables.append(table_result)
|
||||
legacy_count += 1
|
||||
logger.info(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||
logger.debug(f"Result[{i}]: parsed {legacy_count} tables via legacy path")
|
||||
else:
|
||||
logger.warning(f"Result[{i}]: no recognized API (not dict-like and no layout_elements)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse result: {type(result).__name__}, error: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Total tables detected: {len(tables)}")
|
||||
logger.debug(f"Total tables detected: {len(tables)}")
|
||||
return tables
|
||||
|
||||
def _parse_paddlex_result(self, result: Any) -> list[TableDetectionResult]:
|
||||
@@ -223,7 +223,7 @@ class TableDetector:
|
||||
result_keys = list(result.keys())
|
||||
elif hasattr(result, "__dict__"):
|
||||
result_keys = list(result.__dict__.keys())
|
||||
logger.info(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||
logger.debug(f"Parsing PaddleX result: type={result_type}, keys={result_keys}")
|
||||
|
||||
# Get table results from PaddleX 3.x API
|
||||
# Handle both dict.get() and attribute access
|
||||
@@ -234,8 +234,8 @@ class TableDetector:
|
||||
table_res_list = getattr(result, "table_res_list", None)
|
||||
parsing_res_list = getattr(result, "parsing_res_list", [])
|
||||
|
||||
logger.info(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||
logger.info(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||
logger.debug(f"table_res_list: {type(table_res_list).__name__}, count={len(table_res_list) if table_res_list else 0}")
|
||||
logger.debug(f"parsing_res_list: {type(parsing_res_list).__name__}, count={len(parsing_res_list) if parsing_res_list else 0}")
|
||||
|
||||
if not table_res_list:
|
||||
# Log available keys/attributes for debugging
|
||||
@@ -330,7 +330,7 @@ class TableDetector:
|
||||
# Default confidence for PaddleX 3.x results
|
||||
confidence = 0.9
|
||||
|
||||
logger.info(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||
logger.debug(f"Table {i}: html_len={len(html)}, cells={len(cells)}")
|
||||
tables.append(TableDetectionResult(
|
||||
bbox=(float(bbox[0]), float(bbox[1]), float(bbox[2]), float(bbox[3])),
|
||||
html=html,
|
||||
@@ -467,14 +467,14 @@ class TableDetector:
|
||||
if not pdf_path.exists():
|
||||
raise FileNotFoundError(f"PDF not found: {pdf_path}")
|
||||
|
||||
logger.info(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||
logger.debug(f"detect_from_pdf: {pdf_path}, page={page_number}, dpi={dpi}")
|
||||
|
||||
# Render specific page
|
||||
for page_no, image_bytes in render_pdf_to_images(str(pdf_path), dpi=dpi):
|
||||
if page_no == page_number:
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
image_array = np.array(image)
|
||||
logger.info(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||
logger.debug(f"detect_from_pdf: rendered page {page_no}, image shape={image_array.shape}")
|
||||
return self.detect(image_array)
|
||||
|
||||
raise ValueError(f"Page {page_number} not found in PDF")
|
||||
|
||||
@@ -15,6 +15,11 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration constants
|
||||
DEFAULT_ROW_TOLERANCE = 15.0 # Max vertical distance (pixels) to consider same row
|
||||
MIN_ITEMS_FOR_VALID_EXTRACTION = 2 # Minimum items required for valid extraction
|
||||
MIN_TEXT_ELEMENTS_FOR_EXTRACTION = 5 # Minimum text elements needed to attempt extraction
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextElement:
|
||||
@@ -65,7 +70,10 @@ class TextLineItemsResult:
|
||||
extraction_method: str = "text_spatial"
|
||||
|
||||
|
||||
# Swedish amount pattern: 1 234,56 or 1234.56 or 1,234.56
|
||||
# Amount pattern matches Swedish, US, and simple numeric formats
|
||||
# Handles: "1 234,56", "1,234.56", "1234.56", "100 kr", "50:-", "-100,00"
|
||||
# Does NOT handle: amounts with more than 2 decimal places, scientific notation
|
||||
# See tests in test_text_line_items_extractor.py::TestAmountPattern
|
||||
AMOUNT_PATTERN = re.compile(
|
||||
r"(?<![0-9])(?:"
|
||||
r"-?\d{1,3}(?:\s\d{3})*(?:,\d{2})?" # Swedish: 1 234,56
|
||||
@@ -128,17 +136,17 @@ class TextLineItemsExtractor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
row_tolerance: float = 15.0, # Max vertical distance to consider same row
|
||||
min_items_for_valid: int = 2, # Minimum items to consider extraction valid
|
||||
row_tolerance: float = DEFAULT_ROW_TOLERANCE,
|
||||
min_items_for_valid: int = MIN_ITEMS_FOR_VALID_EXTRACTION,
|
||||
):
|
||||
"""
|
||||
Initialize extractor.
|
||||
|
||||
Args:
|
||||
row_tolerance: Maximum vertical distance (pixels) between elements
|
||||
to consider them on the same row.
|
||||
to consider them on the same row. Default: 15.0
|
||||
min_items_for_valid: Minimum number of line items required for
|
||||
extraction to be considered successful.
|
||||
extraction to be considered successful. Default: 2
|
||||
"""
|
||||
self.row_tolerance = row_tolerance
|
||||
self.min_items_for_valid = min_items_for_valid
|
||||
@@ -161,10 +169,13 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Extract text elements from parsing results
|
||||
text_elements = self._extract_text_elements(parsing_res_list)
|
||||
logger.info(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||
logger.debug(f"TextLineItemsExtractor: found {len(text_elements)} text elements")
|
||||
|
||||
if len(text_elements) < 5: # Need at least a few elements
|
||||
logger.debug("Too few text elements for line item extraction")
|
||||
if len(text_elements) < MIN_TEXT_ELEMENTS_FOR_EXTRACTION:
|
||||
logger.debug(
|
||||
f"Too few text elements ({len(text_elements)}) for line item extraction, "
|
||||
f"need at least {MIN_TEXT_ELEMENTS_FOR_EXTRACTION}"
|
||||
)
|
||||
return None
|
||||
|
||||
return self.extract_from_text_elements(text_elements)
|
||||
@@ -183,11 +194,11 @@ class TextLineItemsExtractor:
|
||||
"""
|
||||
# Group elements by row
|
||||
rows = self._group_by_row(text_elements)
|
||||
logger.info(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||
logger.debug(f"TextLineItemsExtractor: grouped into {len(rows)} rows")
|
||||
|
||||
# Find the line items section
|
||||
item_rows = self._identify_line_item_rows(rows)
|
||||
logger.info(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||
logger.debug(f"TextLineItemsExtractor: identified {len(item_rows)} potential item rows")
|
||||
|
||||
if len(item_rows) < self.min_items_for_valid:
|
||||
logger.debug(f"Found only {len(item_rows)} item rows, need at least {self.min_items_for_valid}")
|
||||
@@ -195,7 +206,7 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Extract structured items
|
||||
items = self._parse_line_items(item_rows)
|
||||
logger.info(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||
logger.debug(f"TextLineItemsExtractor: extracted {len(items)} line items")
|
||||
|
||||
if len(items) < self.min_items_for_valid:
|
||||
return None
|
||||
@@ -209,7 +220,11 @@ class TextLineItemsExtractor:
|
||||
def _extract_text_elements(
|
||||
self, parsing_res_list: list[dict[str, Any]]
|
||||
) -> list[TextElement]:
|
||||
"""Extract TextElement objects from parsing_res_list."""
|
||||
"""Extract TextElement objects from parsing_res_list.
|
||||
|
||||
Handles both dict and LayoutBlock object formats from PP-StructureV3.
|
||||
Gracefully skips invalid elements with appropriate logging.
|
||||
"""
|
||||
elements = []
|
||||
|
||||
for elem in parsing_res_list:
|
||||
@@ -220,11 +235,15 @@ class TextLineItemsExtractor:
|
||||
bbox = elem.get("bbox", [])
|
||||
# Try both 'text' and 'content' keys
|
||||
text = elem.get("text", "") or elem.get("content", "")
|
||||
else:
|
||||
elif hasattr(elem, "label"):
|
||||
label = getattr(elem, "label", "")
|
||||
bbox = getattr(elem, "bbox", [])
|
||||
# LayoutBlock objects use 'content' attribute
|
||||
text = getattr(elem, "content", "") or getattr(elem, "text", "")
|
||||
else:
|
||||
# Element is neither dict nor has expected attributes
|
||||
logger.debug(f"Skipping element with unexpected type: {type(elem).__name__}")
|
||||
continue
|
||||
|
||||
# Only process text elements (skip images, tables, etc.)
|
||||
if label not in ("text", "paragraph_title", "aside_text"):
|
||||
@@ -232,6 +251,7 @@ class TextLineItemsExtractor:
|
||||
|
||||
# Validate bbox
|
||||
if not self._valid_bbox(bbox):
|
||||
logger.debug(f"Skipping element with invalid bbox: {bbox}")
|
||||
continue
|
||||
|
||||
# Clean text
|
||||
@@ -250,8 +270,13 @@ class TextLineItemsExtractor:
|
||||
),
|
||||
)
|
||||
)
|
||||
except (KeyError, TypeError, ValueError, AttributeError) as e:
|
||||
# Expected format issues - log at debug level
|
||||
logger.debug(f"Skipping element due to format issue: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to parse element: {e}")
|
||||
# Unexpected errors - log at warning level for visibility
|
||||
logger.warning(f"Unexpected error parsing element: {type(e).__name__}: {e}")
|
||||
continue
|
||||
|
||||
return elements
|
||||
@@ -270,6 +295,7 @@ class TextLineItemsExtractor:
|
||||
Group text elements into rows based on vertical position.
|
||||
|
||||
Elements within row_tolerance of each other are considered same row.
|
||||
Uses dynamic average center_y to handle varying element heights more accurately.
|
||||
"""
|
||||
if not elements:
|
||||
return []
|
||||
@@ -277,22 +303,22 @@ class TextLineItemsExtractor:
|
||||
# Sort by vertical position
|
||||
sorted_elements = sorted(elements, key=lambda e: e.center_y)
|
||||
|
||||
rows = []
|
||||
current_row = [sorted_elements[0]]
|
||||
current_y = sorted_elements[0].center_y
|
||||
rows: list[list[TextElement]] = []
|
||||
current_row: list[TextElement] = [sorted_elements[0]]
|
||||
|
||||
for elem in sorted_elements[1:]:
|
||||
if abs(elem.center_y - current_y) <= self.row_tolerance:
|
||||
# Same row
|
||||
# Calculate dynamic average center_y for current row
|
||||
avg_center_y = sum(e.center_y for e in current_row) / len(current_row)
|
||||
|
||||
if abs(elem.center_y - avg_center_y) <= self.row_tolerance:
|
||||
# Same row - add element and recalculate average on next iteration
|
||||
current_row.append(elem)
|
||||
else:
|
||||
# New row
|
||||
if current_row:
|
||||
# Sort row by horizontal position
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
# New row - finalize current row
|
||||
# Sort row by horizontal position (left to right)
|
||||
current_row.sort(key=lambda e: e.center_x)
|
||||
rows.append(current_row)
|
||||
current_row = [elem]
|
||||
current_y = elem.center_y
|
||||
|
||||
# Don't forget last row
|
||||
if current_row:
|
||||
|
||||
@@ -7,6 +7,7 @@ the autolabel results to identify potential errors.
|
||||
|
||||
import json
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
@@ -14,6 +15,8 @@ from dataclasses import dataclass, asdict
|
||||
from datetime import datetime
|
||||
|
||||
import psycopg2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from psycopg2.extras import execute_values
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
@@ -648,7 +651,7 @@ Return ONLY the JSON object, no other text."""
|
||||
docs = self.get_documents_with_failed_matches(limit=limit)
|
||||
|
||||
if verbose:
|
||||
print(f"Found {len(docs)} documents with failed matches to validate")
|
||||
logger.info("Found %d documents with failed matches to validate", len(docs))
|
||||
|
||||
results = []
|
||||
for i, doc in enumerate(docs):
|
||||
@@ -656,16 +659,16 @@ Return ONLY the JSON object, no other text."""
|
||||
|
||||
if verbose:
|
||||
failed_fields = [f['field'] for f in doc['failed_fields']]
|
||||
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
||||
logger.info("[%d/%d] Validating %s... (failed: %s)", i+1, len(docs), doc_id[:8], ', '.join(failed_fields))
|
||||
|
||||
result = self.validate_document(doc_id, provider, model)
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
if result.error:
|
||||
print(f" ERROR: {result.error}")
|
||||
logger.error(" ERROR: %s", result.error)
|
||||
else:
|
||||
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
||||
logger.info(" OK (%.0fms)", result.processing_time_ms)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -39,6 +39,26 @@ from backend.web.schemas.common import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PDF magic bytes - all valid PDF files must start with this sequence
|
||||
PDF_MAGIC_BYTES = b"%PDF"
|
||||
|
||||
|
||||
def validate_pdf_magic_bytes(content: bytes) -> None:
|
||||
"""Validate that file content has valid PDF magic bytes.
|
||||
|
||||
PDF files must start with the bytes '%PDF' (0x25 0x50 0x44 0x46).
|
||||
This validation prevents attackers from uploading malicious files
|
||||
(executables, scripts) by simply renaming them to .pdf extension.
|
||||
|
||||
Args:
|
||||
content: The raw file content to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the content does not start with valid PDF magic bytes.
|
||||
"""
|
||||
if not content or not content.startswith(PDF_MAGIC_BYTES):
|
||||
raise ValueError("Invalid PDF file: does not have valid PDF header")
|
||||
|
||||
|
||||
def _validate_uuid(value: str, name: str = "ID") -> None:
|
||||
"""Validate UUID format."""
|
||||
@@ -135,6 +155,14 @@ def create_documents_router(storage_config: StorageConfig) -> APIRouter:
|
||||
logger.error(f"Failed to read uploaded file: {e}")
|
||||
raise HTTPException(status_code=400, detail="Failed to read file")
|
||||
|
||||
# Validate PDF magic bytes (only for PDF files)
|
||||
if file_ext == ".pdf":
|
||||
try:
|
||||
validate_pdf_magic_bytes(content)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PDF magic bytes validation failed: {e}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Get page count (for PDF)
|
||||
page_count = 1
|
||||
if file_ext == ".pdf":
|
||||
|
||||
@@ -11,6 +11,7 @@ from backend.web.schemas.admin import (
|
||||
ExportResponse,
|
||||
)
|
||||
from backend.web.schemas.common import ErrorResponse
|
||||
from shared.bbox import expand_bbox
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -102,12 +103,52 @@ def register_export_routes(router: APIRouter) -> None:
|
||||
dst_image.write_bytes(image_content)
|
||||
total_images += 1
|
||||
|
||||
# Get image dimensions for bbox expansion
|
||||
img_dims = storage.get_admin_image_dimensions(doc_id, page_num)
|
||||
if img_dims is None:
|
||||
# Fall back to standard A4 at 300 DPI if dimensions unavailable
|
||||
img_width, img_height = 2480, 3508
|
||||
else:
|
||||
img_width, img_height = img_dims
|
||||
|
||||
label_name = f"{doc.document_id}_page{page_num}.txt"
|
||||
label_path = export_dir / "labels" / split / label_name
|
||||
|
||||
with open(label_path, "w") as f:
|
||||
for ann in page_annotations:
|
||||
line = f"{ann.class_id} {ann.x_center:.6f} {ann.y_center:.6f} {ann.width:.6f} {ann.height:.6f}\n"
|
||||
# Convert normalized coords to pixel coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
y0 = ann.y_center * img_height - half_h
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Use manual_mode for manual/imported annotations
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized YOLO format
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
new_x_center = max(0, min(1, new_x_center))
|
||||
new_y_center = max(0, min(1, new_y_center))
|
||||
new_width = max(0, min(1, new_width))
|
||||
new_height = max(0, min(1, new_height))
|
||||
|
||||
line = f"{ann.class_id} {new_x_center:.6f} {new_y_center:.6f} {new_width:.6f} {new_height:.6f}\n"
|
||||
f.write(line)
|
||||
total_annotations += 1
|
||||
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
"""
|
||||
Inference Service
|
||||
Inference Service (Adapter Layer)
|
||||
|
||||
Business logic for invoice field extraction.
|
||||
Orchestrates technical pipeline and business domain logic.
|
||||
Acts as adapter between API layer and internal components.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
from typing import TYPE_CHECKING, Callable, Generator
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier
|
||||
from backend.web.services.storage_helpers import get_storage_helper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -50,9 +53,12 @@ class ServiceResult:
|
||||
|
||||
class InferenceService:
|
||||
"""
|
||||
Service for running invoice field extraction.
|
||||
Service for running invoice field extraction (Adapter Pattern).
|
||||
|
||||
Orchestrates:
|
||||
- Technical layer: InferencePipeline, YOLODetector
|
||||
- Business layer: DocumentClassifier
|
||||
|
||||
Encapsulates YOLO detection and OCR extraction logic.
|
||||
Supports dynamic model loading from database.
|
||||
"""
|
||||
|
||||
@@ -61,6 +67,7 @@ class InferenceService:
|
||||
model_config: ModelConfig,
|
||||
storage_config: StorageConfig,
|
||||
model_path_resolver: ModelPathResolver | None = None,
|
||||
document_classifier: DocumentClassifier | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize inference service.
|
||||
@@ -71,12 +78,19 @@ class InferenceService:
|
||||
model_path_resolver: Optional function to resolve model path from database.
|
||||
If provided, will be called to get active model path.
|
||||
If returns None, falls back to model_config.model_path.
|
||||
document_classifier: Optional custom classifier (uses default if None)
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.storage_config = storage_config
|
||||
self._model_path_resolver = model_path_resolver
|
||||
|
||||
# Technical layer (lazy initialized)
|
||||
self._pipeline = None
|
||||
self._detector = None
|
||||
|
||||
# Business layer (eagerly initialized, no heavy resources)
|
||||
self._classifier = document_classifier or DocumentClassifier()
|
||||
|
||||
self._is_initialized = False
|
||||
self._current_model_path: Path | None = None
|
||||
self._business_features_enabled = False
|
||||
@@ -219,22 +233,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections for visualization
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Save visualization if requested
|
||||
if save_visualization and pipeline_result.raw_detections:
|
||||
@@ -293,22 +297,12 @@ class InferenceService:
|
||||
result.success = pipeline_result.success
|
||||
result.errors = pipeline_result.errors
|
||||
|
||||
# Determine document type based on payment_line presence
|
||||
# If no payment_line found, it's likely a letter, not an invoice
|
||||
if not result.fields.get('payment_line'):
|
||||
result.document_type = "letter"
|
||||
else:
|
||||
result.document_type = "invoice"
|
||||
# Business layer: classify document type
|
||||
classification = self._classifier.classify(result.fields)
|
||||
result.document_type = classification.document_type
|
||||
|
||||
# Get raw detections
|
||||
result.detections = [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in pipeline_result.raw_detections
|
||||
]
|
||||
result.detections = self._format_detections(pipeline_result.raw_detections)
|
||||
|
||||
# Include business features if extracted
|
||||
if extract_line_items:
|
||||
@@ -329,10 +323,19 @@ class InferenceService:
|
||||
result.processing_time_ms = (time.time() - start_time) * 1000
|
||||
return result
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization image with detections."""
|
||||
from ultralytics import YOLO
|
||||
def _format_detections(self, raw_detections: list) -> list[dict]:
|
||||
"""Format raw detections for response."""
|
||||
return [
|
||||
{
|
||||
"field": d.class_name,
|
||||
"confidence": d.confidence,
|
||||
"bbox": list(d.bbox),
|
||||
}
|
||||
for d in raw_detections
|
||||
]
|
||||
|
||||
def _save_visualization(self, image_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization image with detections using existing detector."""
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
results_dir = storage.get_results_base_path()
|
||||
@@ -340,9 +343,8 @@ class InferenceService:
|
||||
logger.warning("Cannot save visualization: local storage not available")
|
||||
return None
|
||||
|
||||
# Load model and run prediction with visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(image_path), verbose=False)
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(image_path), verbose=False)
|
||||
|
||||
# Save annotated image
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
@@ -351,11 +353,20 @@ class InferenceService:
|
||||
|
||||
return output_path
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path:
|
||||
"""Save visualization for PDF (first page)."""
|
||||
@contextmanager
|
||||
def _temp_image_file(
|
||||
self, results_dir: Path, doc_id: str
|
||||
) -> Generator[Path, None, None]:
|
||||
"""Context manager for temporary image file with guaranteed cleanup."""
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
try:
|
||||
yield temp_path
|
||||
finally:
|
||||
temp_path.unlink(missing_ok=True)
|
||||
|
||||
def _save_pdf_visualization(self, pdf_path: Path, doc_id: str) -> Path | None:
|
||||
"""Save visualization for PDF (first page) using existing detector."""
|
||||
from shared.pdf.renderer import render_pdf_to_images
|
||||
from ultralytics import YOLO
|
||||
import io
|
||||
|
||||
# Get storage helper for results directory
|
||||
storage = get_storage_helper()
|
||||
@@ -369,20 +380,19 @@ class InferenceService:
|
||||
pdf_path, dpi=self.model_config.dpi
|
||||
):
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
temp_path = results_dir / f"{doc_id}_temp.png"
|
||||
image.save(temp_path)
|
||||
|
||||
# Run YOLO and save visualization
|
||||
model = YOLO(str(self.model_config.model_path))
|
||||
results = model.predict(str(temp_path), verbose=False)
|
||||
# Use context manager for temp file to guarantee cleanup
|
||||
with self._temp_image_file(results_dir, doc_id) as temp_path:
|
||||
image.save(temp_path)
|
||||
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
# Reuse self._detector instead of creating new YOLO instance
|
||||
results = self._detector.model.predict(str(temp_path), verbose=False)
|
||||
|
||||
# Cleanup temp file
|
||||
temp_path.unlink(missing_ok=True)
|
||||
return output_path
|
||||
output_path = results_dir / f"{doc_id}_result.png"
|
||||
for r in results:
|
||||
r.save(filename=str(output_path))
|
||||
|
||||
return output_path
|
||||
|
||||
# If no pages rendered
|
||||
return None
|
||||
|
||||
37
packages/shared/shared/bbox/__init__.py
Normal file
37
packages/shared/shared/bbox/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
BBox Scale Strategy Module.
|
||||
|
||||
Provides field-specific bounding box expansion strategies for YOLO training data.
|
||||
Expands bboxes using center-point scaling with directional compensation to capture
|
||||
field labels that typically appear above or to the left of field values.
|
||||
|
||||
Two modes are supported:
|
||||
- Auto-label: Field-specific scale strategies with directional compensation
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
|
||||
Usage:
|
||||
from shared.bbox import expand_bbox, ScaleStrategy, FIELD_SCALE_STRATEGIES
|
||||
|
||||
Available exports:
|
||||
- ScaleStrategy: Dataclass for scale strategy configuration
|
||||
- DEFAULT_STRATEGY: Default strategy for unknown fields (auto-label)
|
||||
- MANUAL_LABEL_STRATEGY: Minimal padding strategy for manual labels
|
||||
- FIELD_SCALE_STRATEGIES: dict[str, ScaleStrategy] - field-specific strategies
|
||||
- expand_bbox: Function to expand bbox using field-specific strategy
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from .expander import expand_bbox
|
||||
|
||||
__all__ = [
|
||||
"ScaleStrategy",
|
||||
"DEFAULT_STRATEGY",
|
||||
"MANUAL_LABEL_STRATEGY",
|
||||
"FIELD_SCALE_STRATEGIES",
|
||||
"expand_bbox",
|
||||
]
|
||||
101
packages/shared/shared/bbox/expander.py
Normal file
101
packages/shared/shared/bbox/expander.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
BBox Expander Module.
|
||||
|
||||
Provides functions to expand bounding boxes using field-specific strategies.
|
||||
Expansion is center-point based with directional compensation.
|
||||
|
||||
Two modes:
|
||||
- Auto-label (default): Field-specific scale strategies
|
||||
- Manual-label: Minimal padding only to prevent edge clipping
|
||||
"""
|
||||
|
||||
from .scale_strategy import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
|
||||
|
||||
def expand_bbox(
|
||||
bbox: tuple[float, float, float, float],
|
||||
image_width: float,
|
||||
image_height: float,
|
||||
field_type: str,
|
||||
strategies: dict[str, ScaleStrategy] | None = None,
|
||||
manual_mode: bool = False,
|
||||
) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
Expand bbox using field-specific scale strategy.
|
||||
|
||||
The expansion follows these steps:
|
||||
1. Scale bbox around center point (scale_x, scale_y)
|
||||
2. Apply directional compensation (extra_*_ratio)
|
||||
3. Clamp expansion to max_pad limits
|
||||
4. Clamp to image boundaries
|
||||
|
||||
Args:
|
||||
bbox: (x0, y0, x1, y1) in pixels
|
||||
image_width: Image width for boundary clamping
|
||||
image_height: Image height for boundary clamping
|
||||
field_type: Field class_name (e.g., "ocr_number")
|
||||
strategies: Custom strategies dict, defaults to FIELD_SCALE_STRATEGIES
|
||||
manual_mode: If True, use MANUAL_LABEL_STRATEGY (minimal padding only)
|
||||
|
||||
Returns:
|
||||
Expanded bbox (x0, y0, x1, y1) as integers, clamped to image bounds
|
||||
"""
|
||||
x0, y0, x1, y1 = bbox
|
||||
w = x1 - x0
|
||||
h = y1 - y0
|
||||
|
||||
# Get strategy based on mode
|
||||
if manual_mode:
|
||||
strategy = MANUAL_LABEL_STRATEGY
|
||||
elif strategies is None:
|
||||
strategy = FIELD_SCALE_STRATEGIES.get(field_type, DEFAULT_STRATEGY)
|
||||
else:
|
||||
strategy = strategies.get(field_type, DEFAULT_STRATEGY)
|
||||
|
||||
# Step 1: Scale around center point
|
||||
cx = (x0 + x1) / 2
|
||||
cy = (y0 + y1) / 2
|
||||
|
||||
new_w = w * strategy.scale_x
|
||||
new_h = h * strategy.scale_y
|
||||
|
||||
nx0 = cx - new_w / 2
|
||||
nx1 = cx + new_w / 2
|
||||
ny0 = cy - new_h / 2
|
||||
ny1 = cy + new_h / 2
|
||||
|
||||
# Step 2: Apply directional compensation
|
||||
nx0 -= w * strategy.extra_left_ratio
|
||||
nx1 += w * strategy.extra_right_ratio
|
||||
ny0 -= h * strategy.extra_top_ratio
|
||||
ny1 += h * strategy.extra_bottom_ratio
|
||||
|
||||
# Step 3: Clamp expansion to max_pad limits (preserve asymmetry)
|
||||
left_pad = min(x0 - nx0, strategy.max_pad_x)
|
||||
right_pad = min(nx1 - x1, strategy.max_pad_x)
|
||||
top_pad = min(y0 - ny0, strategy.max_pad_y)
|
||||
bottom_pad = min(ny1 - y1, strategy.max_pad_y)
|
||||
|
||||
# Ensure pads are non-negative (in case of contraction)
|
||||
left_pad = max(0, left_pad)
|
||||
right_pad = max(0, right_pad)
|
||||
top_pad = max(0, top_pad)
|
||||
bottom_pad = max(0, bottom_pad)
|
||||
|
||||
nx0 = x0 - left_pad
|
||||
nx1 = x1 + right_pad
|
||||
ny0 = y0 - top_pad
|
||||
ny1 = y1 + bottom_pad
|
||||
|
||||
# Step 4: Clamp to image boundaries
|
||||
nx0 = max(0, int(nx0))
|
||||
ny0 = max(0, int(ny0))
|
||||
nx1 = min(int(image_width), int(nx1))
|
||||
ny1 = min(int(image_height), int(ny1))
|
||||
|
||||
return (nx0, ny0, nx1, ny1)
|
||||
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
140
packages/shared/shared/bbox/scale_strategy.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Scale Strategy Configuration.
|
||||
|
||||
Defines field-specific bbox expansion strategies for YOLO training data.
|
||||
Each strategy controls how bboxes are expanded around field values to
|
||||
capture contextual information like labels.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Final
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScaleStrategy:
|
||||
"""Immutable scale strategy for bbox expansion.
|
||||
|
||||
Attributes:
|
||||
scale_x: Horizontal scale factor (1.0 = no scaling)
|
||||
scale_y: Vertical scale factor (1.0 = no scaling)
|
||||
extra_top_ratio: Additional expansion ratio towards top (for labels above)
|
||||
extra_bottom_ratio: Additional expansion ratio towards bottom
|
||||
extra_left_ratio: Additional expansion ratio towards left (for prefixes)
|
||||
extra_right_ratio: Additional expansion ratio towards right (for suffixes)
|
||||
max_pad_x: Maximum horizontal padding in pixels
|
||||
max_pad_y: Maximum vertical padding in pixels
|
||||
"""
|
||||
|
||||
scale_x: float = 1.15
|
||||
scale_y: float = 1.15
|
||||
extra_top_ratio: float = 0.0
|
||||
extra_bottom_ratio: float = 0.0
|
||||
extra_left_ratio: float = 0.0
|
||||
extra_right_ratio: float = 0.0
|
||||
max_pad_x: int = 50
|
||||
max_pad_y: int = 50
|
||||
|
||||
|
||||
# Default strategy for unknown fields (auto-label mode)
|
||||
DEFAULT_STRATEGY: Final[ScaleStrategy] = ScaleStrategy()
|
||||
|
||||
# Manual label strategy - minimal padding to prevent edge clipping
|
||||
# No scaling, no directional compensation, just small uniform padding
|
||||
MANUAL_LABEL_STRATEGY: Final[ScaleStrategy] = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.0,
|
||||
extra_bottom_ratio=0.0,
|
||||
extra_left_ratio=0.0,
|
||||
extra_right_ratio=0.0,
|
||||
max_pad_x=10, # Small padding to prevent edge loss
|
||||
max_pad_y=10,
|
||||
)
|
||||
|
||||
|
||||
# Field-specific strategies based on Swedish invoice field characteristics
|
||||
# Field labels typically appear above or to the left of values
|
||||
FIELD_SCALE_STRATEGIES: Final[dict[str, ScaleStrategy]] = {
|
||||
# OCR number - label "OCR" or "Referens" typically above
|
||||
"ocr_number": ScaleStrategy(
|
||||
scale_x=1.15,
|
||||
scale_y=1.80,
|
||||
extra_top_ratio=0.60,
|
||||
max_pad_x=50,
|
||||
max_pad_y=140,
|
||||
),
|
||||
# Bankgiro - prefix "Bankgiro:" or "BG:" typically to the left
|
||||
"bankgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Plusgiro - prefix "Plusgiro:" or "PG:" typically to the left
|
||||
"plusgiro": ScaleStrategy(
|
||||
scale_x=1.45,
|
||||
scale_y=1.35,
|
||||
extra_left_ratio=0.80,
|
||||
max_pad_x=160,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Invoice date - label "Fakturadatum" typically above
|
||||
"invoice_date": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.55,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=110,
|
||||
),
|
||||
# Due date - label "Forfalldatum" typically above, sometimes left
|
||||
"invoice_due_date": ScaleStrategy(
|
||||
scale_x=1.30,
|
||||
scale_y=1.65,
|
||||
extra_top_ratio=0.45,
|
||||
extra_left_ratio=0.35,
|
||||
max_pad_x=100,
|
||||
max_pad_y=120,
|
||||
),
|
||||
# Amount - currency symbol "SEK" or "kr" may be to the right
|
||||
"amount": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.35,
|
||||
extra_right_ratio=0.30,
|
||||
max_pad_x=70,
|
||||
max_pad_y=80,
|
||||
),
|
||||
# Invoice number - label "Fakturanummer" typically above
|
||||
"invoice_number": ScaleStrategy(
|
||||
scale_x=1.20,
|
||||
scale_y=1.50,
|
||||
extra_top_ratio=0.40,
|
||||
max_pad_x=80,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Supplier org number - label "Org.nr" typically above or left
|
||||
"supplier_org_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.40,
|
||||
extra_top_ratio=0.30,
|
||||
extra_left_ratio=0.20,
|
||||
max_pad_x=90,
|
||||
max_pad_y=90,
|
||||
),
|
||||
# Customer number - label "Kundnummer" typically above or left
|
||||
"customer_number": ScaleStrategy(
|
||||
scale_x=1.25,
|
||||
scale_y=1.45,
|
||||
extra_top_ratio=0.35,
|
||||
extra_left_ratio=0.25,
|
||||
max_pad_x=90,
|
||||
max_pad_y=100,
|
||||
),
|
||||
# Payment line - machine-readable code, minimal expansion needed
|
||||
"payment_line": ScaleStrategy(
|
||||
scale_x=1.10,
|
||||
scale_y=1.20,
|
||||
max_pad_x=40,
|
||||
max_pad_y=40,
|
||||
),
|
||||
}
|
||||
@@ -16,6 +16,7 @@ Available exports:
|
||||
- FIELD_CLASSES: dict[int, str] - class_id to class_name
|
||||
- FIELD_CLASS_IDS: dict[str, int] - class_name to class_id
|
||||
- CLASS_TO_FIELD: dict[str, str] - class_name to field_name
|
||||
- FIELD_TO_CLASS: dict[str, str] - field_name to class_name
|
||||
- CSV_TO_CLASS_MAPPING: dict[str, int] - field_name to class_id (excludes derived)
|
||||
- TRAINING_FIELD_CLASSES: dict[str, int] - field_name to class_id (all fields)
|
||||
- ACCOUNT_FIELD_MAPPING: Mapping for supplier_accounts handling
|
||||
@@ -27,6 +28,7 @@ from .mappings import (
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
FIELD_TO_CLASS,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
@@ -40,6 +42,7 @@ __all__ = [
|
||||
"FIELD_CLASSES",
|
||||
"FIELD_CLASS_IDS",
|
||||
"CLASS_TO_FIELD",
|
||||
"FIELD_TO_CLASS",
|
||||
"CSV_TO_CLASS_MAPPING",
|
||||
"TRAINING_FIELD_CLASSES",
|
||||
"ACCOUNT_FIELD_MAPPING",
|
||||
|
||||
@@ -47,6 +47,12 @@ TRAINING_FIELD_CLASSES: Final[dict[str, int]] = {
|
||||
fd.field_name: fd.class_id for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# field_name -> class_name mapping (reverse of CLASS_TO_FIELD)
|
||||
# Example: {"InvoiceNumber": "invoice_number", "OCR": "ocr_number", ...}
|
||||
FIELD_TO_CLASS: Final[dict[str, str]] = {
|
||||
fd.field_name: fd.class_name for fd in FIELD_DEFINITIONS
|
||||
}
|
||||
|
||||
# Account field mapping for supplier_accounts special handling
|
||||
# BG:xxx -> Bankgiro, PG:xxx -> Plusgiro
|
||||
ACCOUNT_FIELD_MAPPING: Final[dict[str, dict[str, str]]] = {
|
||||
|
||||
62
packages/shared/shared/logging_config.py
Normal file
62
packages/shared/shared/logging_config.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Logging Configuration
|
||||
|
||||
Provides consistent logging setup for CLI tools and modules.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def setup_cli_logging(
|
||||
level: int = logging.INFO,
|
||||
name: Optional[str] = None,
|
||||
format_string: Optional[str] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure logging for CLI applications.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: INFO)
|
||||
name: Logger name (default: root logger)
|
||||
format_string: Custom format string (default: simple CLI format)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
if format_string is None:
|
||||
format_string = "%(message)s"
|
||||
|
||||
# Configure root logger or specific logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# Remove existing handlers to avoid duplicates
|
||||
logger.handlers.clear()
|
||||
|
||||
# Create console handler
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setLevel(level)
|
||||
handler.setFormatter(logging.Formatter(format_string))
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def setup_verbose_logging(
|
||||
level: int = logging.DEBUG,
|
||||
name: Optional[str] = None,
|
||||
) -> logging.Logger:
|
||||
"""
|
||||
Configure verbose logging with timestamps and module info.
|
||||
|
||||
Args:
|
||||
level: Logging level (default: DEBUG)
|
||||
name: Logger name (default: root logger)
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
format_string = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
return setup_cli_logging(level=level, name=name, format_string=format_string)
|
||||
@@ -9,6 +9,7 @@ Now reads from PostgreSQL database instead of JSONL files.
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
@@ -16,6 +17,9 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from shared.normalize import normalize_field
|
||||
from shared.matcher import FieldMatcher
|
||||
@@ -104,7 +108,7 @@ class LabelAnalyzer:
|
||||
for row in reader:
|
||||
doc_id = row['DocumentId']
|
||||
self.csv_data[doc_id] = row
|
||||
print(f"Loaded {len(self.csv_data)} records from CSV")
|
||||
logger.info("Loaded %d records from CSV", len(self.csv_data))
|
||||
|
||||
def load_labels(self):
|
||||
"""Load all label files from dataset."""
|
||||
@@ -150,12 +154,12 @@ class LabelAnalyzer:
|
||||
for doc in self.label_data.values()
|
||||
for labels in doc['pages'].values()
|
||||
)
|
||||
print(f"Loaded labels for {total_docs} documents ({total_labels} total labels)")
|
||||
logger.info("Loaded labels for %d documents (%d total labels)", total_docs, total_labels)
|
||||
|
||||
def load_report(self):
|
||||
"""Load autolabel report from database."""
|
||||
if not self.db:
|
||||
print("Database not configured, skipping report loading")
|
||||
logger.info("Database not configured, skipping report loading")
|
||||
return
|
||||
|
||||
# Get document IDs from CSV to query
|
||||
@@ -175,7 +179,7 @@ class LabelAnalyzer:
|
||||
self.report_data[doc_id] = doc
|
||||
loaded += 1
|
||||
|
||||
print(f"Loaded {loaded} autolabel reports from database")
|
||||
logger.info("Loaded %d autolabel reports from database", loaded)
|
||||
|
||||
def analyze_document(self, doc_id: str, skip_missing_pdf: bool = True) -> Optional[DocumentAnalysis]:
|
||||
"""Analyze a single document."""
|
||||
@@ -373,7 +377,7 @@ class LabelAnalyzer:
|
||||
break
|
||||
|
||||
if skipped > 0:
|
||||
print(f"Skipped {skipped} documents without PDF files")
|
||||
logger.info("Skipped %d documents without PDF files", skipped)
|
||||
|
||||
return results
|
||||
|
||||
@@ -447,7 +451,7 @@ class LabelAnalyzer:
|
||||
with open(output, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport saved to: {output}")
|
||||
logger.info("Report saved to: %s", output)
|
||||
|
||||
return report
|
||||
|
||||
@@ -456,52 +460,52 @@ def print_summary(report: dict):
|
||||
"""Print summary to console."""
|
||||
summary = report['summary']
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("LABEL ANALYSIS SUMMARY")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("LABEL ANALYSIS SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
|
||||
print(f"\nDocuments:")
|
||||
print(f" Total: {summary['total_documents']}")
|
||||
print(f" With issues: {summary['documents_with_issues']} ({summary['issue_rate']})")
|
||||
logger.info("Documents:")
|
||||
logger.info(" Total: %d", summary['total_documents'])
|
||||
logger.info(" With issues: %d (%s)", summary['documents_with_issues'], summary['issue_rate'])
|
||||
|
||||
print(f"\nFields:")
|
||||
print(f" Expected: {summary['total_expected_fields']}")
|
||||
print(f" Labeled: {summary['total_labeled_fields']} ({summary['label_coverage']})")
|
||||
print(f" Missing: {summary['missing_labels']}")
|
||||
print(f" Extra: {summary['extra_labels']}")
|
||||
logger.info("Fields:")
|
||||
logger.info(" Expected: %d", summary['total_expected_fields'])
|
||||
logger.info(" Labeled: %d (%s)", summary['total_labeled_fields'], summary['label_coverage'])
|
||||
logger.info(" Missing: %d", summary['missing_labels'])
|
||||
logger.info(" Extra: %d", summary['extra_labels'])
|
||||
|
||||
print(f"\nFailure Reasons:")
|
||||
logger.info("Failure Reasons:")
|
||||
for reason, count in sorted(report['failure_reasons'].items(), key=lambda x: -x[1]):
|
||||
print(f" {reason}: {count}")
|
||||
logger.info(" %s: %d", reason, count)
|
||||
|
||||
print(f"\nFailures by Field:")
|
||||
logger.info("Failures by Field:")
|
||||
for field, reasons in report['failures_by_field'].items():
|
||||
total = sum(reasons.values())
|
||||
print(f" {field}: {total}")
|
||||
logger.info(" %s: %d", field, total)
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
logger.info(" - %s: %d", reason, count)
|
||||
|
||||
# Show sample issues
|
||||
if report['issues']:
|
||||
print(f"\n" + "-" * 60)
|
||||
print("SAMPLE ISSUES (first 10)")
|
||||
print("-" * 60)
|
||||
logger.info("-" * 60)
|
||||
logger.info("SAMPLE ISSUES (first 10)")
|
||||
logger.info("-" * 60)
|
||||
|
||||
for issue in report['issues'][:10]:
|
||||
print(f"\n[{issue['doc_id']}] {issue['field']}")
|
||||
print(f" CSV value: {issue['csv_value']}")
|
||||
print(f" Reason: {issue['reason']}")
|
||||
logger.info("[%s] %s", issue['doc_id'], issue['field'])
|
||||
logger.info(" CSV value: %s", issue['csv_value'])
|
||||
logger.info(" Reason: %s", issue['reason'])
|
||||
|
||||
if issue.get('details'):
|
||||
details = issue['details']
|
||||
if details.get('normalized_candidates'):
|
||||
print(f" Candidates: {details['normalized_candidates'][:5]}")
|
||||
logger.info(" Candidates: %s", details['normalized_candidates'][:5])
|
||||
if details.get('pdf_tokens_sample'):
|
||||
print(f" PDF samples: {details['pdf_tokens_sample'][:5]}")
|
||||
logger.info(" PDF samples: %s", details['pdf_tokens_sample'][:5])
|
||||
if details.get('potential_matches'):
|
||||
print(f" Potential matches:")
|
||||
logger.info(" Potential matches:")
|
||||
for pm in details['potential_matches'][:3]:
|
||||
print(f" - token='{pm['token']}' matches candidate='{pm['candidate']}'")
|
||||
logger.info(" - token='%s' matches candidate='%s'", pm['token'], pm['candidate'])
|
||||
|
||||
|
||||
def main():
|
||||
@@ -551,6 +555,9 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
analyzer = LabelAnalyzer(
|
||||
csv_path=args.csv,
|
||||
pdf_dir=args.pdf_dir,
|
||||
@@ -566,30 +573,30 @@ def main():
|
||||
|
||||
analysis = analyzer.analyze_document(args.single)
|
||||
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Document: {analysis.doc_id}")
|
||||
print(f"{'=' * 60}")
|
||||
print(f"PDF exists: {analysis.pdf_exists}")
|
||||
print(f"PDF type: {analysis.pdf_type}")
|
||||
print(f"Pages: {analysis.total_pages}")
|
||||
print(f"\nFields (CSV: {analysis.csv_fields_count}, Labeled: {analysis.labeled_fields_count}):")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Document: %s", analysis.doc_id)
|
||||
logger.info("=" * 60)
|
||||
logger.info("PDF exists: %s", analysis.pdf_exists)
|
||||
logger.info("PDF type: %s", analysis.pdf_type)
|
||||
logger.info("Pages: %d", analysis.total_pages)
|
||||
logger.info("Fields (CSV: %d, Labeled: %d):", analysis.csv_fields_count, analysis.labeled_fields_count)
|
||||
|
||||
for f in analysis.fields:
|
||||
status = "✓" if f.labeled else ("✗" if f.expected else "-")
|
||||
status = "[OK]" if f.labeled else ("[FAIL]" if f.expected else "[-]")
|
||||
value_str = f.csv_value[:30] if f.csv_value else "(empty)"
|
||||
print(f" [{status}] {f.field_name}: {value_str}")
|
||||
logger.info(" %s %s: %s", status, f.field_name, value_str)
|
||||
|
||||
if f.failure_reason:
|
||||
print(f" Reason: {f.failure_reason}")
|
||||
logger.info(" Reason: %s", f.failure_reason)
|
||||
if f.details.get('normalized_candidates'):
|
||||
print(f" Candidates: {f.details['normalized_candidates']}")
|
||||
logger.info(" Candidates: %s", f.details['normalized_candidates'])
|
||||
if f.details.get('potential_matches'):
|
||||
print(f" Potential matches in PDF:")
|
||||
logger.info(" Potential matches in PDF:")
|
||||
for pm in f.details['potential_matches'][:3]:
|
||||
print(f" - '{pm['token']}'")
|
||||
logger.info(" - '%s'", pm['token'])
|
||||
else:
|
||||
# Full analysis
|
||||
print("Running label analysis...")
|
||||
logger.info("Running label analysis...")
|
||||
results = analyzer.run_analysis(limit=args.limit)
|
||||
report = analyzer.generate_report(results, args.output, verbose=args.verbose)
|
||||
print_summary(report)
|
||||
|
||||
@@ -7,11 +7,15 @@ Generates statistics and insights from database or autolabel_report.jsonl
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import get_db_connection_string
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_reports_from_db() -> dict:
|
||||
@@ -147,9 +151,9 @@ def load_reports_from_file(report_path: str) -> list[dict]:
|
||||
if not report_files:
|
||||
return []
|
||||
|
||||
print(f"Reading {len(report_files)} report file(s):")
|
||||
logger.info("Reading %d report file(s):", len(report_files))
|
||||
for f in report_files:
|
||||
print(f" - {f.name}")
|
||||
logger.info(" - %s", f.name)
|
||||
|
||||
reports = []
|
||||
for report_file in report_files:
|
||||
@@ -231,55 +235,55 @@ def analyze_reports(reports: list[dict]) -> dict:
|
||||
|
||||
def print_report(stats: dict, verbose: bool = False):
|
||||
"""Print analysis report."""
|
||||
print("\n" + "=" * 60)
|
||||
print("AUTO-LABEL REPORT ANALYSIS")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("AUTO-LABEL REPORT ANALYSIS")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Overall stats
|
||||
print(f"\n{'OVERALL STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "OVERALL STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
total = stats['total']
|
||||
successful = stats['successful']
|
||||
failed = stats['failed']
|
||||
success_rate = successful / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"Total documents: {total:>8}")
|
||||
print(f"Successful: {successful:>8} ({success_rate:.1f}%)")
|
||||
print(f"Failed: {failed:>8} ({100-success_rate:.1f}%)")
|
||||
logger.info("Total documents: %8d", total)
|
||||
logger.info("Successful: %8d (%.1f%%)", successful, success_rate)
|
||||
logger.info("Failed: %8d (%.1f%%)", failed, 100-success_rate)
|
||||
|
||||
# Processing time
|
||||
if 'processing_time_stats' in stats:
|
||||
pts = stats['processing_time_stats']
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {pts['avg_ms']:>8.1f}")
|
||||
print(f" Min: {pts['min_ms']:>8.1f}")
|
||||
print(f" Max: {pts['max_ms']:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", pts['avg_ms'])
|
||||
logger.info(" Min: %8.1f", pts['min_ms'])
|
||||
logger.info(" Max: %8.1f", pts['max_ms'])
|
||||
elif stats.get('processing_times'):
|
||||
times = stats['processing_times']
|
||||
avg_time = sum(times) / len(times)
|
||||
min_time = min(times)
|
||||
max_time = max(times)
|
||||
print(f"\nProcessing time (ms):")
|
||||
print(f" Average: {avg_time:>8.1f}")
|
||||
print(f" Min: {min_time:>8.1f}")
|
||||
print(f" Max: {max_time:>8.1f}")
|
||||
logger.info("Processing time (ms):")
|
||||
logger.info(" Average: %8.1f", avg_time)
|
||||
logger.info(" Min: %8.1f", min_time)
|
||||
logger.info(" Max: %8.1f", max_time)
|
||||
|
||||
# By PDF type
|
||||
print(f"\n{'BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Type':<15} {'Total':>10} {'Success':>10} {'Rate':>10}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-15s %10s %10s %10s", 'Type', 'Total', 'Success', 'Rate')
|
||||
logger.info("-" * 60)
|
||||
for pdf_type, type_stats in sorted(stats['by_pdf_type'].items()):
|
||||
type_total = type_stats['total']
|
||||
type_success = type_stats['successful']
|
||||
type_rate = type_success / type_total * 100 if type_total > 0 else 0
|
||||
print(f"{pdf_type:<15} {type_total:>10} {type_success:>10} {type_rate:>9.1f}%")
|
||||
logger.info("%-15s %10d %10d %9.1f%%", pdf_type, type_total, type_success, type_rate)
|
||||
|
||||
# By field
|
||||
print(f"\n{'FIELD MATCH STATISTICS':^60}")
|
||||
print("-" * 60)
|
||||
print(f"{'Field':<18} {'Total':>7} {'Match':>7} {'Rate':>7} {'Exact':>7} {'Flex':>7} {'AvgScore':>8}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH STATISTICS".center(60))
|
||||
logger.info("-" * 60)
|
||||
logger.info("%-18s %7s %7s %7s %7s %7s %8s", 'Field', 'Total', 'Match', 'Rate', 'Exact', 'Flex', 'AvgScore')
|
||||
logger.info("-" * 60)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -299,16 +303,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
else:
|
||||
avg_score = 0
|
||||
|
||||
print(f"{field_name:<18} {total:>7} {matched:>7} {rate:>6.1f}% {exact:>7} {flex:>7} {avg_score:>8.3f}")
|
||||
logger.info("%-18s %7d %7d %6.1f%% %7d %7d %8.3f", field_name, total, matched, rate, exact, flex, avg_score)
|
||||
|
||||
# Field match by PDF type
|
||||
print(f"\n{'FIELD MATCH BY PDF TYPE':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "FIELD MATCH BY PDF TYPE".center(60))
|
||||
logger.info("-" * 60)
|
||||
|
||||
for pdf_type in sorted(stats['by_pdf_type'].keys()):
|
||||
print(f"\n[{pdf_type.upper()}]")
|
||||
print(f"{'Field':<18} {'Total':>10} {'Matched':>10} {'Rate':>10}")
|
||||
print("-" * 50)
|
||||
logger.info("[%s]", pdf_type.upper())
|
||||
logger.info("%-18s %10s %10s %10s", 'Field', 'Total', 'Matched', 'Rate')
|
||||
logger.info("-" * 50)
|
||||
|
||||
for field_name in ['InvoiceNumber', 'InvoiceDate', 'InvoiceDueDate', 'OCR', 'Bankgiro', 'Plusgiro', 'Amount']:
|
||||
if field_name not in stats['by_field']:
|
||||
@@ -317,16 +321,16 @@ def print_report(stats: dict, verbose: bool = False):
|
||||
total = type_stats['total']
|
||||
matched = type_stats['matched']
|
||||
rate = matched / total * 100 if total > 0 else 0
|
||||
print(f"{field_name:<18} {total:>10} {matched:>10} {rate:>9.1f}%")
|
||||
logger.info("%-18s %10d %10d %9.1f%%", field_name, total, matched, rate)
|
||||
|
||||
# Errors
|
||||
if stats.get('errors') and verbose:
|
||||
print(f"\n{'ERRORS':^60}")
|
||||
print("-" * 60)
|
||||
logger.info("%s", "ERRORS".center(60))
|
||||
logger.info("-" * 60)
|
||||
for error, count in sorted(stats['errors'].items(), key=lambda x: -x[1])[:20]:
|
||||
print(f"{count:>5}x {error[:50]}")
|
||||
logger.info("%5dx %s", count, error[:50])
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
logger.info("=" * 60)
|
||||
|
||||
|
||||
def export_json(stats: dict, output_path: str):
|
||||
@@ -372,7 +376,7 @@ def export_json(stats: dict, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(export_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nStatistics exported to: {output_path}")
|
||||
logger.info("Statistics exported to: %s", output_path)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -401,25 +405,28 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Decide source
|
||||
use_db = not args.from_file and args.report is None
|
||||
|
||||
if use_db:
|
||||
print("Loading statistics from database...")
|
||||
logger.info("Loading statistics from database...")
|
||||
stats = load_reports_from_db()
|
||||
print(f"Loaded stats for {stats['total']} documents")
|
||||
logger.info("Loaded stats for %d documents", stats['total'])
|
||||
else:
|
||||
report_path = args.report or 'reports/autolabel_report.jsonl'
|
||||
path = Path(report_path)
|
||||
|
||||
# Check if file exists (handle glob patterns)
|
||||
if '*' not in str(path) and '?' not in str(path) and not path.exists():
|
||||
print(f"Error: Report file not found: {path}")
|
||||
logger.error("Report file not found: %s", path)
|
||||
return 1
|
||||
|
||||
print(f"Loading reports from: {report_path}")
|
||||
logger.info("Loading reports from: %s", report_path)
|
||||
reports = load_reports_from_file(report_path)
|
||||
print(f"Loaded {len(reports)} reports")
|
||||
logger.info("Loaded %d reports", len(reports))
|
||||
stats = analyze_reports(reports)
|
||||
|
||||
print_report(stats, verbose=args.verbose)
|
||||
|
||||
@@ -6,6 +6,7 @@ Generates YOLO training data from PDFs and structured CSV data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
@@ -17,6 +18,10 @@ from tqdm import tqdm
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed, TimeoutError
|
||||
import multiprocessing
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global flag for graceful shutdown
|
||||
_shutdown_requested = False
|
||||
|
||||
@@ -25,8 +30,8 @@ def _signal_handler(signum, frame):
|
||||
"""Handle interrupt signals for graceful shutdown."""
|
||||
global _shutdown_requested
|
||||
_shutdown_requested = True
|
||||
print("\n\nShutdown requested. Finishing current batch and saving progress...")
|
||||
print("(Press Ctrl+C again to force quit)\n")
|
||||
logger.warning("Shutdown requested. Finishing current batch and saving progress...")
|
||||
logger.warning("(Press Ctrl+C again to force quit)")
|
||||
|
||||
# Windows compatibility: use 'spawn' method for multiprocessing
|
||||
# This is required on Windows and is also safer for libraries like PaddleOCR
|
||||
@@ -350,11 +355,14 @@ def main():
|
||||
if ',' in csv_input and '*' not in csv_input:
|
||||
csv_input = [p.strip() for p in csv_input.split(',')]
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Get list of CSV files (don't load all data at once)
|
||||
temp_loader = CSVLoader(csv_input, args.pdf_dir)
|
||||
csv_files = temp_loader.csv_paths
|
||||
pdf_dir = temp_loader.pdf_dir
|
||||
print(f"Found {len(csv_files)} CSV file(s) to process")
|
||||
logger.info("Found %d CSV file(s) to process", len(csv_files))
|
||||
|
||||
# Setup output directories
|
||||
output_dir = Path(args.output)
|
||||
@@ -371,7 +379,7 @@ def main():
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
db.create_tables() # Ensure tables exist
|
||||
print("Connected to database for status checking")
|
||||
logger.info("Connected to database for status checking")
|
||||
|
||||
# Global stats
|
||||
stats = {
|
||||
@@ -443,7 +451,7 @@ def main():
|
||||
db.save_documents_batch(db_batch)
|
||||
db_batch.clear()
|
||||
if args.verbose:
|
||||
print(f"Error processing {doc_id}: {error}")
|
||||
logger.error("Error processing %s: %s", doc_id, error)
|
||||
|
||||
# Initialize dual-pool coordinator if enabled (keeps workers alive across CSVs)
|
||||
dual_pool_coordinator = None
|
||||
@@ -453,7 +461,7 @@ def main():
|
||||
from training.processing import DualPoolCoordinator
|
||||
from training.processing.autolabel_tasks import process_text_pdf, process_scanned_pdf
|
||||
|
||||
print(f"\nStarting dual-pool mode: {args.cpu_workers} CPU + {args.gpu_workers} GPU workers")
|
||||
logger.info("Starting dual-pool mode: %d CPU + %d GPU workers", args.cpu_workers, args.gpu_workers)
|
||||
dual_pool_coordinator = DualPoolCoordinator(
|
||||
cpu_workers=args.cpu_workers,
|
||||
gpu_workers=args.gpu_workers,
|
||||
@@ -467,10 +475,10 @@ def main():
|
||||
for csv_idx, csv_file in enumerate(csv_files):
|
||||
# Check for shutdown request
|
||||
if _shutdown_requested:
|
||||
print("\nShutdown requested. Stopping after current batch...")
|
||||
logger.warning("Shutdown requested. Stopping after current batch...")
|
||||
break
|
||||
|
||||
print(f"\n[{csv_idx + 1}/{len(csv_files)}] Processing: {csv_file.name}")
|
||||
logger.info("[%d/%d] Processing: %s", csv_idx + 1, len(csv_files), csv_file.name)
|
||||
|
||||
# Load only this CSV file
|
||||
single_loader = CSVLoader(str(csv_file), str(pdf_dir))
|
||||
@@ -488,7 +496,7 @@ def main():
|
||||
seen_doc_ids.add(r.DocumentId)
|
||||
|
||||
if not rows:
|
||||
print(f" Skipping CSV (no new documents)")
|
||||
logger.info(" Skipping CSV (no new documents)")
|
||||
continue
|
||||
|
||||
# Batch query database for all document IDs in this CSV
|
||||
@@ -500,13 +508,13 @@ def main():
|
||||
|
||||
# Skip entire CSV if all documents are already processed
|
||||
if already_processed == len(rows):
|
||||
print(f" Skipping CSV (all {len(rows)} documents already processed)")
|
||||
logger.info(" Skipping CSV (all %d documents already processed)", len(rows))
|
||||
stats['skipped_db'] += len(rows)
|
||||
continue
|
||||
|
||||
# Count how many new documents need processing in this CSV
|
||||
new_to_process = len(rows) - already_processed
|
||||
print(f" Found {new_to_process} new documents to process ({already_processed} already in DB)")
|
||||
logger.info(" Found %d new documents to process (%d already in DB)", new_to_process, already_processed)
|
||||
|
||||
stats['total'] += len(rows)
|
||||
|
||||
@@ -520,7 +528,7 @@ def main():
|
||||
if args.limit:
|
||||
remaining_limit = args.limit - stats.get('tasks_submitted', 0)
|
||||
if remaining_limit <= 0:
|
||||
print(f" Reached limit of {args.limit} new documents, stopping.")
|
||||
logger.info(" Reached limit of %d new documents, stopping.", args.limit)
|
||||
break
|
||||
else:
|
||||
remaining_limit = float('inf')
|
||||
@@ -583,7 +591,7 @@ def main():
|
||||
))
|
||||
|
||||
if skipped_in_csv > 0 or retry_in_csv > 0:
|
||||
print(f" Skipped {skipped_in_csv} (already in DB), retrying {retry_in_csv} failed")
|
||||
logger.info(" Skipped %d (already in DB), retrying %d failed", skipped_in_csv, retry_in_csv)
|
||||
|
||||
# Clean up retry documents: delete from database and remove temp folders
|
||||
if retry_doc_ids:
|
||||
@@ -599,7 +607,7 @@ def main():
|
||||
temp_doc_dir = output_dir / 'temp' / doc_id
|
||||
if temp_doc_dir.exists():
|
||||
shutil.rmtree(temp_doc_dir, ignore_errors=True)
|
||||
print(f" Cleaned up {len(retry_doc_ids)} retry documents (DB + temp folders)")
|
||||
logger.info(" Cleaned up %d retry documents (DB + temp folders)", len(retry_doc_ids))
|
||||
|
||||
if not tasks:
|
||||
continue
|
||||
@@ -636,7 +644,7 @@ def main():
|
||||
# Count task types
|
||||
text_count = sum(1 for d in documents if not d["is_scanned"])
|
||||
scan_count = len(documents) - text_count
|
||||
print(f" Text PDFs: {text_count}, Scanned PDFs: {scan_count}")
|
||||
logger.info(" Text PDFs: %d, Scanned PDFs: %d", text_count, scan_count)
|
||||
|
||||
# Progress tracking with tqdm
|
||||
pbar = tqdm(total=len(documents), desc="Processing")
|
||||
@@ -667,11 +675,11 @@ def main():
|
||||
# Log summary
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
print(f" Batch complete: {successful} successful, {failed} failed")
|
||||
logger.info(" Batch complete: %d successful, %d failed", successful, failed)
|
||||
|
||||
else:
|
||||
# Single-pool mode (original behavior)
|
||||
print(f" Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info(" Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process documents in parallel (inside CSV loop for streaming)
|
||||
# Use single process for debugging or when workers=1
|
||||
@@ -725,28 +733,28 @@ def main():
|
||||
db.close()
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Auto-labeling Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total documents: {stats['total']}")
|
||||
print(f"Successful: {stats['successful']}")
|
||||
print(f"Failed: {stats['failed']}")
|
||||
print(f"Skipped (no PDF): {stats['skipped']}")
|
||||
print(f"Skipped (in DB): {stats['skipped_db']}")
|
||||
print(f"Retried (failed): {stats['retried']}")
|
||||
print(f"Total annotations: {stats['annotations']}")
|
||||
print(f"\nImages saved to: {output_dir / 'temp'}")
|
||||
print(f"Labels stored in: PostgreSQL database")
|
||||
print(f"\nAnnotations by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Auto-labeling Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total documents: %d", stats['total'])
|
||||
logger.info("Successful: %d", stats['successful'])
|
||||
logger.info("Failed: %d", stats['failed'])
|
||||
logger.info("Skipped (no PDF): %d", stats['skipped'])
|
||||
logger.info("Skipped (in DB): %d", stats['skipped_db'])
|
||||
logger.info("Retried (failed): %d", stats['retried'])
|
||||
logger.info("Total annotations: %d", stats['annotations'])
|
||||
logger.info("Images saved to: %s", output_dir / 'temp')
|
||||
logger.info("Labels stored in: PostgreSQL database")
|
||||
logger.info("Annotations by field:")
|
||||
for field, count in stats['by_field'].items():
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
shard_files = report_writer.get_shard_files()
|
||||
if len(shard_files) > 1:
|
||||
print(f"\nReport files ({len(shard_files)}):")
|
||||
logger.info("Report files (%d):", len(shard_files))
|
||||
for sf in shard_files:
|
||||
print(f" - {sf}")
|
||||
logger.info(" - %s", sf)
|
||||
else:
|
||||
print(f"\nReport: {shard_files[0] if shard_files else args.report}")
|
||||
logger.info("Report: %s", shard_files[0] if shard_files else args.report)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -8,6 +8,7 @@ Usage:
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
@@ -16,6 +17,9 @@ from psycopg2.extras import execute_values
|
||||
|
||||
# Add project root to path
|
||||
from shared.config import get_db_connection_string, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_tables(conn):
|
||||
@@ -150,7 +154,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
try:
|
||||
record = json.loads(line)
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" Warning: Line {line_no} - JSON parse error: {e}")
|
||||
logger.warning("Line %d - JSON parse error: %s", line_no, e)
|
||||
stats['errors'] += 1
|
||||
continue
|
||||
|
||||
@@ -211,7 +215,7 @@ def import_jsonl_file(conn, jsonl_path: Path, skip_existing: bool = True, batch_
|
||||
# Flush batch if needed
|
||||
if len(doc_batch) >= batch_size:
|
||||
flush_batches()
|
||||
print(f" Processed {stats['imported'] + stats['skipped']} records...")
|
||||
logger.info(" Processed %d records...", stats['imported'] + stats['skipped'])
|
||||
|
||||
# Final flush
|
||||
flush_batches()
|
||||
@@ -243,11 +247,14 @@ def main():
|
||||
else:
|
||||
report_files = [report_path] if report_path.exists() else []
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
if not report_files:
|
||||
print(f"No report files found: {args.report}")
|
||||
logger.error("No report files found: %s", args.report)
|
||||
return
|
||||
|
||||
print(f"Found {len(report_files)} report file(s)")
|
||||
logger.info("Found %d report file(s)", len(report_files))
|
||||
|
||||
# Connect to database
|
||||
conn = psycopg2.connect(db_connection)
|
||||
@@ -257,20 +264,20 @@ def main():
|
||||
total_stats = {'imported': 0, 'skipped': 0, 'errors': 0}
|
||||
|
||||
for report_file in report_files:
|
||||
print(f"\nImporting: {report_file.name}")
|
||||
logger.info("Importing: %s", report_file.name)
|
||||
stats = import_jsonl_file(conn, report_file, skip_existing=not args.no_skip, batch_size=args.batch_size)
|
||||
print(f" Imported: {stats['imported']}, Skipped: {stats['skipped']}, Errors: {stats['errors']}")
|
||||
logger.info(" Imported: %d, Skipped: %d, Errors: %d", stats['imported'], stats['skipped'], stats['errors'])
|
||||
|
||||
for key in total_stats:
|
||||
total_stats[key] += stats[key]
|
||||
|
||||
# Print summary
|
||||
print("\n" + "=" * 50)
|
||||
print("Import Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total imported: {total_stats['imported']}")
|
||||
print(f"Total skipped: {total_stats['skipped']}")
|
||||
print(f"Total errors: {total_stats['errors']}")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Import Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total imported: %d", total_stats['imported'])
|
||||
logger.info("Total skipped: %d", total_stats['skipped'])
|
||||
logger.info("Total errors: %d", total_stats['errors'])
|
||||
|
||||
# Quick stats from database
|
||||
with conn.cursor() as cursor:
|
||||
@@ -288,11 +295,11 @@ def main():
|
||||
|
||||
conn.close()
|
||||
|
||||
print(f"\nDatabase Stats:")
|
||||
print(f" Documents: {total_docs} ({success_docs} successful)")
|
||||
print(f" Field results: {total_fields} ({matched_fields} matched)")
|
||||
logger.info("Database Stats:")
|
||||
logger.info(" Documents: %d (%d successful)", total_docs, success_docs)
|
||||
logger.info(" Field results: %d (%d matched)", total_fields, matched_fields)
|
||||
if total_fields > 0:
|
||||
print(f" Match rate: {matched_fields / total_fields * 100:.2f}%")
|
||||
logger.info(" Match rate: %.2f%%", matched_fields / total_fields * 100)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -7,6 +7,7 @@ CSV values, and source CSV filename in a new table.
|
||||
import argparse
|
||||
import json
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
@@ -20,6 +21,9 @@ from shared.config import DEFAULT_DPI
|
||||
from shared.data.db import DocumentDB
|
||||
from shared.data.csv_loader import CSVLoader
|
||||
from shared.normalize.normalizer import normalize_field
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_failed_match_table(db: DocumentDB):
|
||||
@@ -57,7 +61,7 @@ def create_failed_match_table(db: DocumentDB):
|
||||
CREATE INDEX IF NOT EXISTS idx_failed_match_matched ON failed_match_details(matched);
|
||||
""")
|
||||
conn.commit()
|
||||
print("Created table: failed_match_details")
|
||||
logger.info("Created table: failed_match_details")
|
||||
|
||||
|
||||
def get_failed_documents(db: DocumentDB) -> list:
|
||||
@@ -332,14 +336,17 @@ def main():
|
||||
parser.add_argument('--limit', type=int, help='Limit number of documents to process')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Expand CSV glob
|
||||
csv_files = sorted(glob.glob(args.csv))
|
||||
print(f"Found {len(csv_files)} CSV files")
|
||||
logger.info("Found %d CSV files", len(csv_files))
|
||||
|
||||
# Build CSV cache
|
||||
print("Building CSV filename cache...")
|
||||
logger.info("Building CSV filename cache...")
|
||||
build_csv_cache(csv_files)
|
||||
print(f"Cached {len(_csv_cache)} document IDs")
|
||||
logger.info("Cached %d document IDs", len(_csv_cache))
|
||||
|
||||
# Connect to database
|
||||
db = DocumentDB()
|
||||
@@ -349,13 +356,13 @@ def main():
|
||||
create_failed_match_table(db)
|
||||
|
||||
# Get all failed documents
|
||||
print("Fetching failed documents...")
|
||||
logger.info("Fetching failed documents...")
|
||||
failed_docs = get_failed_documents(db)
|
||||
print(f"Found {len(failed_docs)} documents with failed matches")
|
||||
logger.info("Found %d documents with failed matches", len(failed_docs))
|
||||
|
||||
if args.limit:
|
||||
failed_docs = failed_docs[:args.limit]
|
||||
print(f"Limited to {len(failed_docs)} documents")
|
||||
logger.info("Limited to %d documents", len(failed_docs))
|
||||
|
||||
# Prepare tasks
|
||||
tasks = []
|
||||
@@ -365,7 +372,7 @@ def main():
|
||||
if failed_fields:
|
||||
tasks.append((doc, failed_fields, csv_filename))
|
||||
|
||||
print(f"Processing {len(tasks)} documents with {args.workers} workers...")
|
||||
logger.info("Processing %d documents with %d workers...", len(tasks), args.workers)
|
||||
|
||||
# Process with multiprocessing
|
||||
total_results = 0
|
||||
@@ -389,15 +396,15 @@ def main():
|
||||
batch_results = []
|
||||
|
||||
except TimeoutError:
|
||||
print(f"\nTimeout processing {doc_id}")
|
||||
logger.warning("Timeout processing %s", doc_id)
|
||||
except Exception as e:
|
||||
print(f"\nError processing {doc_id}: {e}")
|
||||
logger.error("Error processing %s: %s", doc_id, e)
|
||||
|
||||
# Save remaining results
|
||||
if batch_results:
|
||||
save_results_batch(db, batch_results)
|
||||
|
||||
print(f"\nDone! Saved {total_results} failed match records to failed_match_details table")
|
||||
logger.info("Done! Saved %d failed match records to failed_match_details table", total_results)
|
||||
|
||||
# Show summary
|
||||
conn = db.connect()
|
||||
@@ -410,12 +417,12 @@ def main():
|
||||
GROUP BY field_name
|
||||
ORDER BY total DESC
|
||||
""")
|
||||
print("\nSummary by field:")
|
||||
print("-" * 70)
|
||||
print(f"{'Field':<35} {'Total':>8} {'Has OCR':>10} {'Avg Score':>12}")
|
||||
print("-" * 70)
|
||||
logger.info("Summary by field:")
|
||||
logger.info("-" * 70)
|
||||
logger.info("%-35s %8s %10s %12s", 'Field', 'Total', 'Has OCR', 'Avg Score')
|
||||
logger.info("-" * 70)
|
||||
for row in cursor.fetchall():
|
||||
print(f"{row[0]:<35} {row[1]:>8} {row[2]:>10} {row[3]:>12.2f}")
|
||||
logger.info("%-35s %8d %10d %12.2f", row[0], row[1], row[2], row[3])
|
||||
|
||||
db.close()
|
||||
|
||||
|
||||
@@ -7,10 +7,14 @@ Images are read from filesystem, labels are dynamically generated from DB.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.config import DEFAULT_DPI, PATHS
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -119,47 +123,50 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
# Apply low-memory mode if specified
|
||||
if args.low_memory:
|
||||
print("🔧 Low memory mode enabled")
|
||||
logger.info("Low memory mode enabled")
|
||||
args.batch = min(args.batch, 8) # Reduce from 16 to 8
|
||||
args.workers = min(args.workers, 4) # Reduce from 8 to 4
|
||||
args.cache = False
|
||||
print(f" Batch size: {args.batch}")
|
||||
print(f" Workers: {args.workers}")
|
||||
print(f" Cache: disabled")
|
||||
logger.info(" Batch size: %d", args.batch)
|
||||
logger.info(" Workers: %d", args.workers)
|
||||
logger.info(" Cache: disabled")
|
||||
|
||||
# Validate dataset directory
|
||||
dataset_dir = Path(args.dataset_dir)
|
||||
temp_dir = dataset_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Error: Temp directory not found: {temp_dir}")
|
||||
print("Run autolabel first to generate images.")
|
||||
logger.error("Temp directory not found: %s", temp_dir)
|
||||
logger.error("Run autolabel first to generate images.")
|
||||
sys.exit(1)
|
||||
|
||||
print("=" * 60)
|
||||
print("YOLO Training with Database Labels")
|
||||
print("=" * 60)
|
||||
print(f"Dataset dir: {dataset_dir}")
|
||||
print(f"Model: {args.model}")
|
||||
print(f"Epochs: {args.epochs}")
|
||||
print(f"Batch size: {args.batch}")
|
||||
print(f"Image size: {args.imgsz}")
|
||||
print(f"Split ratio: {args.train_ratio}/{args.val_ratio}/{1-args.train_ratio-args.val_ratio:.1f}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("YOLO Training with Database Labels")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Dataset dir: %s", dataset_dir)
|
||||
logger.info("Model: %s", args.model)
|
||||
logger.info("Epochs: %d", args.epochs)
|
||||
logger.info("Batch size: %d", args.batch)
|
||||
logger.info("Image size: %d", args.imgsz)
|
||||
logger.info("Split ratio: %s/%s/%.1f", args.train_ratio, args.val_ratio, 1-args.train_ratio-args.val_ratio)
|
||||
if args.limit:
|
||||
print(f"Document limit: {args.limit}")
|
||||
logger.info("Document limit: %d", args.limit)
|
||||
|
||||
# Connect to database
|
||||
from shared.data.db import DocumentDB
|
||||
|
||||
print("\nConnecting to database...")
|
||||
logger.info("Connecting to database...")
|
||||
db = DocumentDB()
|
||||
db.connect()
|
||||
|
||||
# Create datasets from database
|
||||
from training.yolo.db_dataset import create_datasets
|
||||
|
||||
print("Loading dataset from database...")
|
||||
logger.info("Loading dataset from database...")
|
||||
datasets = create_datasets(
|
||||
images_dir=dataset_dir,
|
||||
db=db,
|
||||
@@ -170,39 +177,39 @@ def main():
|
||||
limit=args.limit
|
||||
)
|
||||
|
||||
print(f"\nDataset splits:")
|
||||
print(f" Train: {len(datasets['train'])} items")
|
||||
print(f" Val: {len(datasets['val'])} items")
|
||||
print(f" Test: {len(datasets['test'])} items")
|
||||
logger.info("Dataset splits:")
|
||||
logger.info(" Train: %d items", len(datasets['train']))
|
||||
logger.info(" Val: %d items", len(datasets['val']))
|
||||
logger.info(" Test: %d items", len(datasets['test']))
|
||||
|
||||
if len(datasets['train']) == 0:
|
||||
print("\nError: No training data found!")
|
||||
print("Make sure autolabel has been run and images exist in temp directory.")
|
||||
logger.error("No training data found!")
|
||||
logger.error("Make sure autolabel has been run and images exist in temp directory.")
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Export to YOLO format (required for Ultralytics training)
|
||||
print("\nExporting dataset to YOLO format...")
|
||||
logger.info("Exporting dataset to YOLO format...")
|
||||
for split_name, dataset in datasets.items():
|
||||
count = dataset.export_to_yolo_format(dataset_dir, split_name)
|
||||
print(f" {split_name}: {count} items exported")
|
||||
logger.info(" %s: %d items exported", split_name, count)
|
||||
|
||||
# Generate YOLO config files
|
||||
from training.yolo.annotation_generator import AnnotationGenerator
|
||||
|
||||
AnnotationGenerator.generate_classes_file(dataset_dir / 'classes.txt')
|
||||
AnnotationGenerator.generate_yaml_config(dataset_dir / 'dataset.yaml')
|
||||
print(f"\nGenerated dataset.yaml at: {dataset_dir / 'dataset.yaml'}")
|
||||
logger.info("Generated dataset.yaml at: %s", dataset_dir / 'dataset.yaml')
|
||||
|
||||
if args.export_only:
|
||||
print("\nExport complete (--export-only specified, skipping training)")
|
||||
logger.info("Export complete (--export-only specified, skipping training)")
|
||||
db.close()
|
||||
return
|
||||
|
||||
# Start training using shared trainer
|
||||
print("\n" + "=" * 60)
|
||||
print("Starting YOLO Training")
|
||||
print("=" * 60)
|
||||
logger.info("=" * 60)
|
||||
logger.info("Starting YOLO Training")
|
||||
logger.info("=" * 60)
|
||||
|
||||
from shared.training import YOLOTrainer, TrainingConfig
|
||||
|
||||
@@ -232,30 +239,30 @@ def main():
|
||||
result = trainer.train()
|
||||
|
||||
if not result.success:
|
||||
print(f"\nError: Training failed - {result.error}")
|
||||
logger.error("Training failed - %s", result.error)
|
||||
db.close()
|
||||
sys.exit(1)
|
||||
|
||||
# Print results
|
||||
print("\n" + "=" * 60)
|
||||
print("Training Complete")
|
||||
print("=" * 60)
|
||||
print(f"Best model: {result.model_path}")
|
||||
print(f"Save directory: {result.save_dir}")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Training Complete")
|
||||
logger.info("=" * 60)
|
||||
logger.info("Best model: %s", result.model_path)
|
||||
logger.info("Save directory: %s", result.save_dir)
|
||||
if result.metrics:
|
||||
print(f"mAP@0.5: {result.metrics.get('mAP50', 'N/A')}")
|
||||
print(f"mAP@0.5-0.95: {result.metrics.get('mAP50-95', 'N/A')}")
|
||||
logger.info("mAP@0.5: %s", result.metrics.get('mAP50', 'N/A'))
|
||||
logger.info("mAP@0.5-0.95: %s", result.metrics.get('mAP50-95', 'N/A'))
|
||||
|
||||
# Validate on test set
|
||||
print("\nRunning validation on test set...")
|
||||
logger.info("Running validation on test set...")
|
||||
if result.model_path:
|
||||
config.model_path = result.model_path
|
||||
config.data_yaml = str(data_yaml)
|
||||
test_trainer = YOLOTrainer(config=config)
|
||||
test_metrics = test_trainer.validate(split='test')
|
||||
if test_metrics:
|
||||
print(f"mAP50: {test_metrics.get('mAP50', 0):.4f}")
|
||||
print(f"mAP50-95: {test_metrics.get('mAP50-95', 0):.4f}")
|
||||
logger.info("mAP50: %.4f", test_metrics.get('mAP50', 0))
|
||||
logger.info("mAP50-95: %.4f", test_metrics.get('mAP50-95', 0))
|
||||
|
||||
# Close database
|
||||
db.close()
|
||||
|
||||
@@ -7,9 +7,14 @@ and comparing the extraction results.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from shared.logging_config import setup_cli_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
@@ -73,6 +78,9 @@ def main():
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
# Configure logging for CLI
|
||||
setup_cli_logging()
|
||||
|
||||
from backend.validation import LLMValidator
|
||||
|
||||
validator = LLMValidator()
|
||||
@@ -104,60 +112,58 @@ def show_stats(validator):
|
||||
"""Show statistics about failed matches."""
|
||||
stats = validator.get_failed_match_stats()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Failed Match Statistics")
|
||||
print("=" * 50)
|
||||
print(f"\nDocuments with failures: {stats['documents_with_failures']}")
|
||||
print(f"Already validated: {stats['already_validated']}")
|
||||
print(f"Remaining to validate: {stats['remaining']}")
|
||||
print("\nFailures by field:")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Failed Match Statistics")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Documents with failures: %d", stats['documents_with_failures'])
|
||||
logger.info("Already validated: %d", stats['already_validated'])
|
||||
logger.info("Remaining to validate: %d", stats['remaining'])
|
||||
logger.info("Failures by field:")
|
||||
for field, count in sorted(stats['failures_by_field'].items(), key=lambda x: -x[1]):
|
||||
print(f" {field}: {count}")
|
||||
logger.info(" %s: %d", field, count)
|
||||
|
||||
|
||||
def validate_single(validator, doc_id: str, provider: str, model: str):
|
||||
"""Validate a single document."""
|
||||
print(f"\nValidating document: {doc_id}")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating document: %s", doc_id)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
result = validator.validate_document(doc_id, provider, model)
|
||||
|
||||
if result.error:
|
||||
print(f"ERROR: {result.error}")
|
||||
logger.error("ERROR: %s", result.error)
|
||||
return
|
||||
|
||||
print(f"Processing time: {result.processing_time_ms:.0f}ms")
|
||||
print(f"Model used: {result.model_used}")
|
||||
print("\nExtracted fields:")
|
||||
print(f" Invoice Number: {result.invoice_number}")
|
||||
print(f" Invoice Date: {result.invoice_date}")
|
||||
print(f" Due Date: {result.invoice_due_date}")
|
||||
print(f" OCR: {result.ocr_number}")
|
||||
print(f" Bankgiro: {result.bankgiro}")
|
||||
print(f" Plusgiro: {result.plusgiro}")
|
||||
print(f" Amount: {result.amount}")
|
||||
print(f" Org Number: {result.supplier_organisation_number}")
|
||||
logger.info("Processing time: %.0fms", result.processing_time_ms)
|
||||
logger.info("Model used: %s", result.model_used)
|
||||
logger.info("Extracted fields:")
|
||||
logger.info(" Invoice Number: %s", result.invoice_number)
|
||||
logger.info(" Invoice Date: %s", result.invoice_date)
|
||||
logger.info(" Due Date: %s", result.invoice_due_date)
|
||||
logger.info(" OCR: %s", result.ocr_number)
|
||||
logger.info(" Bankgiro: %s", result.bankgiro)
|
||||
logger.info(" Plusgiro: %s", result.plusgiro)
|
||||
logger.info(" Amount: %s", result.amount)
|
||||
logger.info(" Org Number: %s", result.supplier_organisation_number)
|
||||
|
||||
# Show comparison
|
||||
print("\n" + "-" * 50)
|
||||
print("Comparison with autolabel:")
|
||||
logger.info("-" * 50)
|
||||
logger.info("Comparison with autolabel:")
|
||||
comparison = validator.compare_results(doc_id)
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value'):
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
print(f" {status} {field}:")
|
||||
print(f" CSV: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM: {data['llm_value']}")
|
||||
logger.info(" %s %s:", status, field)
|
||||
logger.info(" CSV: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM: %s", data['llm_value'])
|
||||
|
||||
|
||||
def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
"""Validate a batch of documents."""
|
||||
print(f"\nValidating up to {limit} documents with failed matches")
|
||||
print(f"Provider: {provider}, Model: {model or 'default'}")
|
||||
print()
|
||||
logger.info("Validating up to %d documents with failed matches", limit)
|
||||
logger.info("Provider: %s, Model: %s", provider, model or 'default')
|
||||
|
||||
results = validator.validate_batch(
|
||||
limit=limit,
|
||||
@@ -171,15 +177,15 @@ def validate_batch(validator, limit: int, provider: str, model: str):
|
||||
failed = len(results) - success
|
||||
total_time = sum(r.processing_time_ms or 0 for r in results)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("Validation Complete")
|
||||
print("=" * 50)
|
||||
print(f"Total: {len(results)}")
|
||||
print(f"Success: {success}")
|
||||
print(f"Failed: {failed}")
|
||||
print(f"Total time: {total_time/1000:.1f}s")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Validation Complete")
|
||||
logger.info("=" * 50)
|
||||
logger.info("Total: %d", len(results))
|
||||
logger.info("Success: %d", success)
|
||||
logger.info("Failed: %d", failed)
|
||||
logger.info("Total time: %.1fs", total_time/1000)
|
||||
if success > 0:
|
||||
print(f"Avg time: {total_time/success:.0f}ms per document")
|
||||
logger.info("Avg time: %.0fms per document", total_time/success)
|
||||
|
||||
|
||||
def compare_single(validator, doc_id: str):
|
||||
@@ -187,23 +193,23 @@ def compare_single(validator, doc_id: str):
|
||||
comparison = validator.compare_results(doc_id)
|
||||
|
||||
if 'error' in comparison:
|
||||
print(f"Error: {comparison['error']}")
|
||||
logger.error("Error: %s", comparison['error'])
|
||||
return
|
||||
|
||||
print(f"\nComparison for document: {doc_id}")
|
||||
print("=" * 60)
|
||||
logger.info("Comparison for document: %s", doc_id)
|
||||
logger.info("=" * 60)
|
||||
|
||||
for field, data in comparison.items():
|
||||
if data.get('csv_value') is None:
|
||||
continue
|
||||
|
||||
status = "✓" if data['agreement'] else "✗"
|
||||
status = "[OK]" if data['agreement'] else "[FAIL]"
|
||||
auto_status = "matched" if data['autolabel_matched'] else "FAILED"
|
||||
|
||||
print(f"\n{status} {field}:")
|
||||
print(f" CSV value: {data['csv_value']}")
|
||||
print(f" Autolabel: {data['autolabel_text']} ({auto_status})")
|
||||
print(f" LLM extracted: {data['llm_value']}")
|
||||
logger.info("%s %s:", status, field)
|
||||
logger.info(" CSV value: %s", data['csv_value'])
|
||||
logger.info(" Autolabel: %s (%s)", data['autolabel_text'], auto_status)
|
||||
logger.info(" LLM extracted: %s", data['llm_value'])
|
||||
|
||||
|
||||
def compare_all(validator, limit: int):
|
||||
@@ -220,11 +226,11 @@ def compare_all(validator, limit: int):
|
||||
doc_ids = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if not doc_ids:
|
||||
print("No validated documents found.")
|
||||
logger.info("No validated documents found.")
|
||||
return
|
||||
|
||||
print(f"\nComparison Summary ({len(doc_ids)} documents)")
|
||||
print("=" * 80)
|
||||
logger.info("Comparison Summary (%d documents)", len(doc_ids))
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Aggregate stats
|
||||
field_stats = {}
|
||||
@@ -259,12 +265,12 @@ def compare_all(validator, limit: int):
|
||||
if not data['autolabel_matched'] and data['agreement']:
|
||||
stats['llm_correct_auto_wrong'] += 1
|
||||
|
||||
print(f"\n{'Field':<30} {'Total':>6} {'Auto OK':>8} {'LLM Agrees':>10} {'LLM Found':>10}")
|
||||
print("-" * 80)
|
||||
logger.info("%-30s %6s %8s %10s %10s", 'Field', 'Total', 'Auto OK', 'LLM Agrees', 'LLM Found')
|
||||
logger.info("-" * 80)
|
||||
|
||||
for field, stats in sorted(field_stats.items()):
|
||||
print(f"{field:<30} {stats['total']:>6} {stats['autolabel_matched']:>8} "
|
||||
f"{stats['llm_agrees']:>10} {stats['llm_correct_auto_wrong']:>10}")
|
||||
logger.info("%-30s %6d %8d %10d %10d", field, stats['total'], stats['autolabel_matched'],
|
||||
stats['llm_agrees'], stats['llm_correct_auto_wrong'])
|
||||
|
||||
|
||||
def generate_report(validator, output_path: str):
|
||||
@@ -328,8 +334,8 @@ def generate_report(validator, output_path: str):
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||
|
||||
print(f"\nReport generated: {output_path}")
|
||||
print(f"Total validations: {len(validations)}")
|
||||
logger.info("Report generated: %s", output_path)
|
||||
logger.info("Total validations: %d", len(validations))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
YOLO Annotation Generator
|
||||
|
||||
Generates YOLO format annotations from matched fields.
|
||||
Uses field-specific bbox expansion strategies for optimal training data.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
@@ -14,7 +15,9 @@ from shared.fields import (
|
||||
TRAINING_FIELD_CLASSES as FIELD_CLASSES,
|
||||
CLASS_NAMES,
|
||||
ACCOUNT_FIELD_MAPPING,
|
||||
FIELD_TO_CLASS,
|
||||
)
|
||||
from shared.bbox import expand_bbox
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,19 +41,16 @@ class AnnotationGenerator:
|
||||
def __init__(
|
||||
self,
|
||||
min_confidence: float = 0.7,
|
||||
bbox_padding_px: int = 20, # Absolute padding in pixels
|
||||
min_bbox_height_px: int = 30 # Minimum bbox height
|
||||
min_bbox_height_px: int = 30, # Minimum bbox height
|
||||
):
|
||||
"""
|
||||
Initialize annotation generator.
|
||||
|
||||
Args:
|
||||
min_confidence: Minimum match score to include in training
|
||||
bbox_padding_px: Absolute padding in pixels to add around bboxes
|
||||
min_bbox_height_px: Minimum bbox height in pixels
|
||||
"""
|
||||
self.min_confidence = min_confidence
|
||||
self.bbox_padding_px = bbox_padding_px
|
||||
self.min_bbox_height_px = min_bbox_height_px
|
||||
|
||||
def generate_from_matches(
|
||||
@@ -63,6 +63,10 @@ class AnnotationGenerator:
|
||||
"""
|
||||
Generate YOLO annotations from field matches.
|
||||
|
||||
Uses field-specific bbox expansion strategies for optimal training data.
|
||||
Each field type has customized scale factors and directional compensation
|
||||
to capture field labels and context.
|
||||
|
||||
Args:
|
||||
matches: Dict of field_name -> list of Match objects
|
||||
image_width: Width of the rendered image in pixels
|
||||
@@ -82,6 +86,8 @@ class AnnotationGenerator:
|
||||
continue
|
||||
|
||||
class_id = FIELD_CLASSES[field_name]
|
||||
# Get class_name for bbox expansion strategy
|
||||
class_name = FIELD_TO_CLASS.get(field_name, field_name)
|
||||
|
||||
# Take only the best match per field
|
||||
if field_matches:
|
||||
@@ -94,19 +100,20 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = best_match.bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
# Apply field-specific bbox expansion strategy
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
current_height = y1 - y0
|
||||
if current_height < self.min_bbox_height_px:
|
||||
extra = (self.min_bbox_height_px - current_height) / 2
|
||||
y0 = max(0, y0 - extra)
|
||||
y1 = min(image_height, y1 + extra)
|
||||
y0 = max(0, int(y0 - extra))
|
||||
y1 = min(int(image_height), int(y1 + extra))
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
@@ -143,6 +150,9 @@ class AnnotationGenerator:
|
||||
"""
|
||||
Add payment_line annotation from machine code parser result.
|
||||
|
||||
Uses "payment_line" scale strategy for minimal expansion
|
||||
(machine-readable code needs less context).
|
||||
|
||||
Args:
|
||||
annotations: Existing list of annotations to append to
|
||||
payment_line_bbox: Bounding box (x0, y0, x1, y1) in PDF coordinates
|
||||
@@ -163,12 +173,13 @@ class AnnotationGenerator:
|
||||
x0, y0, x1, y1 = payment_line_bbox
|
||||
x0, y0, x1, y1 = x0 * scale, y0 * scale, x1 * scale, y1 * scale
|
||||
|
||||
# Add absolute padding
|
||||
pad = self.bbox_padding_px
|
||||
x0 = max(0, x0 - pad)
|
||||
y0 = max(0, y0 - pad)
|
||||
x1 = min(image_width, x1 + pad)
|
||||
y1 = min(image_height, y1 + pad)
|
||||
# Apply field-specific bbox expansion strategy for payment_line
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
field_type="payment_line",
|
||||
)
|
||||
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / image_width
|
||||
|
||||
@@ -18,7 +18,8 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from shared.config import DEFAULT_DPI
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES
|
||||
from shared.fields import TRAINING_FIELD_CLASSES as FIELD_CLASSES, CLASS_NAMES
|
||||
from shared.bbox import expand_bbox
|
||||
from .annotation_generator import YOLOAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +157,7 @@ class DBYOLODataset:
|
||||
|
||||
# Split items for this split
|
||||
instance.items = instance._split_dataset_from_cache()
|
||||
print(f"Split '{split}': {len(instance.items)} items")
|
||||
logger.info("Split '%s': %d items", split, len(instance.items))
|
||||
|
||||
return instance
|
||||
|
||||
@@ -165,7 +166,7 @@ class DBYOLODataset:
|
||||
# Find all document directories
|
||||
temp_dir = self.images_dir / 'temp'
|
||||
if not temp_dir.exists():
|
||||
print(f"Temp directory not found: {temp_dir}")
|
||||
logger.warning("Temp directory not found: %s", temp_dir)
|
||||
return
|
||||
|
||||
# Collect all document IDs with images
|
||||
@@ -182,13 +183,13 @@ class DBYOLODataset:
|
||||
if images:
|
||||
doc_image_map[doc_dir.name] = sorted(images)
|
||||
|
||||
print(f"Found {len(doc_image_map)} documents with images")
|
||||
logger.info("Found %d documents with images", len(doc_image_map))
|
||||
|
||||
# Query database for all document labels
|
||||
doc_ids = list(doc_image_map.keys())
|
||||
doc_labels = self._load_labels_from_db(doc_ids)
|
||||
|
||||
print(f"Loaded labels for {len(doc_labels)} documents from database")
|
||||
logger.info("Loaded labels for %d documents from database", len(doc_labels))
|
||||
|
||||
# Build dataset items
|
||||
all_items: list[DatasetItem] = []
|
||||
@@ -227,19 +228,19 @@ class DBYOLODataset:
|
||||
else:
|
||||
skipped_no_labels += 1
|
||||
|
||||
print(f"Total images found: {total_images}")
|
||||
print(f"Images with labels: {len(all_items)}")
|
||||
logger.info("Total images found: %d", total_images)
|
||||
logger.info("Images with labels: %d", len(all_items))
|
||||
if skipped_no_db_record > 0:
|
||||
print(f"Skipped {skipped_no_db_record} images (document not in database)")
|
||||
logger.info("Skipped %d images (document not in database)", skipped_no_db_record)
|
||||
if skipped_no_labels > 0:
|
||||
print(f"Skipped {skipped_no_labels} images (no labels for page)")
|
||||
logger.info("Skipped %d images (no labels for page)", skipped_no_labels)
|
||||
|
||||
# Cache all items for sharing with other splits
|
||||
self._all_items = all_items
|
||||
|
||||
# Split dataset
|
||||
self.items, self._doc_ids_ordered = self._split_dataset(all_items)
|
||||
print(f"Split '{self.split}': {len(self.items)} items")
|
||||
logger.info("Split '%s': %d items", self.split, len(self.items))
|
||||
|
||||
def _load_labels_from_db(self, doc_ids: list[str]) -> dict[str, tuple[dict[int, list[YOLOAnnotation]], bool, str | None]]:
|
||||
"""
|
||||
@@ -374,7 +375,7 @@ class DBYOLODataset:
|
||||
|
||||
if has_csv_splits:
|
||||
# Use CSV-defined splits
|
||||
print("Using CSV-defined split field for train/val/test assignment")
|
||||
logger.info("Using CSV-defined split field for train/val/test assignment")
|
||||
|
||||
# Map split values: 'train' -> train, 'test' -> test, None -> train (fallback)
|
||||
# 'val' is taken from train set using val_ratio
|
||||
@@ -411,11 +412,11 @@ class DBYOLODataset:
|
||||
# Apply limit if specified
|
||||
if self.limit is not None and self.limit < len(split_doc_ids):
|
||||
split_doc_ids = split_doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
else:
|
||||
# Fall back to random splitting (original behavior)
|
||||
print("No CSV split field found, using random splitting")
|
||||
logger.info("No CSV split field found, using random splitting")
|
||||
|
||||
random.seed(self.seed)
|
||||
random.shuffle(doc_ids)
|
||||
@@ -423,7 +424,7 @@ class DBYOLODataset:
|
||||
# Apply limit if specified (before splitting)
|
||||
if self.limit is not None and self.limit < len(doc_ids):
|
||||
doc_ids = doc_ids[:self.limit]
|
||||
print(f"Limited to {self.limit} documents")
|
||||
logger.info("Limited to %d documents", self.limit)
|
||||
|
||||
# Calculate split indices
|
||||
n_total = len(doc_ids)
|
||||
@@ -549,6 +550,8 @@ class DBYOLODataset:
|
||||
"""
|
||||
Convert annotations to normalized YOLO format.
|
||||
|
||||
Uses field-specific bbox expansion strategies via expand_bbox.
|
||||
|
||||
Args:
|
||||
annotations: List of annotations
|
||||
img_width: Actual image width in pixels
|
||||
@@ -568,26 +571,43 @@ class DBYOLODataset:
|
||||
|
||||
labels = []
|
||||
for ann in annotations:
|
||||
# Convert to pixels (if needed)
|
||||
x_center_px = ann.x_center * scale
|
||||
y_center_px = ann.y_center * scale
|
||||
width_px = ann.width * scale
|
||||
height_px = ann.height * scale
|
||||
# Convert center+size to corner coords in PDF points
|
||||
half_w = ann.width / 2
|
||||
half_h = ann.height / 2
|
||||
x0_pdf = ann.x_center - half_w
|
||||
y0_pdf = ann.y_center - half_h
|
||||
x1_pdf = ann.x_center + half_w
|
||||
y1_pdf = ann.y_center + half_h
|
||||
|
||||
# Add padding
|
||||
pad = self.bbox_padding_px
|
||||
width_px += 2 * pad
|
||||
height_px += 2 * pad
|
||||
# Convert to pixels
|
||||
x0_px = x0_pdf * scale
|
||||
y0_px = y0_pdf * scale
|
||||
x1_px = x1_pdf * scale
|
||||
y1_px = y1_pdf * scale
|
||||
|
||||
# Get class name for field-specific expansion
|
||||
class_name = CLASS_NAMES[ann.class_id]
|
||||
|
||||
# Apply field-specific bbox expansion
|
||||
x0, y0, x1, y1 = expand_bbox(
|
||||
bbox=(x0_px, y0_px, x1_px, y1_px),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Ensure minimum height
|
||||
height_px = y1 - y0
|
||||
if height_px < self.min_bbox_height_px:
|
||||
height_px = self.min_bbox_height_px
|
||||
extra = (self.min_bbox_height_px - height_px) / 2
|
||||
y0 = max(0, int(y0 - extra))
|
||||
y1 = min(img_height, int(y1 + extra))
|
||||
|
||||
# Normalize to 0-1
|
||||
x_center = x_center_px / img_width
|
||||
y_center = y_center_px / img_height
|
||||
width = width_px / img_width
|
||||
height = height_px / img_height
|
||||
# Convert to YOLO format (normalized center + size)
|
||||
x_center = (x0 + x1) / 2 / img_width
|
||||
y_center = (y0 + y1) / 2 / img_height
|
||||
width = (x1 - x0) / img_width
|
||||
height = (y1 - y0) / img_height
|
||||
|
||||
# Clamp to valid range
|
||||
x_center = max(0, min(1, x_center))
|
||||
@@ -675,7 +695,7 @@ class DBYOLODataset:
|
||||
|
||||
count += 1
|
||||
|
||||
print(f"Exported {count} items to {output_dir / split_name}")
|
||||
logger.info("Exported %d items to %s", count, output_dir / split_name)
|
||||
return count
|
||||
|
||||
|
||||
@@ -706,7 +726,7 @@ def create_datasets(
|
||||
Dict with 'train', 'val', 'test' datasets
|
||||
"""
|
||||
# Create first dataset which loads all data
|
||||
print("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
logger.info("Loading dataset (this may take a few minutes for large datasets)...")
|
||||
first_dataset = DBYOLODataset(
|
||||
images_dir=images_dir,
|
||||
db=db,
|
||||
|
||||
1
tests/domain/__init__.py
Normal file
1
tests/domain/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Domain layer tests
|
||||
176
tests/domain/test_document_classifier.py
Normal file
176
tests/domain/test_document_classifier.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Tests for DocumentClassifier - TDD RED phase.
|
||||
|
||||
Test document type classification based on extracted fields.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from backend.domain.document_classifier import DocumentClassifier, ClassificationResult
|
||||
|
||||
|
||||
class TestDocumentClassifier:
|
||||
"""Test document classification logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def classifier(self) -> DocumentClassifier:
|
||||
"""Create classifier instance."""
|
||||
return DocumentClassifier()
|
||||
|
||||
# ==================== Invoice Detection Tests ====================
|
||||
|
||||
def test_classify_with_payment_line_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Payment line is the strongest invoice indicator."""
|
||||
fields = {"payment_line": "# 123456 # 100 00 5 > 308-2963#"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.9
|
||||
assert "payment_line" in result.reason
|
||||
|
||||
def test_classify_with_multiple_indicators_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Multiple invoice indicators -> invoice with medium confidence."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"Bankgiro": "123-4567",
|
||||
"payment_line": None,
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.7
|
||||
|
||||
def test_classify_with_ocr_and_amount_returns_invoice(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""OCR + Amount is typical invoice pattern."""
|
||||
fields = {
|
||||
"OCR": "123456789012",
|
||||
"Amount": "500.00",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert result.confidence >= 0.7
|
||||
|
||||
def test_classify_with_single_indicator_returns_invoice_lower_confidence(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Single indicator -> invoice but lower confidence."""
|
||||
fields = {"Amount": "100.00"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
assert 0.5 <= result.confidence < 0.8
|
||||
|
||||
def test_classify_with_invoice_number_only(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Invoice number alone suggests invoice."""
|
||||
fields = {"InvoiceNumber": "INV-2024-001"}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "invoice"
|
||||
|
||||
# ==================== Letter Detection Tests ====================
|
||||
|
||||
def test_classify_with_no_indicators_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""No invoice indicators -> letter."""
|
||||
fields: dict[str, str | None] = {}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
assert result.confidence >= 0.5
|
||||
|
||||
def test_classify_with_empty_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""All fields empty or None -> letter."""
|
||||
fields = {
|
||||
"payment_line": None,
|
||||
"OCR": None,
|
||||
"Amount": None,
|
||||
"Bankgiro": None,
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
def test_classify_with_only_non_indicator_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Fields that don't indicate invoice -> letter."""
|
||||
fields = {
|
||||
"CustomerNumber": "C12345",
|
||||
"SupplierOrgNumber": "556677-8899",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
# ==================== Edge Cases ====================
|
||||
|
||||
def test_classify_with_empty_string_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Empty strings should be treated as missing."""
|
||||
fields = {
|
||||
"payment_line": "",
|
||||
"Amount": "",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
def test_classify_with_whitespace_only_fields_returns_letter(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""Whitespace-only strings should be treated as missing."""
|
||||
fields = {
|
||||
"payment_line": " ",
|
||||
"Amount": "\t\n",
|
||||
}
|
||||
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert result.document_type == "letter"
|
||||
|
||||
# ==================== ClassificationResult Immutability ====================
|
||||
|
||||
def test_classification_result_is_immutable(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""ClassificationResult should be a frozen dataclass."""
|
||||
fields = {"payment_line": "test"}
|
||||
result = classifier.classify(fields)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.document_type = "modified" # type: ignore
|
||||
|
||||
def test_classification_result_has_required_fields(
|
||||
self, classifier: DocumentClassifier
|
||||
) -> None:
|
||||
"""ClassificationResult must have document_type, confidence, reason."""
|
||||
fields = {"Amount": "100.00"}
|
||||
result = classifier.classify(fields)
|
||||
|
||||
assert hasattr(result, "document_type")
|
||||
assert hasattr(result, "confidence")
|
||||
assert hasattr(result, "reason")
|
||||
assert isinstance(result.document_type, str)
|
||||
assert isinstance(result.confidence, float)
|
||||
assert isinstance(result.reason, str)
|
||||
232
tests/domain/test_invoice_validator.py
Normal file
232
tests/domain/test_invoice_validator.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Tests for InvoiceValidator - TDD RED phase.
|
||||
|
||||
Test invoice field validation logic.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from backend.domain.invoice_validator import (
|
||||
InvoiceValidator,
|
||||
ValidationResult,
|
||||
ValidationIssue,
|
||||
)
|
||||
|
||||
|
||||
class TestInvoiceValidator:
|
||||
"""Test invoice validation logic."""
|
||||
|
||||
@pytest.fixture
|
||||
def validator(self) -> InvoiceValidator:
|
||||
"""Create validator instance with default settings."""
|
||||
return InvoiceValidator()
|
||||
|
||||
@pytest.fixture
|
||||
def validator_strict(self) -> InvoiceValidator:
|
||||
"""Create validator with strict confidence threshold."""
|
||||
return InvoiceValidator(min_confidence=0.8)
|
||||
|
||||
# ==================== Valid Invoice Tests ====================
|
||||
|
||||
def test_validate_complete_invoice_is_valid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Complete invoice with all required fields is valid."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
"Bankgiro": "123-4567",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.95,
|
||||
"OCR": 0.90,
|
||||
"Bankgiro": 0.85,
|
||||
}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert len([i for i in result.issues if i.severity == "error"]) == 0
|
||||
|
||||
def test_validate_invoice_with_payment_line_is_valid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Invoice with payment_line as payment reference is valid."""
|
||||
fields = {
|
||||
"Amount": "500.00",
|
||||
"payment_line": "# 123 # 500 00 5 > 308#",
|
||||
}
|
||||
confidence = {"Amount": 0.9, "payment_line": 0.85}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is True
|
||||
|
||||
# ==================== Invalid Invoice Tests ====================
|
||||
|
||||
def test_validate_missing_amount_is_invalid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Missing Amount field should produce error."""
|
||||
fields = {
|
||||
"OCR": "123456789012",
|
||||
"Bankgiro": "123-4567",
|
||||
}
|
||||
confidence = {"OCR": 0.9, "Bankgiro": 0.85}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||
assert "Amount" in error_fields
|
||||
|
||||
def test_validate_missing_payment_reference_produces_warning(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Missing all payment references should produce warning."""
|
||||
fields = {"Amount": "1200.00"}
|
||||
confidence = {"Amount": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
# Missing payment ref is warning, not error
|
||||
warning_fields = [i.field for i in result.issues if i.severity == "warning"]
|
||||
assert "payment_reference" in warning_fields
|
||||
|
||||
# ==================== Confidence Threshold Tests ====================
|
||||
|
||||
def test_validate_low_confidence_produces_warning(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Fields below confidence threshold should produce warning."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.9,
|
||||
"OCR": 0.3, # Below default threshold of 0.5
|
||||
}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
low_conf_warnings = [
|
||||
i for i in result.issues
|
||||
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||
]
|
||||
assert len(low_conf_warnings) > 0
|
||||
|
||||
def test_validate_strict_threshold_more_warnings(
|
||||
self, validator_strict: InvoiceValidator
|
||||
) -> None:
|
||||
"""Strict validator should produce more warnings."""
|
||||
fields = {
|
||||
"Amount": "1200.00",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {
|
||||
"Amount": 0.7, # Below 0.8 threshold
|
||||
"OCR": 0.6, # Below 0.8 threshold
|
||||
}
|
||||
|
||||
result = validator_strict.validate(fields, confidence)
|
||||
|
||||
low_conf_warnings = [
|
||||
i for i in result.issues
|
||||
if i.severity == "warning" and "confidence" in i.message.lower()
|
||||
]
|
||||
assert len(low_conf_warnings) >= 2
|
||||
|
||||
# ==================== Edge Cases ====================
|
||||
|
||||
def test_validate_empty_fields_is_invalid(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Empty fields dict should be invalid."""
|
||||
fields: dict[str, str | None] = {}
|
||||
confidence: dict[str, float] = {}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
def test_validate_none_field_values_treated_as_missing(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""None values should be treated as missing."""
|
||||
fields = {
|
||||
"Amount": None,
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {"OCR": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
error_fields = [i.field for i in result.issues if i.severity == "error"]
|
||||
assert "Amount" in error_fields
|
||||
|
||||
def test_validate_empty_string_treated_as_missing(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Empty string should be treated as missing."""
|
||||
fields = {
|
||||
"Amount": "",
|
||||
"OCR": "123456789012",
|
||||
}
|
||||
confidence = {"OCR": 0.9}
|
||||
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert result.is_valid is False
|
||||
|
||||
# ==================== ValidationResult Properties ====================
|
||||
|
||||
def test_validation_result_is_immutable(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationResult should be a frozen dataclass."""
|
||||
fields = {"Amount": "100.00", "OCR": "123"}
|
||||
confidence = {"Amount": 0.9, "OCR": 0.9}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.is_valid = False # type: ignore
|
||||
|
||||
def test_validation_result_issues_is_tuple(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""Issues should be a tuple (immutable)."""
|
||||
fields = {"Amount": "100.00"}
|
||||
confidence = {"Amount": 0.9}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert isinstance(result.issues, tuple)
|
||||
|
||||
def test_validation_result_has_confidence(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationResult should have confidence score."""
|
||||
fields = {"Amount": "100.00", "OCR": "123"}
|
||||
confidence = {"Amount": 0.9, "OCR": 0.8}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert hasattr(result, "confidence")
|
||||
assert 0.0 <= result.confidence <= 1.0
|
||||
|
||||
# ==================== ValidationIssue Tests ====================
|
||||
|
||||
def test_validation_issue_has_required_fields(
|
||||
self, validator: InvoiceValidator
|
||||
) -> None:
|
||||
"""ValidationIssue must have field, severity, message."""
|
||||
fields: dict[str, str | None] = {}
|
||||
confidence: dict[str, float] = {}
|
||||
result = validator.validate(fields, confidence)
|
||||
|
||||
assert len(result.issues) > 0
|
||||
issue = result.issues[0]
|
||||
|
||||
assert hasattr(issue, "field")
|
||||
assert hasattr(issue, "severity")
|
||||
assert hasattr(issue, "message")
|
||||
assert issue.severity in ("error", "warning", "info")
|
||||
1
tests/shared/bbox/__init__.py
Normal file
1
tests/shared/bbox/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for shared.bbox module."""
|
||||
556
tests/shared/bbox/test_expander.py
Normal file
556
tests/shared/bbox/test_expander.py
Normal file
@@ -0,0 +1,556 @@
|
||||
"""
|
||||
Tests for expand_bbox function.
|
||||
|
||||
Tests verify that bbox expansion works correctly with center-point scaling,
|
||||
directional compensation, max padding clamping, and image boundary handling.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
expand_bbox,
|
||||
ScaleStrategy,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
DEFAULT_STRATEGY,
|
||||
)
|
||||
|
||||
|
||||
class TestExpandBboxCenterScaling:
|
||||
"""Tests for center-point based scaling."""
|
||||
|
||||
def test_center_scaling_expands_symmetrically(self):
|
||||
"""Verify bbox expands symmetrically around center when no extra ratios."""
|
||||
# 100x50 bbox at (100, 200)
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider
|
||||
scale_y=1.4, # 40% taller
|
||||
max_pad_x=1000, # Large to avoid clamping
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Original: width=100, height=50
|
||||
# New: width=120, height=70
|
||||
# Center: (150, 225)
|
||||
# Expected: x0=150-60=90, x1=150+60=210, y0=225-35=190, y1=225+35=260
|
||||
assert result[0] == 90 # x0
|
||||
assert result[1] == 190 # y0
|
||||
assert result[2] == 210 # x1
|
||||
assert result[3] == 260 # y1
|
||||
|
||||
def test_no_scaling_returns_original(self):
|
||||
"""Verify scale=1.0 with no extras returns original bbox."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result == (100, 200, 200, 250)
|
||||
|
||||
|
||||
class TestExpandBboxDirectionalCompensation:
|
||||
"""Tests for directional compensation (extra ratios)."""
|
||||
|
||||
def test_extra_top_expands_upward(self):
|
||||
"""Verify extra_top_ratio adds expansion toward top."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Add 50% of height to top
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_top = 50 * 0.5 = 25
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 175 # y0 = 200 - 25
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_left_expands_leftward(self):
|
||||
"""Verify extra_left_ratio adds expansion toward left."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.8, # Add 80% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_left = 100 * 0.8 = 80
|
||||
assert result[0] == 20 # x0 = 100 - 80
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_right_expands_rightward(self):
|
||||
"""Verify extra_right_ratio adds expansion toward right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.3, # Add 30% of width to right
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_right = 100 * 0.3 = 30
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
assert result[3] == 250 # y1 unchanged
|
||||
|
||||
def test_extra_bottom_expands_downward(self):
|
||||
"""Verify extra_bottom_ratio adds expansion toward bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.4, # Add 40% of height to bottom
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# extra_bottom = 50 * 0.4 = 20
|
||||
assert result[0] == 100 # x0 unchanged
|
||||
assert result[1] == 200 # y0 unchanged
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_combined_scaling_and_directional(self):
|
||||
"""Verify scale + directional compensation work together."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.2, # 20% wider -> 120 width
|
||||
scale_y=1.0, # no height change
|
||||
extra_left_ratio=0.5, # Add 50% of width to left
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Center: x=150
|
||||
# After scale: width=120 -> x0=150-60=90, x1=150+60=210
|
||||
# After extra_left: x0 = 90 - (100 * 0.5) = 40
|
||||
assert result[0] == 40 # x0
|
||||
assert result[2] == 210 # x1
|
||||
|
||||
|
||||
class TestExpandBboxMaxPadClamping:
|
||||
"""Tests for max padding clamping."""
|
||||
|
||||
def test_max_pad_x_limits_horizontal_expansion(self):
|
||||
"""Verify max_pad_x limits expansion on left and right."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=2.0, # Double width (would add 50 each side)
|
||||
scale_y=1.0,
|
||||
max_pad_x=30, # Limit to 30 pixels each side
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: x0=100, x1=200 -> x0=50, x1=250 (50px each side)
|
||||
# But max_pad_x=30 limits to: x0=70, x1=230
|
||||
assert result[0] == 70 # x0 = 100 - 30
|
||||
assert result[2] == 230 # x1 = 200 + 30
|
||||
|
||||
def test_max_pad_y_limits_vertical_expansion(self):
|
||||
"""Verify max_pad_y limits expansion on top and bottom."""
|
||||
bbox = (100, 200, 200, 250) # height=50
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=3.0, # Triple height (would add 50 each side)
|
||||
max_pad_x=1000,
|
||||
max_pad_y=20, # Limit to 20 pixels each side
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Scale would make: y0=175, y1=275 (50px each side)
|
||||
# But max_pad_y=20 limits to: y0=180, y1=270
|
||||
assert result[1] == 180 # y0 = 200 - 20
|
||||
assert result[3] == 270 # y1 = 250 + 20
|
||||
|
||||
def test_max_pad_preserves_asymmetry(self):
|
||||
"""Verify max_pad clamping preserves asymmetric expansion."""
|
||||
bbox = (100, 200, 200, 250) # width=100
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=1.0, # 100px left expansion
|
||||
extra_right_ratio=0.0, # No right expansion
|
||||
max_pad_x=50, # Limit to 50 pixels
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
# Left would expand 100, clamped to 50
|
||||
# Right stays at 0
|
||||
assert result[0] == 50 # x0 = 100 - 50
|
||||
assert result[2] == 200 # x1 unchanged
|
||||
|
||||
|
||||
class TestExpandBboxImageBoundaryClamping:
|
||||
"""Tests for image boundary clamping."""
|
||||
|
||||
def test_clamps_to_left_boundary(self):
|
||||
"""Verify x0 is clamped to 0."""
|
||||
bbox = (10, 200, 110, 250) # Close to left edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_left_ratio=0.5, # Would push x0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[0] == 0 # Clamped to 0
|
||||
|
||||
def test_clamps_to_top_boundary(self):
|
||||
"""Verify y0 is clamped to 0."""
|
||||
bbox = (100, 10, 200, 60) # Close to top edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_top_ratio=0.5, # Would push y0 below 0
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[1] == 0 # Clamped to 0
|
||||
|
||||
def test_clamps_to_right_boundary(self):
|
||||
"""Verify x1 is clamped to image_width."""
|
||||
bbox = (900, 200, 990, 250) # Close to right edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_right_ratio=0.5, # Would push x1 beyond image_width
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[2] == 1000 # Clamped to image_width
|
||||
|
||||
def test_clamps_to_bottom_boundary(self):
|
||||
"""Verify y1 is clamped to image_height."""
|
||||
bbox = (100, 940, 200, 990) # Close to bottom edge
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.0,
|
||||
scale_y=1.0,
|
||||
extra_bottom_ratio=0.5, # Would push y1 beyond image_height
|
||||
max_pad_x=1000,
|
||||
max_pad_y=1000,
|
||||
)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test_field",
|
||||
strategies={"test_field": strategy},
|
||||
)
|
||||
|
||||
assert result[3] == 1000 # Clamped to image_height
|
||||
|
||||
|
||||
class TestExpandBboxUnknownField:
|
||||
"""Tests for unknown field handling."""
|
||||
|
||||
def test_unknown_field_uses_default_strategy(self):
|
||||
"""Verify unknown field types use DEFAULT_STRATEGY."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="unknown_field_xyz",
|
||||
)
|
||||
|
||||
# DEFAULT_STRATEGY: scale_x=1.15, scale_y=1.15
|
||||
# Original: width=100, height=50
|
||||
# New: width=115, height=57.5
|
||||
# Center: (150, 225)
|
||||
# x0 = 150 - 57.5 = 92.5 -> 92
|
||||
# x1 = 150 + 57.5 = 207.5 -> 207
|
||||
# y0 = 225 - 28.75 = 196.25 -> 196
|
||||
# y1 = 225 + 28.75 = 253.75 -> 253
|
||||
# But max_pad_x=50 may clamp...
|
||||
# Left pad = 100 - 92.5 = 7.5 (< 50, ok)
|
||||
# Right pad = 207.5 - 200 = 7.5 (< 50, ok)
|
||||
assert result[0] == 92
|
||||
assert result[2] == 207
|
||||
|
||||
|
||||
class TestExpandBboxWithRealStrategies:
|
||||
"""Tests using actual FIELD_SCALE_STRATEGIES."""
|
||||
|
||||
def test_ocr_number_expands_significantly_upward(self):
|
||||
"""Verify ocr_number field gets significant upward expansion."""
|
||||
bbox = (100, 200, 200, 230) # Small height=30
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
)
|
||||
|
||||
# extra_top_ratio=0.60 -> 30 * 0.6 = 18 extra top
|
||||
# y0 should decrease significantly
|
||||
assert result[1] < 200 - 10 # At least 10px upward expansion
|
||||
|
||||
def test_bankgiro_expands_significantly_leftward(self):
|
||||
"""Verify bankgiro field gets significant leftward expansion."""
|
||||
bbox = (200, 200, 300, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
)
|
||||
|
||||
# extra_left_ratio=0.80 -> 100 * 0.8 = 80 extra left
|
||||
# x0 should decrease significantly
|
||||
assert result[0] < 200 - 30 # At least 30px leftward expansion
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount field gets rightward expansion for currency."""
|
||||
bbox = (100, 200, 200, 230) # width=100
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="amount",
|
||||
)
|
||||
|
||||
# extra_right_ratio=0.30 -> 100 * 0.3 = 30 extra right
|
||||
# x1 should increase
|
||||
assert result[2] > 200 + 10 # At least 10px rightward expansion
|
||||
|
||||
|
||||
class TestExpandBboxReturnType:
|
||||
"""Tests for return type and value format."""
|
||||
|
||||
def test_returns_tuple_of_four_ints(self):
|
||||
"""Verify return type is tuple of 4 integers."""
|
||||
bbox = (100.5, 200.3, 200.7, 250.9)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 4
|
||||
assert all(isinstance(v, int) for v in result)
|
||||
|
||||
def test_returns_valid_bbox_format(self):
|
||||
"""Verify returned bbox has x0 < x1 and y0 < y1."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 < x1, "x0 should be less than x1"
|
||||
assert y0 < y1, "y0 should be less than y1"
|
||||
|
||||
|
||||
class TestManualLabelMode:
|
||||
"""Tests for manual_mode parameter."""
|
||||
|
||||
def test_manual_mode_uses_minimal_padding(self):
|
||||
"""Verify manual_mode uses MANUAL_LABEL_STRATEGY with minimal padding."""
|
||||
bbox = (100, 200, 200, 250) # width=100, height=50
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Would normally expand left significantly
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# MANUAL_LABEL_STRATEGY: scale=1.0, max_pad=10
|
||||
# Should only add 10px padding each side (but scale=1.0 means no scaling)
|
||||
# Actually with scale=1.0, no extra ratios, we get 0 expansion from scaling
|
||||
# Only max_pad=10 applies as a limit, but there's no expansion to limit
|
||||
# So result should be same as original
|
||||
assert result == (100, 200, 200, 250)
|
||||
|
||||
def test_manual_mode_ignores_field_type(self):
|
||||
"""Verify manual_mode ignores field-specific strategies."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
# Different fields should give same result in manual_mode
|
||||
result_bankgiro = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
result_ocr = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="ocr_number",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
assert result_bankgiro == result_ocr
|
||||
|
||||
def test_manual_mode_vs_auto_mode_different(self):
|
||||
"""Verify manual_mode produces different results than auto mode."""
|
||||
bbox = (100, 200, 200, 250)
|
||||
|
||||
auto_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro", # Has extra_left_ratio=0.80
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="bankgiro",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Auto mode should expand more (especially to the left for bankgiro)
|
||||
assert auto_result[0] < manual_result[0] # Auto x0 is more left
|
||||
|
||||
def test_manual_mode_clamps_to_image_bounds(self):
|
||||
"""Verify manual_mode still respects image boundaries."""
|
||||
bbox = (5, 5, 50, 50) # Close to top-left corner
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
field_type="test",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Should clamp to 0
|
||||
assert result[0] >= 0
|
||||
assert result[1] >= 0
|
||||
192
tests/shared/bbox/test_scale_strategy.py
Normal file
192
tests/shared/bbox/test_scale_strategy.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Tests for ScaleStrategy configuration.
|
||||
|
||||
Tests verify that scale strategies are properly defined, immutable,
|
||||
and cover all required fields.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from shared.bbox import (
|
||||
ScaleStrategy,
|
||||
DEFAULT_STRATEGY,
|
||||
MANUAL_LABEL_STRATEGY,
|
||||
FIELD_SCALE_STRATEGIES,
|
||||
)
|
||||
from shared.fields import CLASS_NAMES
|
||||
|
||||
|
||||
class TestScaleStrategyDataclass:
|
||||
"""Tests for ScaleStrategy dataclass behavior."""
|
||||
|
||||
def test_default_strategy_values(self):
|
||||
"""Verify default strategy has expected default values."""
|
||||
strategy = ScaleStrategy()
|
||||
assert strategy.scale_x == 1.15
|
||||
assert strategy.scale_y == 1.15
|
||||
assert strategy.extra_top_ratio == 0.0
|
||||
assert strategy.extra_bottom_ratio == 0.0
|
||||
assert strategy.extra_left_ratio == 0.0
|
||||
assert strategy.extra_right_ratio == 0.0
|
||||
assert strategy.max_pad_x == 50
|
||||
assert strategy.max_pad_y == 50
|
||||
|
||||
def test_scale_strategy_immutability(self):
|
||||
"""Verify ScaleStrategy is frozen (immutable)."""
|
||||
strategy = ScaleStrategy()
|
||||
with pytest.raises(AttributeError):
|
||||
strategy.scale_x = 2.0 # type: ignore
|
||||
|
||||
def test_custom_strategy_values(self):
|
||||
"""Verify custom values are properly set."""
|
||||
strategy = ScaleStrategy(
|
||||
scale_x=1.5,
|
||||
scale_y=1.8,
|
||||
extra_top_ratio=0.6,
|
||||
extra_left_ratio=0.8,
|
||||
max_pad_x=100,
|
||||
max_pad_y=150,
|
||||
)
|
||||
assert strategy.scale_x == 1.5
|
||||
assert strategy.scale_y == 1.8
|
||||
assert strategy.extra_top_ratio == 0.6
|
||||
assert strategy.extra_left_ratio == 0.8
|
||||
assert strategy.max_pad_x == 100
|
||||
assert strategy.max_pad_y == 150
|
||||
|
||||
|
||||
class TestDefaultStrategy:
|
||||
"""Tests for DEFAULT_STRATEGY constant."""
|
||||
|
||||
def test_default_strategy_is_scale_strategy(self):
|
||||
"""Verify DEFAULT_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(DEFAULT_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_default_strategy_matches_default_values(self):
|
||||
"""Verify DEFAULT_STRATEGY has same values as ScaleStrategy()."""
|
||||
expected = ScaleStrategy()
|
||||
assert DEFAULT_STRATEGY == expected
|
||||
|
||||
|
||||
class TestManualLabelStrategy:
|
||||
"""Tests for MANUAL_LABEL_STRATEGY constant."""
|
||||
|
||||
def test_manual_label_strategy_is_scale_strategy(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY is a ScaleStrategy instance."""
|
||||
assert isinstance(MANUAL_LABEL_STRATEGY, ScaleStrategy)
|
||||
|
||||
def test_manual_label_strategy_has_no_scaling(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has scale factors of 1.0."""
|
||||
assert MANUAL_LABEL_STRATEGY.scale_x == 1.0
|
||||
assert MANUAL_LABEL_STRATEGY.scale_y == 1.0
|
||||
|
||||
def test_manual_label_strategy_has_no_directional_expansion(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has no directional expansion."""
|
||||
assert MANUAL_LABEL_STRATEGY.extra_top_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_bottom_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_left_ratio == 0.0
|
||||
assert MANUAL_LABEL_STRATEGY.extra_right_ratio == 0.0
|
||||
|
||||
def test_manual_label_strategy_has_small_max_pad(self):
|
||||
"""Verify MANUAL_LABEL_STRATEGY has small max padding."""
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_x <= 15
|
||||
assert MANUAL_LABEL_STRATEGY.max_pad_y <= 15
|
||||
|
||||
|
||||
class TestFieldScaleStrategies:
|
||||
"""Tests for FIELD_SCALE_STRATEGIES dictionary."""
|
||||
|
||||
def test_all_class_names_have_strategies(self):
|
||||
"""Verify all field class names have defined strategies."""
|
||||
for class_name in CLASS_NAMES:
|
||||
assert class_name in FIELD_SCALE_STRATEGIES, (
|
||||
f"Missing strategy for field: {class_name}"
|
||||
)
|
||||
|
||||
def test_strategies_are_scale_strategy_instances(self):
|
||||
"""Verify all strategies are ScaleStrategy instances."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert isinstance(strategy, ScaleStrategy), (
|
||||
f"Strategy for {field_name} is not a ScaleStrategy"
|
||||
)
|
||||
|
||||
def test_scale_values_are_greater_than_one(self):
|
||||
"""Verify all scale values are >= 1.0 (expansion, not contraction)."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.scale_x >= 1.0, (
|
||||
f"{field_name} scale_x should be >= 1.0"
|
||||
)
|
||||
assert strategy.scale_y >= 1.0, (
|
||||
f"{field_name} scale_y should be >= 1.0"
|
||||
)
|
||||
|
||||
def test_extra_ratios_are_non_negative(self):
|
||||
"""Verify all extra ratios are >= 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.extra_top_ratio >= 0, (
|
||||
f"{field_name} extra_top_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_bottom_ratio >= 0, (
|
||||
f"{field_name} extra_bottom_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_left_ratio >= 0, (
|
||||
f"{field_name} extra_left_ratio should be >= 0"
|
||||
)
|
||||
assert strategy.extra_right_ratio >= 0, (
|
||||
f"{field_name} extra_right_ratio should be >= 0"
|
||||
)
|
||||
|
||||
def test_max_pad_values_are_positive(self):
|
||||
"""Verify all max_pad values are > 0."""
|
||||
for field_name, strategy in FIELD_SCALE_STRATEGIES.items():
|
||||
assert strategy.max_pad_x > 0, (
|
||||
f"{field_name} max_pad_x should be > 0"
|
||||
)
|
||||
assert strategy.max_pad_y > 0, (
|
||||
f"{field_name} max_pad_y should be > 0"
|
||||
)
|
||||
|
||||
|
||||
class TestSpecificFieldStrategies:
|
||||
"""Tests for specific field strategy configurations."""
|
||||
|
||||
def test_ocr_number_expands_upward(self):
|
||||
"""Verify ocr_number strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["ocr_number"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_top_ratio >= 0.5 # Significant upward expansion
|
||||
|
||||
def test_bankgiro_expands_leftward(self):
|
||||
"""Verify bankgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["bankgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5 # Significant leftward expansion
|
||||
|
||||
def test_plusgiro_expands_leftward(self):
|
||||
"""Verify plusgiro strategy expands leftward to capture prefix."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["plusgiro"]
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
assert strategy.extra_left_ratio >= 0.5
|
||||
|
||||
def test_amount_expands_rightward(self):
|
||||
"""Verify amount strategy expands rightward for currency symbol."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["amount"]
|
||||
assert strategy.extra_right_ratio > 0.0
|
||||
|
||||
def test_invoice_date_expands_upward(self):
|
||||
"""Verify invoice_date strategy expands upward to capture label."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
|
||||
def test_invoice_due_date_expands_upward_and_leftward(self):
|
||||
"""Verify invoice_due_date strategy expands both up and left."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["invoice_due_date"]
|
||||
assert strategy.extra_top_ratio > 0.0
|
||||
assert strategy.extra_left_ratio > 0.0
|
||||
|
||||
def test_payment_line_has_minimal_expansion(self):
|
||||
"""Verify payment_line has conservative expansion (machine code)."""
|
||||
strategy = FIELD_SCALE_STRATEGIES["payment_line"]
|
||||
# Payment line is machine-readable, needs minimal expansion
|
||||
assert strategy.scale_x <= 1.2
|
||||
assert strategy.scale_y <= 1.3
|
||||
@@ -16,6 +16,7 @@ from shared.fields import (
|
||||
FIELD_CLASSES,
|
||||
FIELD_CLASS_IDS,
|
||||
CLASS_TO_FIELD,
|
||||
FIELD_TO_CLASS,
|
||||
CSV_TO_CLASS_MAPPING,
|
||||
TRAINING_FIELD_CLASSES,
|
||||
NUM_CLASSES,
|
||||
@@ -133,6 +134,20 @@ class TestMappingConsistency:
|
||||
assert fd.field_name in TRAINING_FIELD_CLASSES
|
||||
assert TRAINING_FIELD_CLASSES[fd.field_name] == fd.class_id
|
||||
|
||||
def test_field_to_class_is_inverse_of_class_to_field(self):
|
||||
"""Verify FIELD_TO_CLASS and CLASS_TO_FIELD are proper inverses."""
|
||||
for class_name, field_name in CLASS_TO_FIELD.items():
|
||||
assert FIELD_TO_CLASS[field_name] == class_name
|
||||
|
||||
for field_name, class_name in FIELD_TO_CLASS.items():
|
||||
assert CLASS_TO_FIELD[class_name] == field_name
|
||||
|
||||
def test_field_to_class_has_all_fields(self):
|
||||
"""Verify FIELD_TO_CLASS has mapping for all field names."""
|
||||
for fd in FIELD_DEFINITIONS:
|
||||
assert fd.field_name in FIELD_TO_CLASS
|
||||
assert FIELD_TO_CLASS[fd.field_name] == fd.class_name
|
||||
|
||||
|
||||
class TestSpecificFieldDefinitions:
|
||||
"""Tests for specific field definitions to catch common mistakes."""
|
||||
|
||||
@@ -272,12 +272,12 @@ class TestLineItemsExtractorFromPdf:
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create mock table detection result
|
||||
# Create mock table detection result with proper thead/tbody structure
|
||||
mock_table = MagicMock(spec=TableDetectionResult)
|
||||
mock_table.html = """
|
||||
<table>
|
||||
<tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr>
|
||||
<tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr>
|
||||
<thead><tr><th>Beskrivning</th><th>Antal</th><th>Pris</th><th>Belopp</th></tr></thead>
|
||||
<tbody><tr><td>Product A</td><td>2</td><td>100,00</td><td>200,00</td></tr></tbody>
|
||||
</table>
|
||||
"""
|
||||
|
||||
@@ -291,6 +291,78 @@ class TestLineItemsExtractorFromPdf:
|
||||
assert len(result.items) >= 1
|
||||
|
||||
|
||||
class TestPdfPathValidation:
|
||||
"""Tests for PDF path validation."""
|
||||
|
||||
def test_detect_tables_with_nonexistent_path(self):
|
||||
"""Test that non-existent PDF path returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create detector and call _detect_tables_with_parsing with non-existent path
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, "nonexistent.pdf"
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_with_directory_path(self, tmp_path):
|
||||
"""Test that directory path (not file) returns empty results."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
|
||||
# tmp_path is a directory, not a file
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(tmp_path)
|
||||
)
|
||||
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
def test_detect_tables_validates_file_exists(self, tmp_path):
|
||||
"""Test path validation for file existence.
|
||||
|
||||
This test verifies that the method correctly validates the path exists
|
||||
and is a file before attempting to process it.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Create a real file path that exists
|
||||
fake_pdf = tmp_path / "test.pdf"
|
||||
fake_pdf.write_bytes(b"not a real pdf")
|
||||
|
||||
# Mock render_pdf_to_images to avoid actual PDF processing
|
||||
with patch("shared.pdf.renderer.render_pdf_to_images") as mock_render:
|
||||
# Return empty iterator - simulates file exists but no pages
|
||||
mock_render.return_value = iter([])
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from backend.table.structure_detector import TableDetector
|
||||
|
||||
mock_detector = MagicMock(spec=TableDetector)
|
||||
mock_detector._ensure_initialized = MagicMock()
|
||||
mock_detector._pipeline = MagicMock()
|
||||
|
||||
tables, parsing_res = extractor._detect_tables_with_parsing(
|
||||
mock_detector, str(fake_pdf)
|
||||
)
|
||||
|
||||
# render_pdf_to_images was called (path validation passed)
|
||||
mock_render.assert_called_once()
|
||||
assert tables == []
|
||||
assert parsing_res == []
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
@@ -462,3 +534,246 @@ class TestMergedCellExtraction:
|
||||
assert result.items[0].is_deduction is False
|
||||
assert result.items[1].amount == "-2000"
|
||||
assert result.items[1].is_deduction is True
|
||||
|
||||
|
||||
class TestTextFallbackExtraction:
|
||||
"""Tests for text-based fallback extraction."""
|
||||
|
||||
def test_text_fallback_disabled_by_default(self):
|
||||
"""Test text fallback can be disabled."""
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
assert extractor.enable_text_fallback is False
|
||||
|
||||
def test_text_fallback_enabled_by_default(self):
|
||||
"""Test text fallback is enabled by default."""
|
||||
extractor = LineItemsExtractor()
|
||||
assert extractor.enable_text_fallback is True
|
||||
|
||||
def test_try_text_fallback_with_valid_parsing_res(self):
|
||||
"""Test text fallback with valid parsing results."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import (
|
||||
TextLineItemsExtractor,
|
||||
TextLineItem,
|
||||
TextLineItemsResult,
|
||||
)
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Mock parsing_res_list with text elements
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Product A"},
|
||||
{"label": "text", "bbox": [250, 100, 350, 120], "text": "1 234,56"},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Product B"},
|
||||
{"label": "text", "bbox": [250, 150, 350, 170], "text": "2 345,67"},
|
||||
]
|
||||
|
||||
# Create mock text extraction result
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product A", amount="1 234,56"),
|
||||
TextLineItem(row_index=1, description="Product B", amount="2 345,67"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(TextLineItemsExtractor, 'extract_from_parsing_res', return_value=mock_text_result):
|
||||
result = extractor._try_text_fallback(parsing_res)
|
||||
|
||||
assert result is not None
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].description == "Product A"
|
||||
assert result.items[1].description == "Product B"
|
||||
|
||||
def test_try_text_fallback_returns_none_on_failure(self):
|
||||
"""Test text fallback returns None when extraction fails."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
with patch('backend.table.text_line_items_extractor.TextLineItemsExtractor.extract_from_parsing_res', return_value=None):
|
||||
result = extractor._try_text_fallback([])
|
||||
assert result is None
|
||||
|
||||
def test_extract_from_pdf_uses_text_fallback(self):
|
||||
"""Test extract_from_pdf uses text fallback when no tables found."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from backend.table.text_line_items_extractor import TextLineItem, TextLineItemsResult
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=True)
|
||||
|
||||
# Mock _detect_tables_with_parsing to return no tables but parsing_res
|
||||
mock_text_result = TextLineItemsResult(
|
||||
items=[
|
||||
TextLineItem(row_index=0, description="Product", amount="100,00"),
|
||||
TextLineItem(row_index=1, description="Product 2", amount="200,00"),
|
||||
],
|
||||
header_row=[],
|
||||
)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
with patch.object(extractor, '_try_text_fallback', return_value=MagicMock(items=[MagicMock()])) as mock_fallback:
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Text fallback should be called
|
||||
mock_fallback.assert_called_once()
|
||||
|
||||
def test_extract_from_pdf_skips_fallback_when_disabled(self):
|
||||
"""Test extract_from_pdf skips text fallback when disabled."""
|
||||
from unittest.mock import patch
|
||||
|
||||
extractor = LineItemsExtractor(enable_text_fallback=False)
|
||||
|
||||
with patch.object(extractor, '_detect_tables_with_parsing') as mock_detect:
|
||||
mock_detect.return_value = ([], [{"label": "text", "text": "test"}])
|
||||
|
||||
result = extractor.extract_from_pdf("fake.pdf")
|
||||
|
||||
# Should return None, not use text fallback
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestVerticallyMergedCellExtraction:
|
||||
"""Tests for vertically merged cell extraction."""
|
||||
|
||||
def test_detects_vertically_merged_cells(self):
|
||||
"""Test detection of vertically merged cells in rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
# Rows with multiple product numbers in single cell
|
||||
rows = [["Produktnr 1457280 1457281 1060381 merged text here"]]
|
||||
assert extractor._has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_splits_vertically_merged_rows(self):
|
||||
"""Test splitting vertically merged rows."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = extractor._split_merged_rows(rows)
|
||||
|
||||
# Should split into header + data rows
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestDeductionDetection:
|
||||
"""Tests for deduction/discount detection."""
|
||||
|
||||
def test_detects_deduction_by_keyword_avdrag(self):
|
||||
"""Test detection of deduction by 'avdrag' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Hyresavdrag januari</td><td>-500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_keyword_rabatt(self):
|
||||
"""Test detection of deduction by 'rabatt' keyword."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Rabatt 10%</td><td>-100,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_detects_deduction_by_negative_amount(self):
|
||||
"""Test detection of deduction by negative amount."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Some credit</td><td>-250,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is True
|
||||
|
||||
def test_normal_item_not_deduction(self):
|
||||
"""Test normal item is not marked as deduction."""
|
||||
html = """
|
||||
<html><body><table>
|
||||
<thead><tr><th>Beskrivning</th><th>Belopp</th></tr></thead>
|
||||
<tbody>
|
||||
<tr><td>Normal product</td><td>500,00</td></tr>
|
||||
</tbody>
|
||||
</table></body></html>
|
||||
"""
|
||||
extractor = LineItemsExtractor()
|
||||
result = extractor.extract(html)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].is_deduction is False
|
||||
|
||||
|
||||
class TestHeaderDetection:
|
||||
"""Tests for header row detection."""
|
||||
|
||||
def test_detect_header_at_bottom(self):
|
||||
"""Test detecting header at bottom of table (reversed)."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at bottom
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 2
|
||||
assert is_at_end is True
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_detect_header_at_top(self):
|
||||
"""Test detecting header at top of table."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["Belopp", "Beskrivning", "Antal"], # Header at top
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == 0
|
||||
assert is_at_end is False
|
||||
assert "Belopp" in header
|
||||
|
||||
def test_no_header_detected(self):
|
||||
"""Test when no header is detected."""
|
||||
extractor = LineItemsExtractor()
|
||||
|
||||
rows = [
|
||||
["100,00", "Product A", "1"],
|
||||
["200,00", "Product B", "2"],
|
||||
]
|
||||
|
||||
header_idx, header, is_at_end = extractor._detect_header_row(rows)
|
||||
|
||||
assert header_idx == -1
|
||||
assert header == []
|
||||
assert is_at_end is False
|
||||
|
||||
448
tests/table/test_merged_cell_handler.py
Normal file
448
tests/table/test_merged_cell_handler.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""
|
||||
Tests for Merged Cell Handler
|
||||
|
||||
Tests the detection and extraction of data from tables with merged cells,
|
||||
a common issue with PP-StructureV3 OCR output.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.merged_cell_handler import MergedCellHandler, MIN_AMOUNT_THRESHOLD
|
||||
from backend.table.html_table_parser import ColumnMapper
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def handler():
|
||||
"""Create a MergedCellHandler with default ColumnMapper."""
|
||||
return MergedCellHandler(ColumnMapper())
|
||||
|
||||
|
||||
class TestHasVerticallyMergedCells:
|
||||
"""Tests for has_vertically_merged_cells detection."""
|
||||
|
||||
def test_empty_rows_returns_false(self, handler):
|
||||
"""Test empty rows returns False."""
|
||||
assert handler.has_vertically_merged_cells([]) is False
|
||||
|
||||
def test_short_cells_ignored(self, handler):
|
||||
"""Test cells shorter than 20 chars are ignored."""
|
||||
rows = [["Short cell", "Also short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_product_numbers(self, handler):
|
||||
"""Test detection of multiple 7-digit product numbers in cell."""
|
||||
rows = [["Produktnr 1457280 1457281 1060381 and more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_product_number_not_merged(self, handler):
|
||||
"""Test single product number doesn't trigger detection."""
|
||||
rows = [["Produktnr 1457280 and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_prices(self, handler):
|
||||
"""Test detection of 3+ prices in cell (Swedish format)."""
|
||||
rows = [["Pris 127,20 234,56 159,20 total amounts"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_two_prices_not_merged(self, handler):
|
||||
"""Test two prices doesn't trigger detection (needs 3+)."""
|
||||
rows = [["Pris 127,20 234,56 total amount here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_detects_multiple_quantities(self, handler):
|
||||
"""Test detection of multiple quantity patterns."""
|
||||
rows = [["Antal 6ST 6ST 1ST more text here"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
def test_single_quantity_not_merged(self, handler):
|
||||
"""Test single quantity doesn't trigger detection."""
|
||||
rows = [["Antal 6ST and more text here for length"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_empty_cell_skipped(self, handler):
|
||||
"""Test empty cells are skipped."""
|
||||
rows = [["", None, "Valid but short"]]
|
||||
assert handler.has_vertically_merged_cells(rows) is False
|
||||
|
||||
def test_multiple_rows_checked(self, handler):
|
||||
"""Test all rows are checked for merged content."""
|
||||
rows = [
|
||||
["Normal row with nothing special"],
|
||||
["Produktnr 1457280 1457281 1060381 merged content"],
|
||||
]
|
||||
assert handler.has_vertically_merged_cells(rows) is True
|
||||
|
||||
|
||||
class TestSplitMergedRows:
|
||||
"""Tests for split_merged_rows method."""
|
||||
|
||||
def test_empty_rows_returns_empty(self, handler):
|
||||
"""Test empty rows returns empty result."""
|
||||
header, data = handler.split_merged_rows([])
|
||||
assert header == []
|
||||
assert data == []
|
||||
|
||||
def test_all_empty_rows_returns_original(self, handler):
|
||||
"""Test all empty rows returns original rows."""
|
||||
rows = [["", ""], ["", ""]]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting rows by product numbers."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", "Antal 2ST 3ST", "Pris 100,00 200,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
assert len(header) == 3
|
||||
assert header[0] == "Produktnr"
|
||||
assert len(data) == 2
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting rows by quantity patterns."""
|
||||
rows = [
|
||||
["Description text", "Antal 5ST 10ST", "Belopp 500,00 1000,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should detect 2 quantities and split accordingly
|
||||
assert len(data) >= 1
|
||||
|
||||
def test_single_row_not_split(self, handler):
|
||||
"""Test single item row is not split."""
|
||||
rows = [
|
||||
["Produktnr 1234567", "Antal 2ST", "Pris 100,00"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Only 1 product number, so expected_rows <= 1
|
||||
assert header == []
|
||||
assert data == rows
|
||||
|
||||
def test_handles_missing_columns(self, handler):
|
||||
"""Test handles rows with different column counts."""
|
||||
rows = [
|
||||
["Produktnr 1234567 1234568", ""],
|
||||
["Antal 2ST 3ST"],
|
||||
]
|
||||
header, data = handler.split_merged_rows(rows)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(header, list)
|
||||
assert isinstance(data, list)
|
||||
|
||||
|
||||
class TestCountExpectedRows:
|
||||
"""Tests for _count_expected_rows helper."""
|
||||
|
||||
def test_counts_product_numbers(self, handler):
|
||||
"""Test counting product numbers."""
|
||||
columns = ["Produktnr 1234567 1234568 1234569", "Other"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_counts_quantities(self, handler):
|
||||
"""Test counting quantity patterns."""
|
||||
columns = ["Nothing here", "Antal 5ST 10ST 15ST 20ST"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 4
|
||||
|
||||
def test_returns_max_count(self, handler):
|
||||
"""Test returns maximum count across columns."""
|
||||
columns = [
|
||||
"Produktnr 1234567 1234568", # 2 products
|
||||
"Antal 5ST 10ST 15ST", # 3 quantities
|
||||
]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 3
|
||||
|
||||
def test_empty_columns_return_zero(self, handler):
|
||||
"""Test empty columns return 0."""
|
||||
columns = ["", None, "Short"]
|
||||
count = handler._count_expected_rows(columns)
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestSplitCellContentForRows:
|
||||
"""Tests for _split_cell_content_for_rows helper."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by product numbers with expected count."""
|
||||
cell = "Produktnr 1234567 1234568"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result[1]
|
||||
assert "1234568" in result[2]
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by quantity patterns."""
|
||||
cell = "Antal 5ST 10ST"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3 # header + 2 values
|
||||
assert result[0] == "Antal"
|
||||
|
||||
def test_splits_discount_totalsumma(self, handler):
|
||||
"""Test splitting discount+totalsumma columns."""
|
||||
cell = "Rabatt i% Totalsumma 686,88 123,45"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result[1]
|
||||
assert "123,45" in result[2]
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by price patterns."""
|
||||
cell = "Pris 127,20 234,56"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) >= 2
|
||||
|
||||
def test_fallback_returns_original(self, handler):
|
||||
"""Test fallback returns original cell."""
|
||||
cell = "No patterns here"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert result == ["No patterns here"]
|
||||
|
||||
def test_product_number_with_description(self, handler):
|
||||
"""Test product numbers include trailing description text."""
|
||||
cell = "Art 1234567 Widget A 1234568 Widget B"
|
||||
result = handler._split_cell_content_for_rows(cell, 2)
|
||||
|
||||
assert len(result) == 3
|
||||
|
||||
|
||||
class TestSplitCellContent:
|
||||
"""Tests for split_cell_content method."""
|
||||
|
||||
def test_splits_by_product_numbers(self, handler):
|
||||
"""Test splitting by multiple product numbers."""
|
||||
cell = "Produktnr 1234567 1234568 1234569"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Produktnr"
|
||||
assert "1234567" in result
|
||||
assert "1234568" in result
|
||||
assert "1234569" in result
|
||||
|
||||
def test_splits_by_quantities(self, handler):
|
||||
"""Test splitting by multiple quantities."""
|
||||
cell = "Antal 6ST 6ST 1ST"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Antal"
|
||||
assert len(result) >= 3
|
||||
|
||||
def test_splits_discount_amount_interleaved(self, handler):
|
||||
"""Test splitting interleaved discount+amount patterns."""
|
||||
cell = "Rabatt i% Totalsumma 10,0 686,88 10,0 123,45"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
# Should extract amounts (3+ digit numbers with decimals)
|
||||
assert result[0] == "Totalsumma"
|
||||
assert "686,88" in result
|
||||
assert "123,45" in result
|
||||
|
||||
def test_splits_by_prices(self, handler):
|
||||
"""Test splitting by prices."""
|
||||
cell = "Pris 127,20 127,20 159,20"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result[0] == "Pris"
|
||||
|
||||
def test_single_value_not_split(self, handler):
|
||||
"""Test single value is not split."""
|
||||
cell = "Single value"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Single value"]
|
||||
|
||||
def test_single_product_not_split(self, handler):
|
||||
"""Test single product number is not split."""
|
||||
cell = "Produktnr 1234567"
|
||||
result = handler.split_cell_content(cell)
|
||||
|
||||
assert result == ["Produktnr 1234567"]
|
||||
|
||||
|
||||
class TestHasMergedHeader:
|
||||
"""Tests for has_merged_header method."""
|
||||
|
||||
def test_none_header_returns_false(self, handler):
|
||||
"""Test None header returns False."""
|
||||
assert handler.has_merged_header(None) is False
|
||||
|
||||
def test_empty_header_returns_false(self, handler):
|
||||
"""Test empty header returns False."""
|
||||
assert handler.has_merged_header([]) is False
|
||||
|
||||
def test_multiple_non_empty_cells_returns_false(self, handler):
|
||||
"""Test multiple non-empty cells returns False."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_single_cell_with_keywords_returns_true(self, handler):
|
||||
"""Test single cell with multiple keywords returns True."""
|
||||
header = ["Specifikation 0218103-1201 rum och kök Hyra Avdrag"]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
def test_single_cell_one_keyword_returns_false(self, handler):
|
||||
"""Test single cell with only one keyword returns False."""
|
||||
header = ["Beskrivning only"]
|
||||
assert handler.has_merged_header(header) is False
|
||||
|
||||
def test_ignores_empty_trailing_cells(self, handler):
|
||||
"""Test ignores empty trailing cells."""
|
||||
header = ["Specifikation Hyra Avdrag", "", "", ""]
|
||||
assert handler.has_merged_header(header) is True
|
||||
|
||||
|
||||
class TestExtractFromMergedCells:
|
||||
"""Tests for extract_from_merged_cells method."""
|
||||
|
||||
def test_extracts_single_amount(self, handler):
|
||||
"""Test extracting a single amount."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
assert items[0].is_deduction is False
|
||||
assert items[0].article_number == "0218103-1201"
|
||||
assert items[0].description == "2 rum och kök"
|
||||
|
||||
def test_extracts_deduction(self, handler):
|
||||
"""Test extracting a deduction (negative amount)."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "-2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "-2000"
|
||||
assert items[0].is_deduction is True
|
||||
# First item (row_index=0) gets description from header, not "Avdrag"
|
||||
# "Avdrag" is only set for subsequent deduction items
|
||||
assert items[0].description is None
|
||||
|
||||
def test_extracts_multiple_amounts_same_row(self, handler):
|
||||
"""Test extracting multiple amounts from same row."""
|
||||
header = ["Specifikation 0218103-1201 2 rum och kök Hyra Avdrag"]
|
||||
rows = [["", "", "", "8159 -2000"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0].amount == "8159"
|
||||
assert items[1].amount == "-2000"
|
||||
|
||||
def test_extracts_amounts_from_multiple_rows(self, handler):
|
||||
"""Test extracting amounts from multiple rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 2
|
||||
|
||||
def test_skips_small_amounts(self, handler):
|
||||
"""Test skipping small amounts below threshold."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "50"]] # Below MIN_AMOUNT_THRESHOLD (100)
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_skips_empty_rows(self, handler):
|
||||
"""Test skipping empty rows."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", ""]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
|
||||
def test_handles_swedish_format_with_spaces(self, handler):
|
||||
"""Test handling Swedish number format with spaces."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8 159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "8159"
|
||||
|
||||
def test_confidence_is_lower_for_merged(self, handler):
|
||||
"""Test confidence is 0.7 for merged cell extraction."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert items[0].confidence == 0.7
|
||||
|
||||
def test_empty_header_still_extracts(self, handler):
|
||||
"""Test extraction works with empty header."""
|
||||
header = []
|
||||
rows = [["", "", "", "8159"]]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].description is None
|
||||
assert items[0].article_number is None
|
||||
|
||||
def test_row_index_increments(self, handler):
|
||||
"""Test row_index increments for each item."""
|
||||
header = ["Specifikation"]
|
||||
# Use separate rows to avoid regex grouping issues
|
||||
rows = [
|
||||
["", "", "", "8159"],
|
||||
["", "", "", "5000"],
|
||||
["", "", "", "-2000"],
|
||||
]
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
# Should have 3 items from 3 rows
|
||||
assert len(items) == 3
|
||||
assert items[0].row_index == 0
|
||||
assert items[1].row_index == 1
|
||||
assert items[2].row_index == 2
|
||||
|
||||
|
||||
class TestMinAmountThreshold:
|
||||
"""Tests for MIN_AMOUNT_THRESHOLD constant."""
|
||||
|
||||
def test_threshold_value(self):
|
||||
"""Test the threshold constant value."""
|
||||
assert MIN_AMOUNT_THRESHOLD == 100
|
||||
|
||||
def test_amounts_at_threshold_included(self, handler):
|
||||
"""Test amounts exactly at threshold are included."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "100"]] # Exactly at threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0].amount == "100"
|
||||
|
||||
def test_amounts_below_threshold_excluded(self, handler):
|
||||
"""Test amounts below threshold are excluded."""
|
||||
header = ["Specifikation"]
|
||||
rows = [["", "", "", "99"]] # Below threshold
|
||||
|
||||
items = handler.extract_from_merged_cells(header, rows)
|
||||
|
||||
assert len(items) == 0
|
||||
157
tests/table/test_models.py
Normal file
157
tests/table/test_models.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Tests for Line Items Data Models
|
||||
|
||||
Tests for LineItem and LineItemsResult dataclasses.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from backend.table.models import LineItem, LineItemsResult
|
||||
|
||||
|
||||
class TestLineItem:
|
||||
"""Tests for LineItem dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values for optional fields."""
|
||||
item = LineItem(row_index=0)
|
||||
|
||||
assert item.row_index == 0
|
||||
assert item.description is None
|
||||
assert item.quantity is None
|
||||
assert item.unit is None
|
||||
assert item.unit_price is None
|
||||
assert item.amount is None
|
||||
assert item.article_number is None
|
||||
assert item.vat_rate is None
|
||||
assert item.is_deduction is False
|
||||
assert item.confidence == 0.9
|
||||
|
||||
def test_custom_confidence(self):
|
||||
"""Test setting custom confidence."""
|
||||
item = LineItem(row_index=0, confidence=0.7)
|
||||
assert item.confidence == 0.7
|
||||
|
||||
def test_is_deduction_true(self):
|
||||
"""Test is_deduction flag."""
|
||||
item = LineItem(row_index=0, is_deduction=True)
|
||||
assert item.is_deduction is True
|
||||
|
||||
|
||||
class TestLineItemsResult:
|
||||
"""Tests for LineItemsResult dataclass."""
|
||||
|
||||
def test_total_amount_empty_items(self):
|
||||
"""Test total_amount returns None for empty items."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_single_item(self):
|
||||
"""Test total_amount with single item."""
|
||||
items = [LineItem(row_index=0, amount="100,00")]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_multiple_items(self):
|
||||
"""Test total_amount with multiple items."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="200,50"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "300,50"
|
||||
|
||||
def test_total_amount_with_deduction(self):
|
||||
"""Test total_amount includes negative amounts (deductions)."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1000,00"),
|
||||
LineItem(row_index=1, amount="-200,00", is_deduction=True),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "800,00"
|
||||
|
||||
def test_total_amount_swedish_format_with_spaces(self):
|
||||
"""Test total_amount handles Swedish format with spaces."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="1 234,56"),
|
||||
LineItem(row_index=1, amount="2 000,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "3 234,56"
|
||||
|
||||
def test_total_amount_invalid_amount_skipped(self):
|
||||
"""Test total_amount skips invalid amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount="invalid"),
|
||||
LineItem(row_index=2, amount="200,00"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
# Invalid amount is skipped
|
||||
assert result.total_amount == "300,00"
|
||||
|
||||
def test_total_amount_none_amount_skipped(self):
|
||||
"""Test total_amount skips None amounts."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="100,00"),
|
||||
LineItem(row_index=1, amount=None),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "100,00"
|
||||
|
||||
def test_total_amount_all_invalid_returns_none(self):
|
||||
"""Test total_amount returns None when all amounts are invalid."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="invalid"),
|
||||
LineItem(row_index=1, amount="also invalid"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount is None
|
||||
|
||||
def test_total_amount_large_numbers(self):
|
||||
"""Test total_amount handles large numbers."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="123 456,78"),
|
||||
LineItem(row_index=1, amount="876 543,22"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "1 000 000,00"
|
||||
|
||||
def test_total_amount_decimal_precision(self):
|
||||
"""Test total_amount maintains decimal precision."""
|
||||
items = [
|
||||
LineItem(row_index=0, amount="0,01"),
|
||||
LineItem(row_index=1, amount="0,02"),
|
||||
]
|
||||
result = LineItemsResult(items=items, header_row=[], raw_html="")
|
||||
|
||||
assert result.total_amount == "0,03"
|
||||
|
||||
def test_is_reversed_default_false(self):
|
||||
"""Test is_reversed defaults to False."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="")
|
||||
assert result.is_reversed is False
|
||||
|
||||
def test_is_reversed_can_be_set(self):
|
||||
"""Test is_reversed can be set to True."""
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html="", is_reversed=True)
|
||||
assert result.is_reversed is True
|
||||
|
||||
def test_header_row_preserved(self):
|
||||
"""Test header_row is preserved."""
|
||||
header = ["Beskrivning", "Antal", "Belopp"]
|
||||
result = LineItemsResult(items=[], header_row=header, raw_html="")
|
||||
assert result.header_row == header
|
||||
|
||||
def test_raw_html_preserved(self):
|
||||
"""Test raw_html is preserved."""
|
||||
html = "<table><tr><td>Test</td></tr></table>"
|
||||
result = LineItemsResult(items=[], header_row=[], raw_html=html)
|
||||
assert result.raw_html == html
|
||||
@@ -658,3 +658,245 @@ class TestPaddleX3xAPI:
|
||||
assert len(results) == 1
|
||||
assert results[0].cells == [] # Empty cells list
|
||||
assert results[0].html == "<table></table>"
|
||||
|
||||
def test_parse_paddlex_result_with_dict_ocr_data(self):
|
||||
"""Test parsing PaddleX 3.x result with dict-format table_ocr_pred."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20], [50, 0, 100, 20]],
|
||||
"pred_html": "<table><tr><td>A</td><td>B</td></tr></table>",
|
||||
"table_ocr_pred": {
|
||||
"rec_texts": ["A", "B"],
|
||||
"rec_scores": [0.99, 0.98],
|
||||
},
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "table", "bbox": [10, 20, 200, 300]},
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert len(results[0].cells) == 2
|
||||
assert results[0].cells[0]["text"] == "A"
|
||||
assert results[0].cells[1]["text"] == "B"
|
||||
|
||||
def test_parse_paddlex_result_no_bbox_in_parsing_res(self):
|
||||
"""Test parsing PaddleX 3.x result when table bbox not in parsing_res."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
mock_result = {
|
||||
"table_res_list": [
|
||||
{
|
||||
"cell_box_list": [[0, 0, 50, 20]],
|
||||
"pred_html": "<table><tr><td>A</td></tr></table>",
|
||||
"table_ocr_pred": ["A"],
|
||||
}
|
||||
],
|
||||
"parsing_res_list": [
|
||||
{"label": "text", "bbox": [10, 20, 200, 300]}, # Not a table
|
||||
],
|
||||
}
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
# Should use default bbox [0,0,0,0] when not found
|
||||
assert results[0].bbox == (0.0, 0.0, 0.0, 0.0)
|
||||
|
||||
|
||||
class TestIteratorResults:
|
||||
"""Tests for iterator/generator result handling."""
|
||||
|
||||
def test_handles_iterator_results(self):
|
||||
"""Test handling of iterator results from pipeline."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Return a generator instead of list
|
||||
def result_generator():
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.html = "<table></table>"
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
yield mock_result
|
||||
|
||||
mock_pipeline.predict.return_value = result_generator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
def test_handles_failed_iterator_conversion(self):
|
||||
"""Test handling when iterator conversion fails."""
|
||||
mock_pipeline = MagicMock()
|
||||
|
||||
# Create an object that has __iter__ but fails when converted to list
|
||||
class FailingIterator:
|
||||
def __iter__(self):
|
||||
raise RuntimeError("Iterator failed")
|
||||
|
||||
mock_pipeline.predict.return_value = FailingIterator()
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
# Should return empty list, not raise
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestPathConversion:
|
||||
"""Tests for path handling."""
|
||||
|
||||
def test_converts_path_object_to_string(self):
|
||||
"""Test that Path objects are converted to strings."""
|
||||
from pathlib import Path
|
||||
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.predict.return_value = []
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
path = Path("/some/path/to/image.png")
|
||||
|
||||
detector.detect(path)
|
||||
|
||||
# Should be called with string, not Path
|
||||
mock_pipeline.predict.assert_called_with("/some/path/to/image.png")
|
||||
|
||||
|
||||
class TestHtmlExtraction:
|
||||
"""Tests for HTML extraction from different element formats."""
|
||||
|
||||
def test_extracts_html_from_res_dict(self):
|
||||
"""Test extracting HTML from element.res dictionary."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.res = {"html": "<table><tr><td>From res</td></tr></table>"}
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove direct html attribute
|
||||
del element.html
|
||||
del element.table_html
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == "<table><tr><td>From res</td></tr></table>"
|
||||
|
||||
def test_returns_empty_html_when_not_found(self):
|
||||
"""Test empty HTML when no html attribute found."""
|
||||
mock_pipeline = MagicMock()
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
element.bbox = [0, 0, 100, 100]
|
||||
element.score = 0.9
|
||||
element.cells = []
|
||||
# Remove all html attributes
|
||||
del element.html
|
||||
del element.table_html
|
||||
del element.res
|
||||
|
||||
mock_result = MagicMock(spec=["layout_elements"])
|
||||
mock_result.layout_elements = [element]
|
||||
mock_pipeline.predict.return_value = [mock_result]
|
||||
|
||||
detector = TableDetector(pipeline=mock_pipeline)
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
results = detector.detect(image)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].html == ""
|
||||
|
||||
|
||||
class TestTableTypeDetection:
|
||||
"""Tests for table type detection."""
|
||||
|
||||
def test_detects_borderless_table(self):
|
||||
"""Test detection of borderless table type via _get_table_type."""
|
||||
detector = TableDetector()
|
||||
|
||||
# Create mock element with borderless label
|
||||
element = MagicMock()
|
||||
element.label = "borderless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_detects_wireless_table_label(self):
|
||||
"""Test detection of wireless table type."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "wireless_table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
def test_defaults_to_wired_table(self):
|
||||
"""Test default table type is wired."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.label = "table"
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wired"
|
||||
|
||||
def test_type_attribute_instead_of_label(self):
|
||||
"""Test table type detection using type attribute."""
|
||||
detector = TableDetector()
|
||||
|
||||
element = MagicMock()
|
||||
element.type = "wireless"
|
||||
del element.label # Remove label
|
||||
|
||||
result = detector._get_table_type(element)
|
||||
assert result == "wireless"
|
||||
|
||||
|
||||
class TestPipelineRuntimeError:
|
||||
"""Tests for pipeline runtime errors."""
|
||||
|
||||
def test_raises_runtime_error_when_pipeline_none(self):
|
||||
"""Test RuntimeError when pipeline is None during detect."""
|
||||
detector = TableDetector()
|
||||
detector._initialized = True # Bypass lazy init
|
||||
detector._pipeline = None
|
||||
|
||||
image = np.zeros((100, 100, 3), dtype=np.uint8)
|
||||
|
||||
with pytest.raises(RuntimeError) as exc_info:
|
||||
detector.detect(image)
|
||||
|
||||
assert "not initialized" in str(exc_info.value).lower()
|
||||
|
||||
@@ -142,6 +142,33 @@ class TestTextLineItemsExtractor:
|
||||
rows = extractor._group_by_row(elements)
|
||||
assert len(rows) == 2
|
||||
|
||||
def test_group_by_row_varying_heights_uses_average(self, extractor):
|
||||
"""Test grouping handles varying element heights using dynamic average.
|
||||
|
||||
When elements have varying heights, the row center should be recalculated
|
||||
as new elements are added, preventing tall elements from being incorrectly
|
||||
grouped with the next row.
|
||||
"""
|
||||
# First element: small height, center_y = 105
|
||||
# Second element: tall, center_y = 115 (but should still be same row)
|
||||
# Third element: next row, center_y = 160
|
||||
elements = [
|
||||
TextElement(text="Short", bbox=(0, 100, 100, 110)), # center_y = 105
|
||||
TextElement(text="Tall item", bbox=(150, 100, 250, 130)), # center_y = 115
|
||||
TextElement(text="Next row", bbox=(0, 150, 100, 170)), # center_y = 160
|
||||
]
|
||||
rows = extractor._group_by_row(elements)
|
||||
|
||||
# With dynamic average, both first and second element should be same row
|
||||
assert len(rows) == 2
|
||||
assert len(rows[0]) == 2 # Short and Tall item
|
||||
assert len(rows[1]) == 1 # Next row
|
||||
|
||||
def test_group_by_row_empty_input(self, extractor):
|
||||
"""Test grouping with empty input returns empty list."""
|
||||
rows = extractor._group_by_row([])
|
||||
assert rows == []
|
||||
|
||||
def test_looks_like_line_item_with_amount(self, extractor):
|
||||
"""Test line item detection with amount."""
|
||||
row = [
|
||||
@@ -253,6 +280,67 @@ class TestTextLineItemsExtractor:
|
||||
assert len(elements) == 4
|
||||
|
||||
|
||||
class TestExceptionHandling:
|
||||
"""Tests for exception handling in text element extraction."""
|
||||
|
||||
def test_extract_text_elements_handles_missing_bbox(self):
|
||||
"""Test that missing bbox is handled gracefully."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "text": "No bbox"}, # Missing bbox
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
# Should only have 1 valid element
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_invalid_bbox(self):
|
||||
"""Test that invalid bbox (less than 4 values) is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100], "text": "Invalid bbox"}, # Only 2 values
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_none_text(self):
|
||||
"""Test that None text is handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": None},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_empty_string(self):
|
||||
"""Test that empty string text is skipped."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": ""},
|
||||
{"label": "text", "bbox": [0, 150, 200, 170], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
def test_extract_text_elements_handles_malformed_element(self):
|
||||
"""Test that completely malformed elements are handled."""
|
||||
extractor = TextLineItemsExtractor()
|
||||
parsing_res = [
|
||||
"not a dict", # String instead of dict
|
||||
123, # Number instead of dict
|
||||
{"label": "text", "bbox": [0, 100, 200, 120], "text": "Valid"},
|
||||
]
|
||||
elements = extractor._extract_text_elements(parsing_res)
|
||||
assert len(elements) == 1
|
||||
assert elements[0].text == "Valid"
|
||||
|
||||
|
||||
class TestConvertTextLineItem:
|
||||
"""Tests for convert_text_line_item function."""
|
||||
|
||||
|
||||
1
tests/training/__init__.py
Normal file
1
tests/training/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for training package."""
|
||||
1
tests/training/yolo/__init__.py
Normal file
1
tests/training/yolo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for training.yolo module."""
|
||||
342
tests/training/yolo/test_annotation_generator.py
Normal file
342
tests/training/yolo/test_annotation_generator.py
Normal file
@@ -0,0 +1,342 @@
|
||||
"""
|
||||
Tests for AnnotationGenerator with field-specific bbox expansion.
|
||||
|
||||
Tests verify that annotations are generated correctly using
|
||||
field-specific scale strategies.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import pytest
|
||||
|
||||
from training.yolo.annotation_generator import (
|
||||
AnnotationGenerator,
|
||||
YOLOAnnotation,
|
||||
)
|
||||
from shared.fields import TRAINING_FIELD_CLASSES, CLASS_NAMES
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMatch:
|
||||
"""Mock Match object for testing."""
|
||||
bbox: tuple[float, float, float, float]
|
||||
score: float
|
||||
|
||||
|
||||
class TestYOLOAnnotation:
|
||||
"""Tests for YOLOAnnotation dataclass."""
|
||||
|
||||
def test_to_string_format(self):
|
||||
"""Verify YOLO format string output."""
|
||||
ann = YOLOAnnotation(
|
||||
class_id=0,
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.1,
|
||||
height=0.05,
|
||||
confidence=0.9
|
||||
)
|
||||
result = ann.to_string()
|
||||
assert result == "0 0.500000 0.500000 0.100000 0.050000"
|
||||
|
||||
def test_default_confidence(self):
|
||||
"""Verify default confidence is 1.0."""
|
||||
ann = YOLOAnnotation(
|
||||
class_id=0,
|
||||
x_center=0.5,
|
||||
y_center=0.5,
|
||||
width=0.1,
|
||||
height=0.05,
|
||||
)
|
||||
assert ann.confidence == 1.0
|
||||
|
||||
|
||||
class TestAnnotationGeneratorInit:
|
||||
"""Tests for AnnotationGenerator initialization."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Verify default initialization values."""
|
||||
gen = AnnotationGenerator()
|
||||
assert gen.min_confidence == 0.7
|
||||
assert gen.min_bbox_height_px == 30
|
||||
|
||||
def test_custom_values(self):
|
||||
"""Verify custom initialization values."""
|
||||
gen = AnnotationGenerator(
|
||||
min_confidence=0.8,
|
||||
min_bbox_height_px=40,
|
||||
)
|
||||
assert gen.min_confidence == 0.8
|
||||
assert gen.min_bbox_height_px == 40
|
||||
|
||||
|
||||
class TestGenerateFromMatches:
|
||||
"""Tests for generate_from_matches method."""
|
||||
|
||||
def test_generates_annotation_for_valid_match(self):
|
||||
"""Verify annotation is generated for valid match."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Mock match in PDF points (72 DPI)
|
||||
# At 150 DPI, coords multiply by 150/72 = 2.083
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.8)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
ann = annotations[0]
|
||||
assert ann.class_id == TRAINING_FIELD_CLASSES["InvoiceNumber"]
|
||||
assert ann.confidence == 0.8
|
||||
# Normalized values should be in 0-1 range
|
||||
assert 0 <= ann.x_center <= 1
|
||||
assert 0 <= ann.y_center <= 1
|
||||
assert 0 < ann.width <= 1
|
||||
assert 0 < ann.height <= 1
|
||||
|
||||
def test_skips_low_confidence_match(self):
|
||||
"""Verify low confidence matches are skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.7)
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 230), score=0.5)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_skips_unknown_field(self):
|
||||
"""Verify unknown fields are skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
matches = {
|
||||
"UnknownField": [MockMatch(bbox=(100, 200, 200, 230), score=0.9)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_takes_best_match_only(self):
|
||||
"""Verify only the best match is used per field."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": [
|
||||
MockMatch(bbox=(100, 200, 200, 230), score=0.9), # Best
|
||||
MockMatch(bbox=(300, 400, 400, 430), score=0.7),
|
||||
]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
assert annotations[0].confidence == 0.9
|
||||
|
||||
def test_handles_empty_matches(self):
|
||||
"""Verify empty matches list is handled."""
|
||||
gen = AnnotationGenerator()
|
||||
|
||||
matches = {
|
||||
"InvoiceNumber": []
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_applies_field_specific_expansion(self):
|
||||
"""Verify different fields get different expansion."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Same bbox, different fields
|
||||
bbox = (100, 200, 200, 230)
|
||||
|
||||
matches_invoice_number = {
|
||||
"InvoiceNumber": [MockMatch(bbox=bbox, score=0.9)]
|
||||
}
|
||||
matches_bankgiro = {
|
||||
"Bankgiro": [MockMatch(bbox=bbox, score=0.9)]
|
||||
}
|
||||
|
||||
ann_invoice = gen.generate_from_matches(
|
||||
matches=matches_invoice_number,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)[0]
|
||||
|
||||
ann_bankgiro = gen.generate_from_matches(
|
||||
matches=matches_bankgiro,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)[0]
|
||||
|
||||
# Bankgiro has extra_left_ratio=0.80, invoice_number has extra_top_ratio=0.40
|
||||
# They should have different widths due to different expansion
|
||||
# Bankgiro expands more to the left
|
||||
assert ann_bankgiro.width != ann_invoice.width or ann_bankgiro.x_center != ann_invoice.x_center
|
||||
|
||||
def test_enforces_min_bbox_height(self):
|
||||
"""Verify minimum bbox height is enforced."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5, min_bbox_height_px=50)
|
||||
|
||||
# Very small bbox
|
||||
matches = {
|
||||
"InvoiceNumber": [MockMatch(bbox=(100, 200, 200, 210), score=0.9)]
|
||||
}
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=72 # 1:1 scale
|
||||
)
|
||||
|
||||
assert len(annotations) == 1
|
||||
# Height should be at least min_bbox_height_px / image_height
|
||||
# After scale strategy expansion, height should be >= 50/1000 = 0.05
|
||||
# Actually the min_bbox_height check happens AFTER expand_bbox
|
||||
# So the final height should meet the minimum
|
||||
|
||||
|
||||
class TestAddPaymentLineAnnotation:
|
||||
"""Tests for add_payment_line_annotation method."""
|
||||
|
||||
def test_adds_payment_line_annotation(self):
|
||||
"""Verify payment_line annotation is added."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
ann = result[0]
|
||||
assert ann.class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||
assert ann.confidence == 0.9
|
||||
|
||||
def test_skips_none_bbox(self):
|
||||
"""Verify None bbox is handled."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=None,
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_skips_low_confidence(self):
|
||||
"""Verify low confidence is skipped."""
|
||||
gen = AnnotationGenerator(min_confidence=0.7)
|
||||
annotations = []
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=annotations,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.5,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 0
|
||||
|
||||
def test_appends_to_existing_annotations(self):
|
||||
"""Verify payment_line is appended to existing list."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
existing = [YOLOAnnotation(0, 0.5, 0.5, 0.1, 0.1, 0.9)]
|
||||
|
||||
result = gen.add_payment_line_annotation(
|
||||
annotations=existing,
|
||||
payment_line_bbox=(100, 200, 400, 230),
|
||||
confidence=0.9,
|
||||
image_width=1000,
|
||||
image_height=1000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].class_id == 0 # Original
|
||||
assert result[1].class_id == TRAINING_FIELD_CLASSES["payment_line"]
|
||||
|
||||
|
||||
class TestMultipleFieldsIntegration:
|
||||
"""Integration tests for multiple fields."""
|
||||
|
||||
def test_generates_annotations_for_all_field_types(self):
|
||||
"""Verify annotations can be generated for all field types."""
|
||||
gen = AnnotationGenerator(min_confidence=0.5)
|
||||
|
||||
# Create matches for each field (except payment_line which is derived)
|
||||
field_names = [
|
||||
"InvoiceNumber",
|
||||
"InvoiceDate",
|
||||
"InvoiceDueDate",
|
||||
"OCR",
|
||||
"Bankgiro",
|
||||
"Plusgiro",
|
||||
"Amount",
|
||||
"supplier_organisation_number",
|
||||
"customer_number",
|
||||
]
|
||||
|
||||
matches = {}
|
||||
for i, field_name in enumerate(field_names):
|
||||
# Stagger bboxes to avoid overlap
|
||||
matches[field_name] = [
|
||||
MockMatch(bbox=(100 + i * 50, 100 + i * 30, 200 + i * 50, 130 + i * 30), score=0.9)
|
||||
]
|
||||
|
||||
annotations = gen.generate_from_matches(
|
||||
matches=matches,
|
||||
image_width=2000,
|
||||
image_height=2000,
|
||||
dpi=150
|
||||
)
|
||||
|
||||
assert len(annotations) == len(field_names)
|
||||
|
||||
# Verify all class_ids are present
|
||||
class_ids = {ann.class_id for ann in annotations}
|
||||
expected_class_ids = {TRAINING_FIELD_CLASSES[fn] for fn in field_names}
|
||||
assert class_ids == expected_class_ids
|
||||
251
tests/training/yolo/test_db_dataset.py
Normal file
251
tests/training/yolo/test_db_dataset.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""Tests for db_dataset.py expand_bbox integration."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from pathlib import Path
|
||||
|
||||
from training.yolo.db_dataset import DBYOLODataset
|
||||
from training.yolo.annotation_generator import YOLOAnnotation
|
||||
from shared.bbox import FIELD_SCALE_STRATEGIES, DEFAULT_STRATEGY
|
||||
from shared.fields import CLASS_NAMES
|
||||
|
||||
|
||||
class TestConvertLabelsWithExpandBbox:
|
||||
"""Tests for _convert_labels using expand_bbox instead of fixed padding."""
|
||||
|
||||
def test_convert_labels_uses_expand_bbox(self):
|
||||
"""Verify _convert_labels calls expand_bbox for field-specific expansion."""
|
||||
# Create a mock dataset without loading from DB
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Create annotation for bankgiro (has extra_left_ratio)
|
||||
# bbox in PDF points: x0=100, y0=200, x1=200, y1=250
|
||||
# center: (150, 225), width: 100, height: 50
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro
|
||||
x_center=150, # in PDF points
|
||||
y_center=225,
|
||||
width=100,
|
||||
height=50,
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
# Image size in pixels (at 300 DPI)
|
||||
img_width = 2480 # A4 width at 300 DPI
|
||||
img_height = 3508 # A4 height at 300 DPI
|
||||
|
||||
# Convert labels
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Should have one label
|
||||
assert labels.shape == (1, 5)
|
||||
|
||||
# Check class_id
|
||||
assert labels[0, 0] == 4
|
||||
|
||||
# The bbox should be expanded using bankgiro strategy (extra_left_ratio=0.80)
|
||||
# Original bbox at 300 DPI:
|
||||
# x0 = 100 * (300/72) = 416.67
|
||||
# y0 = 200 * (300/72) = 833.33
|
||||
# x1 = 200 * (300/72) = 833.33
|
||||
# y1 = 250 * (300/72) = 1041.67
|
||||
# width_px = 416.67, height_px = 208.33
|
||||
|
||||
# After expand_bbox with bankgiro strategy:
|
||||
# scale_x=1.45, scale_y=1.35, extra_left_ratio=0.80
|
||||
# The x_center should shift left due to extra_left_ratio
|
||||
x_center = labels[0, 1]
|
||||
y_center = labels[0, 2]
|
||||
width = labels[0, 3]
|
||||
height = labels[0, 4]
|
||||
|
||||
# Verify normalized values are in valid range
|
||||
assert 0 <= x_center <= 1
|
||||
assert 0 <= y_center <= 1
|
||||
assert 0 < width <= 1
|
||||
assert 0 < height <= 1
|
||||
|
||||
# Width should be larger than original due to scaling and extra_left
|
||||
# Original normalized width: 416.67 / 2480 = 0.168
|
||||
# After bankgiro expansion it should be wider
|
||||
assert width > 0.168
|
||||
|
||||
def test_convert_labels_different_field_types(self):
|
||||
"""Verify different field types use their specific strategies."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Same bbox for different field types
|
||||
base_annotation = {
|
||||
'x_center': 150,
|
||||
'y_center': 225,
|
||||
'width': 100,
|
||||
'height': 50,
|
||||
'confidence': 0.9
|
||||
}
|
||||
|
||||
# OCR number (class_id=3) - has extra_top_ratio=0.60
|
||||
ocr_annotations = [YOLOAnnotation(class_id=3, **base_annotation)]
|
||||
ocr_labels = dataset._convert_labels(ocr_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Bankgiro (class_id=4) - has extra_left_ratio=0.80
|
||||
bankgiro_annotations = [YOLOAnnotation(class_id=4, **base_annotation)]
|
||||
bankgiro_labels = dataset._convert_labels(bankgiro_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Amount (class_id=6) - has extra_right_ratio=0.30
|
||||
amount_annotations = [YOLOAnnotation(class_id=6, **base_annotation)]
|
||||
amount_labels = dataset._convert_labels(amount_annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# Each field type should have different expansion
|
||||
# OCR should expand more vertically (extra_top)
|
||||
# Bankgiro should expand more to the left
|
||||
# Amount should expand more to the right
|
||||
|
||||
# OCR: extra_top shifts y_center up
|
||||
# Bankgiro: extra_left shifts x_center left
|
||||
# So bankgiro x_center < OCR x_center
|
||||
assert bankgiro_labels[0, 1] < ocr_labels[0, 1]
|
||||
|
||||
# OCR has higher scale_y (1.80) than amount (1.35)
|
||||
assert ocr_labels[0, 4] > amount_labels[0, 4]
|
||||
|
||||
def test_convert_labels_clamps_to_image_bounds(self):
|
||||
"""Verify labels are clamped to image boundaries."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
# Annotation near edge of image (in PDF points)
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=4, # bankgiro - will expand left
|
||||
x_center=30, # Very close to left edge
|
||||
y_center=50,
|
||||
width=40,
|
||||
height=30,
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
labels = dataset._convert_labels(annotations, img_width, img_height, is_scanned=False)
|
||||
|
||||
# All values should be in valid range
|
||||
assert 0 <= labels[0, 1] <= 1 # x_center
|
||||
assert 0 <= labels[0, 2] <= 1 # y_center
|
||||
assert 0 < labels[0, 3] <= 1 # width
|
||||
assert 0 < labels[0, 4] <= 1 # height
|
||||
|
||||
def test_convert_labels_empty_annotations(self):
|
||||
"""Verify empty annotations return empty array."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 30
|
||||
|
||||
labels = dataset._convert_labels([], 2480, 3508, is_scanned=False)
|
||||
|
||||
assert labels.shape == (0, 5)
|
||||
assert labels.dtype == np.float32
|
||||
|
||||
def test_convert_labels_minimum_height(self):
|
||||
"""Verify minimum height is enforced after expansion."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.dpi = 300
|
||||
dataset.min_bbox_height_px = 50 # Higher minimum
|
||||
|
||||
# Very small annotation
|
||||
annotations = [
|
||||
YOLOAnnotation(
|
||||
class_id=9, # payment_line - minimal expansion
|
||||
x_center=100,
|
||||
y_center=100,
|
||||
width=200,
|
||||
height=5, # Very small height
|
||||
confidence=0.9
|
||||
)
|
||||
]
|
||||
|
||||
labels = dataset._convert_labels(annotations, 2480, 3508, is_scanned=False)
|
||||
|
||||
# Height should be at least min_bbox_height_px / img_height
|
||||
min_normalized_height = 50 / 3508
|
||||
assert labels[0, 4] >= min_normalized_height
|
||||
|
||||
|
||||
class TestCreateAnnotationWithClassName:
|
||||
"""Tests for _create_annotation storing class_name for expand_bbox lookup."""
|
||||
|
||||
def test_create_annotation_stores_class_name(self):
|
||||
"""Verify _create_annotation stores class_name for later use."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
|
||||
# Create annotation for invoice_number
|
||||
annotation = dataset._create_annotation(
|
||||
field_name="InvoiceNumber",
|
||||
bbox=[100, 200, 200, 250],
|
||||
score=0.9
|
||||
)
|
||||
|
||||
assert annotation.class_id == 0 # invoice_number class_id
|
||||
|
||||
|
||||
class TestLoadLabelsFromDbWithClassName:
|
||||
"""Tests for _load_labels_from_db preserving field_name for expansion."""
|
||||
|
||||
def test_load_labels_maps_field_names_correctly(self):
|
||||
"""Verify field names are mapped correctly for expand_bbox."""
|
||||
dataset = object.__new__(DBYOLODataset)
|
||||
dataset.min_confidence = 0.7
|
||||
|
||||
# Mock database
|
||||
mock_db = MagicMock()
|
||||
mock_db.get_documents_batch.return_value = {
|
||||
'doc1': {
|
||||
'success': True,
|
||||
'pdf_type': 'text',
|
||||
'split': 'train',
|
||||
'field_results': [
|
||||
{
|
||||
'matched': True,
|
||||
'field_name': 'Bankgiro',
|
||||
'score': 0.9,
|
||||
'bbox': [100, 200, 200, 250],
|
||||
'page_no': 0
|
||||
},
|
||||
{
|
||||
'matched': True,
|
||||
'field_name': 'supplier_accounts(Plusgiro)',
|
||||
'score': 0.85,
|
||||
'bbox': [300, 400, 400, 450],
|
||||
'page_no': 0
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
dataset.db = mock_db
|
||||
|
||||
result = dataset._load_labels_from_db(['doc1'])
|
||||
|
||||
assert 'doc1' in result
|
||||
page_labels, is_scanned, csv_split = result['doc1']
|
||||
|
||||
# Should have 2 annotations on page 0
|
||||
assert 0 in page_labels
|
||||
assert len(page_labels[0]) == 2
|
||||
|
||||
# First annotation: Bankgiro (class_id=4)
|
||||
assert page_labels[0][0].class_id == 4
|
||||
|
||||
# Second annotation: Plusgiro mapped from supplier_accounts(Plusgiro) (class_id=5)
|
||||
assert page_labels[0][1].class_id == 5
|
||||
264
tests/web/test_documents_upload_validation.py
Normal file
264
tests/web/test_documents_upload_validation.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Tests for PDF Magic Bytes Validation in Document Upload.
|
||||
|
||||
TDD: These tests are written FIRST, before implementation.
|
||||
They should FAIL initially until the validation logic is implemented.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock, patch, AsyncMock
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import UploadFile
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.web.api.v1.admin.documents import create_documents_router
|
||||
from backend.web.config import StorageConfig
|
||||
|
||||
|
||||
# Test constants
|
||||
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
TEST_TOKEN = "test-admin-token-12345"
|
||||
|
||||
|
||||
class TestPDFMagicBytesValidation:
|
||||
"""Tests for PDF magic bytes validation during upload."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage_config(self, tmp_path):
|
||||
"""Create a StorageConfig for testing."""
|
||||
return StorageConfig(
|
||||
upload_dir=tmp_path / "uploads",
|
||||
result_dir=tmp_path / "results",
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dependencies(self):
|
||||
"""Create mock dependencies for document upload."""
|
||||
mock_docs = MagicMock()
|
||||
mock_docs.create.return_value = TEST_DOC_UUID
|
||||
|
||||
mock_annotations = MagicMock()
|
||||
mock_annotations.get_for_document.return_value = []
|
||||
|
||||
return {
|
||||
"docs": mock_docs,
|
||||
"annotations": mock_annotations,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def valid_pdf_content(self) -> bytes:
|
||||
"""Create valid PDF content with correct magic bytes."""
|
||||
# PDF files must start with %PDF
|
||||
return b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<>>\nendobj\ntrailer\n<<>>\n%%EOF"
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_pdf_content_exe(self) -> bytes:
|
||||
"""Create content that looks like an executable (MZ header)."""
|
||||
return b"MZ\x90\x00\x03\x00\x00\x00\x04\x00\x00\x00\xff\xff"
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_pdf_content_text(self) -> bytes:
|
||||
"""Create plain text content masquerading as PDF."""
|
||||
return b"This is not a PDF file, just plain text."
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_pdf_content_html(self) -> bytes:
|
||||
"""Create HTML content masquerading as PDF."""
|
||||
return b"<!DOCTYPE html><html><body>Not a PDF</body></html>"
|
||||
|
||||
@pytest.fixture
|
||||
def empty_content(self) -> bytes:
|
||||
"""Create empty file content."""
|
||||
return b""
|
||||
|
||||
@pytest.fixture
|
||||
def almost_valid_pdf(self) -> bytes:
|
||||
"""Create content that starts with %PD but not %PDF."""
|
||||
return b"%PD-1.4\nNot quite right"
|
||||
|
||||
def test_valid_pdf_passes_validation(self, valid_pdf_content):
|
||||
"""Test that a valid PDF file with correct magic bytes passes validation.
|
||||
|
||||
A valid PDF must start with the bytes b'%PDF'.
|
||||
"""
|
||||
# Import the validation function (to be implemented)
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
# Should not raise any exception
|
||||
validate_pdf_magic_bytes(valid_pdf_content)
|
||||
|
||||
def test_invalid_pdf_exe_fails_validation(self, invalid_pdf_content_exe):
|
||||
"""Test that an executable file renamed to .pdf fails validation.
|
||||
|
||||
This is a security test - attackers might try to upload malicious
|
||||
executables by renaming them to .pdf.
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(invalid_pdf_content_exe)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
assert "valid PDF header" in str(exc_info.value)
|
||||
|
||||
def test_invalid_pdf_text_fails_validation(self, invalid_pdf_content_text):
|
||||
"""Test that plain text file renamed to .pdf fails validation."""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(invalid_pdf_content_text)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
def test_invalid_pdf_html_fails_validation(self, invalid_pdf_content_html):
|
||||
"""Test that HTML file renamed to .pdf fails validation."""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(invalid_pdf_content_html)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
def test_empty_file_fails_validation(self, empty_content):
|
||||
"""Test that an empty file fails validation.
|
||||
|
||||
Empty files cannot be valid PDFs and should be rejected.
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(empty_content)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
def test_almost_valid_pdf_fails_validation(self, almost_valid_pdf):
|
||||
"""Test that content starting with %PD but not %PDF fails validation.
|
||||
|
||||
The magic bytes must be exactly %PDF (4 bytes).
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(almost_valid_pdf)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
def test_pdf_magic_bytes_constant(self):
|
||||
"""Test that PDF magic bytes constant is correctly defined."""
|
||||
from backend.web.api.v1.admin.documents import PDF_MAGIC_BYTES
|
||||
|
||||
assert PDF_MAGIC_BYTES == b"%PDF"
|
||||
|
||||
def test_validation_is_case_sensitive(self):
|
||||
"""Test that magic bytes validation is case-sensitive.
|
||||
|
||||
%pdf (lowercase) should fail - PDF magic bytes are uppercase.
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
lowercase_pdf = b"%pdf-1.4\nfake content"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(lowercase_pdf)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestDocumentUploadWithMagicBytesValidation:
|
||||
"""Integration tests for document upload with magic bytes validation."""
|
||||
|
||||
@pytest.fixture
|
||||
def storage_config(self, tmp_path):
|
||||
"""Create a StorageConfig for testing."""
|
||||
return StorageConfig(
|
||||
upload_dir=tmp_path / "uploads",
|
||||
result_dir=tmp_path / "results",
|
||||
max_file_size_mb=50,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def valid_pdf_content(self) -> bytes:
|
||||
"""Create valid PDF content."""
|
||||
return b"%PDF-1.4\n%\xe2\xe3\xcf\xd3\n1 0 obj\n<<>>\nendobj\ntrailer\n<<>>\n%%EOF"
|
||||
|
||||
@pytest.fixture
|
||||
def invalid_pdf_content(self) -> bytes:
|
||||
"""Create invalid PDF content (executable header)."""
|
||||
return b"MZ\x90\x00\x03\x00\x00\x00"
|
||||
|
||||
def test_upload_valid_pdf_succeeds(
|
||||
self, storage_config, valid_pdf_content
|
||||
):
|
||||
"""Test that uploading a valid PDF with correct magic bytes succeeds."""
|
||||
router = create_documents_router(storage_config)
|
||||
|
||||
# Find the upload endpoint (path includes prefix /admin/documents)
|
||||
upload_route = None
|
||||
for route in router.routes:
|
||||
if hasattr(route, 'methods') and 'POST' in route.methods:
|
||||
if route.path == "/admin/documents":
|
||||
upload_route = route
|
||||
break
|
||||
|
||||
assert upload_route is not None, "Upload route should exist"
|
||||
|
||||
# Validate that valid PDF content passes validation
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
validate_pdf_magic_bytes(valid_pdf_content) # Should not raise
|
||||
|
||||
def test_upload_invalid_pdf_returns_400(
|
||||
self, storage_config, invalid_pdf_content
|
||||
):
|
||||
"""Test that uploading an invalid PDF returns HTTP 400.
|
||||
|
||||
The error message should clearly indicate the PDF header is invalid.
|
||||
"""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
# Simulate what the upload endpoint should do
|
||||
try:
|
||||
validate_pdf_magic_bytes(invalid_pdf_content)
|
||||
pytest.fail("Should have raised ValueError for invalid PDF")
|
||||
except ValueError as e:
|
||||
# The endpoint should convert this to HTTP 400
|
||||
assert "Invalid PDF file" in str(e)
|
||||
assert "valid PDF header" in str(e)
|
||||
|
||||
def test_upload_empty_pdf_returns_400(self, storage_config):
|
||||
"""Test that uploading an empty file returns HTTP 400."""
|
||||
from backend.web.api.v1.admin.documents import validate_pdf_magic_bytes
|
||||
|
||||
empty_content = b""
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
validate_pdf_magic_bytes(empty_content)
|
||||
|
||||
assert "Invalid PDF file" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestNonPDFFileValidation:
|
||||
"""Tests to ensure non-PDF files are not affected by magic bytes validation."""
|
||||
|
||||
def test_png_files_skip_pdf_validation(self):
|
||||
"""Test that PNG files do not go through PDF magic bytes validation.
|
||||
|
||||
Only files with .pdf extension should be validated for PDF magic bytes.
|
||||
"""
|
||||
# PNG magic bytes
|
||||
png_content = b"\x89PNG\r\n\x1a\n"
|
||||
file_ext = ".png"
|
||||
|
||||
# PNG files should not be validated with PDF magic bytes check
|
||||
# The validation should only apply to .pdf files
|
||||
assert file_ext != ".pdf"
|
||||
|
||||
def test_jpg_files_skip_pdf_validation(self):
|
||||
"""Test that JPG files do not go through PDF magic bytes validation."""
|
||||
# JPEG magic bytes
|
||||
jpg_content = b"\xff\xd8\xff\xe0"
|
||||
file_ext = ".jpg"
|
||||
|
||||
assert file_ext != ".pdf"
|
||||
@@ -232,10 +232,8 @@ class TestInferenceServicePDFRendering:
|
||||
@patch('backend.pipeline.pipeline.InferencePipeline')
|
||||
@patch('backend.pipeline.yolo_detector.YOLODetector')
|
||||
@patch('shared.pdf.renderer.render_pdf_to_images')
|
||||
@patch('ultralytics.YOLO')
|
||||
def test_pdf_visualization_imports_correctly(
|
||||
self,
|
||||
mock_yolo_class,
|
||||
mock_render_pdf,
|
||||
mock_yolo_detector,
|
||||
mock_pipeline,
|
||||
@@ -248,12 +246,22 @@ class TestInferenceServicePDFRendering:
|
||||
This catches the import error we had with:
|
||||
from ..pdf.renderer (wrong) vs from shared.pdf.renderer (correct)
|
||||
"""
|
||||
# Setup mocks
|
||||
# Setup mocks for detector
|
||||
mock_detector_instance = Mock()
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_model = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model.predict.return_value = [mock_result]
|
||||
mock_detector_instance.model = mock_model
|
||||
mock_yolo_detector.return_value = mock_detector_instance
|
||||
|
||||
# Setup mock for pipeline
|
||||
mock_pipeline_instance = Mock()
|
||||
mock_pipeline.return_value = mock_pipeline_instance
|
||||
|
||||
# Initialize service to setup _detector
|
||||
inference_service.initialize()
|
||||
|
||||
# Create a fake PDF path
|
||||
pdf_path = tmp_path / "test.pdf"
|
||||
pdf_path.touch()
|
||||
@@ -264,18 +272,12 @@ class TestInferenceServicePDFRendering:
|
||||
img.save(image_bytes, format='PNG')
|
||||
mock_render_pdf.return_value = [(1, image_bytes.getvalue())]
|
||||
|
||||
# Mock YOLO
|
||||
mock_model_instance = Mock()
|
||||
mock_result = Mock()
|
||||
mock_result.save = Mock()
|
||||
mock_model_instance.predict.return_value = [mock_result]
|
||||
mock_yolo_class.return_value = mock_model_instance
|
||||
|
||||
# This should not raise ImportError
|
||||
# This should not raise ImportError and should use self._detector.model
|
||||
result_path = inference_service._save_pdf_visualization(pdf_path, "test123")
|
||||
|
||||
# Verify import was successful
|
||||
# Verify import was successful and detector.model was used
|
||||
mock_render_pdf.assert_called_once()
|
||||
mock_model.predict.assert_called_once()
|
||||
assert result_path is not None
|
||||
|
||||
|
||||
|
||||
367
tests/web/test_training_export.py
Normal file
367
tests/web/test_training_export.py
Normal file
@@ -0,0 +1,367 @@
|
||||
"""
|
||||
Tests for Training Export with expand_bbox integration.
|
||||
|
||||
Tests the export endpoint's integration with field-specific bbox expansion.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from shared.bbox import expand_bbox
|
||||
from shared.fields import CLASS_NAMES, FIELD_CLASS_IDS
|
||||
|
||||
|
||||
class TestExpandBboxForExport:
|
||||
"""Tests for expand_bbox integration in export workflow."""
|
||||
|
||||
def test_expand_bbox_converts_normalized_to_pixel_and_back(self):
|
||||
"""Verify expand_bbox works with pixel-to-normalized conversion."""
|
||||
# Annotation stored as normalized coords
|
||||
x_center_norm = 0.5
|
||||
y_center_norm = 0.5
|
||||
width_norm = 0.1
|
||||
height_norm = 0.05
|
||||
|
||||
# Image dimensions
|
||||
img_width = 2480 # A4 at 300 DPI
|
||||
img_height = 3508
|
||||
|
||||
# Convert to pixel coords
|
||||
x_center_px = x_center_norm * img_width
|
||||
y_center_px = y_center_norm * img_height
|
||||
width_px = width_norm * img_width
|
||||
height_px = height_norm * img_height
|
||||
|
||||
# Convert to corner coords
|
||||
x0 = x_center_px - width_px / 2
|
||||
y0 = y_center_px - height_px / 2
|
||||
x1 = x_center_px + width_px / 2
|
||||
y1 = y_center_px + height_px / 2
|
||||
|
||||
# Apply expansion
|
||||
class_name = "invoice_number"
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify expanded bbox is larger
|
||||
assert ex0 < x0 # Left expanded
|
||||
assert ey0 < y0 # Top expanded
|
||||
assert ex1 > x1 # Right expanded
|
||||
assert ey1 > y1 # Bottom expanded
|
||||
|
||||
# Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify valid normalized coords
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 <= new_width <= 1
|
||||
assert 0 <= new_height <= 1
|
||||
|
||||
def test_expand_bbox_manual_mode_minimal_expansion(self):
|
||||
"""Verify manual annotations use minimal expansion."""
|
||||
# Small bbox
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Auto mode (field-specific expansion)
|
||||
auto_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=False,
|
||||
)
|
||||
|
||||
# Manual mode (minimal expansion)
|
||||
manual_result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
manual_mode=True,
|
||||
)
|
||||
|
||||
# Auto expansion should be larger than manual
|
||||
auto_width = auto_result[2] - auto_result[0]
|
||||
manual_width = manual_result[2] - manual_result[0]
|
||||
assert auto_width > manual_width
|
||||
|
||||
auto_height = auto_result[3] - auto_result[1]
|
||||
manual_height = manual_result[3] - manual_result[1]
|
||||
assert auto_height > manual_height
|
||||
|
||||
def test_expand_bbox_different_sources_use_correct_mode(self):
|
||||
"""Verify different annotation sources use correct expansion mode."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Define source to manual_mode mapping
|
||||
source_mode_mapping = {
|
||||
"manual": True, # Manual annotations -> minimal expansion
|
||||
"auto": False, # Auto-labeled -> field-specific expansion
|
||||
"imported": True, # Imported (from CSV) -> minimal expansion
|
||||
}
|
||||
|
||||
results = {}
|
||||
for source, manual_mode in source_mode_mapping.items():
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="ocr_number",
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
results[source] = result
|
||||
|
||||
# Auto should have largest expansion
|
||||
auto_area = (results["auto"][2] - results["auto"][0]) * \
|
||||
(results["auto"][3] - results["auto"][1])
|
||||
manual_area = (results["manual"][2] - results["manual"][0]) * \
|
||||
(results["manual"][3] - results["manual"][1])
|
||||
imported_area = (results["imported"][2] - results["imported"][0]) * \
|
||||
(results["imported"][3] - results["imported"][1])
|
||||
|
||||
assert auto_area > manual_area
|
||||
assert auto_area > imported_area
|
||||
# Manual and imported should be the same (both use minimal mode)
|
||||
assert manual_area == imported_area
|
||||
|
||||
def test_expand_bbox_all_field_types_work(self):
|
||||
"""Verify expand_bbox works for all field types."""
|
||||
bbox = (100, 100, 200, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
for class_name in CLASS_NAMES:
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=class_name,
|
||||
)
|
||||
|
||||
# Verify result is a valid bbox
|
||||
assert len(result) == 4
|
||||
x0, y0, x1, y1 = result
|
||||
assert x0 >= 0
|
||||
assert y0 >= 0
|
||||
assert x1 <= img_width
|
||||
assert y1 <= img_height
|
||||
assert x1 > x0
|
||||
assert y1 > y0
|
||||
|
||||
|
||||
class TestExportAnnotationExpansion:
|
||||
"""Tests for annotation expansion in export workflow."""
|
||||
|
||||
def test_annotation_bbox_conversion_workflow(self):
|
||||
"""Test full annotation bbox conversion workflow."""
|
||||
# Simulate stored annotation (normalized coords)
|
||||
class MockAnnotation:
|
||||
class_id = FIELD_CLASS_IDS["invoice_number"]
|
||||
class_name = "invoice_number"
|
||||
x_center = 0.3
|
||||
y_center = 0.2
|
||||
width = 0.15
|
||||
height = 0.03
|
||||
source = "auto"
|
||||
|
||||
ann = MockAnnotation()
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
# Step 1: Convert normalized to pixel corner coords
|
||||
half_w = (ann.width * img_width) / 2
|
||||
half_h = (ann.height * img_height) / 2
|
||||
x0 = ann.x_center * img_width - half_w
|
||||
y0 = ann.y_center * img_height - half_h
|
||||
x1 = ann.x_center * img_width + half_w
|
||||
y1 = ann.y_center * img_height + half_h
|
||||
|
||||
# Step 2: Determine manual_mode based on source
|
||||
manual_mode = ann.source in ("manual", "imported")
|
||||
|
||||
# Step 3: Apply expand_bbox
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann.class_name,
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Step 4: Convert back to normalized
|
||||
new_x_center = (ex0 + ex1) / 2 / img_width
|
||||
new_y_center = (ey0 + ey1) / 2 / img_height
|
||||
new_width = (ex1 - ex0) / img_width
|
||||
new_height = (ey1 - ey0) / img_height
|
||||
|
||||
# Verify expansion happened (auto mode)
|
||||
assert new_width > ann.width
|
||||
assert new_height > ann.height
|
||||
|
||||
# Verify valid YOLO format
|
||||
assert 0 <= new_x_center <= 1
|
||||
assert 0 <= new_y_center <= 1
|
||||
assert 0 < new_width <= 1
|
||||
assert 0 < new_height <= 1
|
||||
|
||||
def test_export_applies_expansion_to_each_annotation(self):
|
||||
"""Test that export applies expansion to each annotation."""
|
||||
# Simulate multiple annotations with different sources
|
||||
# Use smaller bboxes so manual mode padding has visible effect
|
||||
annotations = [
|
||||
{"class_name": "invoice_number", "source": "auto", "x_center": 0.3, "y_center": 0.2, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "ocr_number", "source": "manual", "x_center": 0.5, "y_center": 0.8, "width": 0.05, "height": 0.02},
|
||||
{"class_name": "amount", "source": "imported", "x_center": 0.7, "y_center": 0.5, "width": 0.05, "height": 0.02},
|
||||
]
|
||||
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
expanded_annotations = []
|
||||
for ann in annotations:
|
||||
# Convert to pixel coords
|
||||
half_w = (ann["width"] * img_width) / 2
|
||||
half_h = (ann["height"] * img_height) / 2
|
||||
x0 = ann["x_center"] * img_width - half_w
|
||||
y0 = ann["y_center"] * img_height - half_h
|
||||
x1 = ann["x_center"] * img_width + half_w
|
||||
y1 = ann["y_center"] * img_height + half_h
|
||||
|
||||
# Determine manual_mode
|
||||
manual_mode = ann["source"] in ("manual", "imported")
|
||||
|
||||
# Apply expansion
|
||||
ex0, ey0, ex1, ey1 = expand_bbox(
|
||||
bbox=(x0, y0, x1, y1),
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type=ann["class_name"],
|
||||
manual_mode=manual_mode,
|
||||
)
|
||||
|
||||
# Convert back to normalized
|
||||
expanded_annotations.append({
|
||||
"class_name": ann["class_name"],
|
||||
"source": ann["source"],
|
||||
"x_center": (ex0 + ex1) / 2 / img_width,
|
||||
"y_center": (ey0 + ey1) / 2 / img_height,
|
||||
"width": (ex1 - ex0) / img_width,
|
||||
"height": (ey1 - ey0) / img_height,
|
||||
})
|
||||
|
||||
# Verify auto-labeled annotation expanded more than manual/imported
|
||||
auto_ann = next(a for a in expanded_annotations if a["source"] == "auto")
|
||||
manual_ann = next(a for a in expanded_annotations if a["source"] == "manual")
|
||||
|
||||
# Auto mode should expand more than manual mode
|
||||
# (auto has larger scale factors and max_pad)
|
||||
assert auto_ann["width"] > manual_ann["width"]
|
||||
assert auto_ann["height"] > manual_ann["height"]
|
||||
|
||||
# All annotations should be expanded (at least slightly for manual mode)
|
||||
# Allow small precision loss (< 1%) due to integer conversion in expand_bbox
|
||||
for i, (orig, exp) in enumerate(zip(annotations, expanded_annotations)):
|
||||
# Width and height should be >= original (expansion or equal, with small tolerance)
|
||||
tolerance = 0.01 # 1% tolerance for integer rounding
|
||||
assert exp["width"] >= orig["width"] * (1 - tolerance), \
|
||||
f"Annotation {i} width unexpectedly smaller: {exp['width']} < {orig['width']}"
|
||||
assert exp["height"] >= orig["height"] * (1 - tolerance), \
|
||||
f"Annotation {i} height unexpectedly smaller: {exp['height']} < {orig['height']}"
|
||||
|
||||
|
||||
class TestExpandBboxEdgeCases:
|
||||
"""Tests for edge cases in export bbox expansion."""
|
||||
|
||||
def test_bbox_at_image_edge_left(self):
|
||||
"""Test bbox at left edge of image."""
|
||||
bbox = (0, 100, 50, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Left edge should be clamped to 0
|
||||
assert result[0] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_right(self):
|
||||
"""Test bbox at right edge of image."""
|
||||
bbox = (2400, 100, 2480, 150)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Right edge should be clamped to image width
|
||||
assert result[2] <= img_width
|
||||
|
||||
def test_bbox_at_image_edge_top(self):
|
||||
"""Test bbox at top edge of image."""
|
||||
bbox = (100, 0, 200, 50)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Top edge should be clamped to 0
|
||||
assert result[1] >= 0
|
||||
|
||||
def test_bbox_at_image_edge_bottom(self):
|
||||
"""Test bbox at bottom edge of image."""
|
||||
bbox = (100, 3400, 200, 3508)
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Bottom edge should be clamped to image height
|
||||
assert result[3] <= img_height
|
||||
|
||||
def test_very_small_bbox(self):
|
||||
"""Test very small bbox gets expanded."""
|
||||
bbox = (100, 100, 105, 105) # 5x5 pixel bbox
|
||||
img_width = 2480
|
||||
img_height = 3508
|
||||
|
||||
result = expand_bbox(
|
||||
bbox=bbox,
|
||||
image_width=img_width,
|
||||
image_height=img_height,
|
||||
field_type="invoice_number",
|
||||
)
|
||||
|
||||
# Should still produce a valid expanded bbox
|
||||
assert result[2] > result[0]
|
||||
assert result[3] > result[1]
|
||||
Reference in New Issue
Block a user