749 lines
25 KiB
Python
749 lines
25 KiB
Python
"""
|
|
LLM-based cross-validation for invoice field extraction.
|
|
|
|
Uses a vision LLM to extract fields from invoice PDFs and compare with
|
|
the autolabel results to identify potential errors.
|
|
"""
|
|
|
|
import json
|
|
import base64
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Optional, Dict, Any, List
|
|
from dataclasses import dataclass, asdict
|
|
from datetime import datetime
|
|
|
|
import psycopg2
|
|
from psycopg2.extras import execute_values
|
|
|
|
from shared.config import DEFAULT_DPI
|
|
|
|
|
|
@dataclass
|
|
class LLMExtractionResult:
|
|
"""Result of LLM field extraction."""
|
|
document_id: str
|
|
invoice_number: Optional[str] = None
|
|
invoice_date: Optional[str] = None
|
|
invoice_due_date: Optional[str] = None
|
|
ocr_number: Optional[str] = None
|
|
bankgiro: Optional[str] = None
|
|
plusgiro: Optional[str] = None
|
|
amount: Optional[str] = None
|
|
supplier_organisation_number: Optional[str] = None
|
|
raw_response: Optional[str] = None
|
|
model_used: Optional[str] = None
|
|
processing_time_ms: Optional[float] = None
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return asdict(self)
|
|
|
|
|
|
class LLMValidator:
|
|
"""
|
|
Cross-validates invoice field extraction using LLM.
|
|
|
|
Queries documents with failed field matches from the database,
|
|
sends the PDF images to an LLM for extraction, and stores
|
|
the results for comparison.
|
|
"""
|
|
|
|
# Fields to extract (excluding customer_number as requested)
|
|
FIELDS_TO_EXTRACT = [
|
|
'InvoiceNumber',
|
|
'InvoiceDate',
|
|
'InvoiceDueDate',
|
|
'OCR',
|
|
'Bankgiro',
|
|
'Plusgiro',
|
|
'Amount',
|
|
'supplier_organisation_number',
|
|
]
|
|
|
|
EXTRACTION_PROMPT = """You are an expert at extracting structured data from Swedish invoices.
|
|
|
|
Analyze this invoice image and extract the following fields. Return ONLY a valid JSON object with these exact keys:
|
|
|
|
{
|
|
"invoice_number": "the invoice number/fakturanummer",
|
|
"invoice_date": "the invoice date in YYYY-MM-DD format",
|
|
"invoice_due_date": "the due date/förfallodatum in YYYY-MM-DD format",
|
|
"ocr_number": "the OCR payment reference number",
|
|
"bankgiro": "the bankgiro number (format: XXXX-XXXX or XXXXXXXX)",
|
|
"plusgiro": "the plusgiro number",
|
|
"amount": "the total amount to pay (just the number, e.g., 1234.56)",
|
|
"supplier_organisation_number": "the supplier's organisation number (format: XXXXXX-XXXX)"
|
|
}
|
|
|
|
Rules:
|
|
- If a field is not found or not visible, use null
|
|
- For dates, convert Swedish month names (januari, februari, etc.) to YYYY-MM-DD
|
|
- For amounts, extract just the numeric value without currency symbols
|
|
- The OCR number is typically a long number used for payment reference
|
|
- Look for "Att betala" or "Summa att betala" for the amount
|
|
- Organisation number is 10 digits, often shown as XXXXXX-XXXX
|
|
|
|
Return ONLY the JSON object, no other text."""
|
|
|
|
def __init__(self, connection_string: str = None, pdf_dir: str = None):
|
|
"""
|
|
Initialize the validator.
|
|
|
|
Args:
|
|
connection_string: PostgreSQL connection string
|
|
pdf_dir: Directory containing PDF files
|
|
"""
|
|
import sys
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
|
from config import get_db_connection_string, PATHS
|
|
|
|
self.connection_string = connection_string or get_db_connection_string()
|
|
self.pdf_dir = Path(pdf_dir or PATHS['pdf_dir'])
|
|
self.conn = None
|
|
|
|
def connect(self):
|
|
"""Connect to database."""
|
|
if self.conn is None:
|
|
self.conn = psycopg2.connect(self.connection_string)
|
|
return self.conn
|
|
|
|
def close(self):
|
|
"""Close database connection."""
|
|
if self.conn:
|
|
self.conn.close()
|
|
self.conn = None
|
|
|
|
def create_validation_table(self):
|
|
"""Create the llm_validation table if it doesn't exist."""
|
|
conn = self.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS llm_validations (
|
|
id SERIAL PRIMARY KEY,
|
|
document_id TEXT NOT NULL,
|
|
-- Extracted fields
|
|
invoice_number TEXT,
|
|
invoice_date TEXT,
|
|
invoice_due_date TEXT,
|
|
ocr_number TEXT,
|
|
bankgiro TEXT,
|
|
plusgiro TEXT,
|
|
amount TEXT,
|
|
supplier_organisation_number TEXT,
|
|
-- Metadata
|
|
raw_response TEXT,
|
|
model_used TEXT,
|
|
processing_time_ms REAL,
|
|
error TEXT,
|
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
-- Comparison results (populated later)
|
|
comparison_results JSONB,
|
|
UNIQUE(document_id)
|
|
);
|
|
|
|
CREATE INDEX IF NOT EXISTS idx_llm_validations_document_id
|
|
ON llm_validations(document_id);
|
|
""")
|
|
conn.commit()
|
|
|
|
def get_documents_with_failed_matches(
|
|
self,
|
|
exclude_customer_number: bool = True,
|
|
limit: Optional[int] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get documents that have at least one failed field match.
|
|
|
|
Args:
|
|
exclude_customer_number: If True, ignore customer_number failures
|
|
limit: Maximum number of documents to return
|
|
|
|
Returns:
|
|
List of document info with failed fields
|
|
"""
|
|
conn = self.connect()
|
|
with conn.cursor() as cursor:
|
|
# Find documents with failed matches (excluding customer_number if requested)
|
|
exclude_clause = ""
|
|
if exclude_customer_number:
|
|
exclude_clause = "AND fr.field_name != 'customer_number'"
|
|
|
|
query = f"""
|
|
SELECT DISTINCT d.document_id, d.pdf_path, d.pdf_type,
|
|
d.supplier_name, d.split
|
|
FROM documents d
|
|
JOIN field_results fr ON d.document_id = fr.document_id
|
|
WHERE fr.matched = false
|
|
AND fr.field_name NOT LIKE 'supplier_accounts%%'
|
|
{exclude_clause}
|
|
AND d.document_id NOT IN (
|
|
SELECT document_id FROM llm_validations WHERE error IS NULL
|
|
)
|
|
ORDER BY d.document_id
|
|
"""
|
|
if limit:
|
|
query += f" LIMIT {limit}"
|
|
|
|
cursor.execute(query)
|
|
|
|
results = []
|
|
for row in cursor.fetchall():
|
|
doc_id = row[0]
|
|
|
|
# Get failed fields for this document
|
|
exclude_clause_inner = ""
|
|
if exclude_customer_number:
|
|
exclude_clause_inner = "AND field_name != 'customer_number'"
|
|
cursor.execute(f"""
|
|
SELECT field_name, csv_value, score
|
|
FROM field_results
|
|
WHERE document_id = %s
|
|
AND matched = false
|
|
AND field_name NOT LIKE 'supplier_accounts%%'
|
|
{exclude_clause_inner}
|
|
""", (doc_id,))
|
|
|
|
failed_fields = [
|
|
{'field': r[0], 'csv_value': r[1], 'score': r[2]}
|
|
for r in cursor.fetchall()
|
|
]
|
|
|
|
results.append({
|
|
'document_id': doc_id,
|
|
'pdf_path': row[1],
|
|
'pdf_type': row[2],
|
|
'supplier_name': row[3],
|
|
'split': row[4],
|
|
'failed_fields': failed_fields,
|
|
})
|
|
|
|
return results
|
|
|
|
def get_failed_match_stats(self, exclude_customer_number: bool = True) -> Dict[str, Any]:
|
|
"""Get statistics about failed matches."""
|
|
conn = self.connect()
|
|
with conn.cursor() as cursor:
|
|
exclude_clause = ""
|
|
if exclude_customer_number:
|
|
exclude_clause = "AND field_name != 'customer_number'"
|
|
|
|
# Count by field
|
|
cursor.execute(f"""
|
|
SELECT field_name, COUNT(*) as cnt
|
|
FROM field_results
|
|
WHERE matched = false
|
|
AND field_name NOT LIKE 'supplier_accounts%%'
|
|
{exclude_clause}
|
|
GROUP BY field_name
|
|
ORDER BY cnt DESC
|
|
""")
|
|
by_field = {row[0]: row[1] for row in cursor.fetchall()}
|
|
|
|
# Count documents with failures
|
|
cursor.execute(f"""
|
|
SELECT COUNT(DISTINCT document_id)
|
|
FROM field_results
|
|
WHERE matched = false
|
|
AND field_name NOT LIKE 'supplier_accounts%%'
|
|
{exclude_clause}
|
|
""")
|
|
doc_count = cursor.fetchone()[0]
|
|
|
|
# Already validated count
|
|
cursor.execute("""
|
|
SELECT COUNT(*) FROM llm_validations WHERE error IS NULL
|
|
""")
|
|
validated_count = cursor.fetchone()[0]
|
|
|
|
return {
|
|
'documents_with_failures': doc_count,
|
|
'already_validated': validated_count,
|
|
'remaining': doc_count - validated_count,
|
|
'failures_by_field': by_field,
|
|
}
|
|
|
|
def render_pdf_to_image(
|
|
self,
|
|
pdf_path: Path,
|
|
page_no: int = 0,
|
|
dpi: int = DEFAULT_DPI,
|
|
max_size_mb: float = 18.0
|
|
) -> bytes:
|
|
"""
|
|
Render a PDF page to PNG image bytes.
|
|
|
|
Args:
|
|
pdf_path: Path to PDF file
|
|
page_no: Page number to render (0-indexed)
|
|
dpi: Resolution for rendering
|
|
max_size_mb: Maximum image size in MB (Azure OpenAI limit is 20MB)
|
|
|
|
Returns:
|
|
PNG image bytes
|
|
"""
|
|
import fitz # PyMuPDF
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
|
|
doc = fitz.open(pdf_path)
|
|
page = doc[page_no]
|
|
|
|
# Try different DPI values until we get a small enough image
|
|
dpi_values = [dpi, 120, 100, 72, 50]
|
|
|
|
for current_dpi in dpi_values:
|
|
mat = fitz.Matrix(current_dpi / 72, current_dpi / 72)
|
|
pix = page.get_pixmap(matrix=mat)
|
|
png_bytes = pix.tobytes("png")
|
|
|
|
size_mb = len(png_bytes) / (1024 * 1024)
|
|
if size_mb <= max_size_mb:
|
|
doc.close()
|
|
return png_bytes
|
|
|
|
# If still too large, use JPEG compression
|
|
mat = fitz.Matrix(72 / 72, 72 / 72) # Lowest DPI
|
|
pix = page.get_pixmap(matrix=mat)
|
|
|
|
# Convert to PIL Image and compress as JPEG
|
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
|
|
|
|
# Try different JPEG quality levels
|
|
for quality in [85, 70, 50, 30]:
|
|
buffer = BytesIO()
|
|
img.save(buffer, format="JPEG", quality=quality)
|
|
jpeg_bytes = buffer.getvalue()
|
|
|
|
size_mb = len(jpeg_bytes) / (1024 * 1024)
|
|
if size_mb <= max_size_mb:
|
|
doc.close()
|
|
return jpeg_bytes
|
|
|
|
doc.close()
|
|
# Return whatever we have, let the API handle the error
|
|
return jpeg_bytes
|
|
|
|
def extract_with_openai(
|
|
self,
|
|
image_bytes: bytes,
|
|
model: str = "gpt-4o"
|
|
) -> LLMExtractionResult:
|
|
"""
|
|
Extract fields using OpenAI's vision API (supports Azure OpenAI).
|
|
|
|
Args:
|
|
image_bytes: PNG image bytes
|
|
model: Model to use (gpt-4o, gpt-4o-mini, etc.)
|
|
|
|
Returns:
|
|
Extraction result
|
|
"""
|
|
import openai
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
# Encode image to base64 and detect format
|
|
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
|
|
|
# Detect image format (PNG starts with \x89PNG, JPEG with \xFF\xD8)
|
|
if image_bytes[:4] == b'\x89PNG':
|
|
media_type = "image/png"
|
|
else:
|
|
media_type = "image/jpeg"
|
|
|
|
# Check for Azure OpenAI configuration
|
|
azure_endpoint = os.environ.get('AZURE_OPENAI_ENDPOINT')
|
|
azure_api_key = os.environ.get('AZURE_OPENAI_API_KEY')
|
|
azure_deployment = os.environ.get('AZURE_OPENAI_DEPLOYMENT', model)
|
|
|
|
if azure_endpoint and azure_api_key:
|
|
# Use Azure OpenAI
|
|
client = openai.AzureOpenAI(
|
|
azure_endpoint=azure_endpoint,
|
|
api_key=azure_api_key,
|
|
api_version="2024-02-15-preview"
|
|
)
|
|
model = azure_deployment # Use deployment name for Azure
|
|
else:
|
|
# Use standard OpenAI
|
|
client = openai.OpenAI()
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": self.EXTRACTION_PROMPT},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:{media_type};base64,{image_b64}",
|
|
"detail": "high"
|
|
}
|
|
}
|
|
]
|
|
}
|
|
],
|
|
max_tokens=1000,
|
|
temperature=0,
|
|
)
|
|
|
|
raw_response = response.choices[0].message.content
|
|
processing_time = (time.time() - start_time) * 1000
|
|
|
|
# Parse JSON response
|
|
# Try to extract JSON from response (may have markdown code blocks)
|
|
json_str = raw_response
|
|
if "```json" in json_str:
|
|
json_str = json_str.split("```json")[1].split("```")[0]
|
|
elif "```" in json_str:
|
|
json_str = json_str.split("```")[1].split("```")[0]
|
|
|
|
data = json.loads(json_str.strip())
|
|
|
|
return LLMExtractionResult(
|
|
document_id="", # Will be set by caller
|
|
invoice_number=data.get('invoice_number'),
|
|
invoice_date=data.get('invoice_date'),
|
|
invoice_due_date=data.get('invoice_due_date'),
|
|
ocr_number=data.get('ocr_number'),
|
|
bankgiro=data.get('bankgiro'),
|
|
plusgiro=data.get('plusgiro'),
|
|
amount=data.get('amount'),
|
|
supplier_organisation_number=data.get('supplier_organisation_number'),
|
|
raw_response=raw_response,
|
|
model_used=model,
|
|
processing_time_ms=processing_time,
|
|
)
|
|
|
|
except json.JSONDecodeError as e:
|
|
return LLMExtractionResult(
|
|
document_id="",
|
|
raw_response=raw_response if 'raw_response' in dir() else None,
|
|
model_used=model,
|
|
processing_time_ms=(time.time() - start_time) * 1000,
|
|
error=f"JSON parse error: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
return LLMExtractionResult(
|
|
document_id="",
|
|
model_used=model,
|
|
processing_time_ms=(time.time() - start_time) * 1000,
|
|
error=str(e)
|
|
)
|
|
|
|
def extract_with_anthropic(
|
|
self,
|
|
image_bytes: bytes,
|
|
model: str = "claude-sonnet-4-20250514"
|
|
) -> LLMExtractionResult:
|
|
"""
|
|
Extract fields using Anthropic's Claude API.
|
|
|
|
Args:
|
|
image_bytes: PNG image bytes
|
|
model: Model to use
|
|
|
|
Returns:
|
|
Extraction result
|
|
"""
|
|
import anthropic
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
# Encode image to base64
|
|
image_b64 = base64.b64encode(image_bytes).decode('utf-8')
|
|
|
|
client = anthropic.Anthropic()
|
|
|
|
try:
|
|
response = client.messages.create(
|
|
model=model,
|
|
max_tokens=1000,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{
|
|
"type": "image",
|
|
"source": {
|
|
"type": "base64",
|
|
"media_type": "image/png",
|
|
"data": image_b64,
|
|
}
|
|
},
|
|
{
|
|
"type": "text",
|
|
"text": self.EXTRACTION_PROMPT
|
|
}
|
|
]
|
|
}
|
|
],
|
|
)
|
|
|
|
raw_response = response.content[0].text
|
|
processing_time = (time.time() - start_time) * 1000
|
|
|
|
# Parse JSON response
|
|
json_str = raw_response
|
|
if "```json" in json_str:
|
|
json_str = json_str.split("```json")[1].split("```")[0]
|
|
elif "```" in json_str:
|
|
json_str = json_str.split("```")[1].split("```")[0]
|
|
|
|
data = json.loads(json_str.strip())
|
|
|
|
return LLMExtractionResult(
|
|
document_id="",
|
|
invoice_number=data.get('invoice_number'),
|
|
invoice_date=data.get('invoice_date'),
|
|
invoice_due_date=data.get('invoice_due_date'),
|
|
ocr_number=data.get('ocr_number'),
|
|
bankgiro=data.get('bankgiro'),
|
|
plusgiro=data.get('plusgiro'),
|
|
amount=data.get('amount'),
|
|
supplier_organisation_number=data.get('supplier_organisation_number'),
|
|
raw_response=raw_response,
|
|
model_used=model,
|
|
processing_time_ms=processing_time,
|
|
)
|
|
|
|
except json.JSONDecodeError as e:
|
|
return LLMExtractionResult(
|
|
document_id="",
|
|
raw_response=raw_response if 'raw_response' in dir() else None,
|
|
model_used=model,
|
|
processing_time_ms=(time.time() - start_time) * 1000,
|
|
error=f"JSON parse error: {str(e)}"
|
|
)
|
|
except Exception as e:
|
|
return LLMExtractionResult(
|
|
document_id="",
|
|
model_used=model,
|
|
processing_time_ms=(time.time() - start_time) * 1000,
|
|
error=str(e)
|
|
)
|
|
|
|
def save_validation_result(self, result: LLMExtractionResult):
|
|
"""Save extraction result to database."""
|
|
conn = self.connect()
|
|
with conn.cursor() as cursor:
|
|
cursor.execute("""
|
|
INSERT INTO llm_validations (
|
|
document_id, invoice_number, invoice_date, invoice_due_date,
|
|
ocr_number, bankgiro, plusgiro, amount,
|
|
supplier_organisation_number, raw_response, model_used,
|
|
processing_time_ms, error
|
|
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
|
ON CONFLICT (document_id) DO UPDATE SET
|
|
invoice_number = EXCLUDED.invoice_number,
|
|
invoice_date = EXCLUDED.invoice_date,
|
|
invoice_due_date = EXCLUDED.invoice_due_date,
|
|
ocr_number = EXCLUDED.ocr_number,
|
|
bankgiro = EXCLUDED.bankgiro,
|
|
plusgiro = EXCLUDED.plusgiro,
|
|
amount = EXCLUDED.amount,
|
|
supplier_organisation_number = EXCLUDED.supplier_organisation_number,
|
|
raw_response = EXCLUDED.raw_response,
|
|
model_used = EXCLUDED.model_used,
|
|
processing_time_ms = EXCLUDED.processing_time_ms,
|
|
error = EXCLUDED.error,
|
|
created_at = NOW()
|
|
""", (
|
|
result.document_id,
|
|
result.invoice_number,
|
|
result.invoice_date,
|
|
result.invoice_due_date,
|
|
result.ocr_number,
|
|
result.bankgiro,
|
|
result.plusgiro,
|
|
result.amount,
|
|
result.supplier_organisation_number,
|
|
result.raw_response,
|
|
result.model_used,
|
|
result.processing_time_ms,
|
|
result.error,
|
|
))
|
|
conn.commit()
|
|
|
|
def validate_document(
|
|
self,
|
|
doc_id: str,
|
|
provider: str = "openai",
|
|
model: str = None
|
|
) -> LLMExtractionResult:
|
|
"""
|
|
Validate a single document using LLM.
|
|
|
|
Args:
|
|
doc_id: Document ID
|
|
provider: LLM provider ("openai" or "anthropic")
|
|
model: Model to use (defaults based on provider)
|
|
|
|
Returns:
|
|
Extraction result
|
|
"""
|
|
# Get PDF path
|
|
pdf_path = self.pdf_dir / f"{doc_id}.pdf"
|
|
if not pdf_path.exists():
|
|
return LLMExtractionResult(
|
|
document_id=doc_id,
|
|
error=f"PDF not found: {pdf_path}"
|
|
)
|
|
|
|
# Render first page
|
|
try:
|
|
image_bytes = self.render_pdf_to_image(pdf_path, page_no=0)
|
|
except Exception as e:
|
|
return LLMExtractionResult(
|
|
document_id=doc_id,
|
|
error=f"Failed to render PDF: {str(e)}"
|
|
)
|
|
|
|
# Extract with LLM
|
|
if provider == "openai":
|
|
model = model or "gpt-4o"
|
|
result = self.extract_with_openai(image_bytes, model)
|
|
elif provider == "anthropic":
|
|
model = model or "claude-sonnet-4-20250514"
|
|
result = self.extract_with_anthropic(image_bytes, model)
|
|
else:
|
|
return LLMExtractionResult(
|
|
document_id=doc_id,
|
|
error=f"Unknown provider: {provider}"
|
|
)
|
|
|
|
result.document_id = doc_id
|
|
|
|
# Save to database
|
|
self.save_validation_result(result)
|
|
|
|
return result
|
|
|
|
def validate_batch(
|
|
self,
|
|
limit: int = 10,
|
|
provider: str = "openai",
|
|
model: str = None,
|
|
verbose: bool = True
|
|
) -> List[LLMExtractionResult]:
|
|
"""
|
|
Validate a batch of documents with failed matches.
|
|
|
|
Args:
|
|
limit: Maximum number of documents to validate
|
|
provider: LLM provider
|
|
model: Model to use
|
|
verbose: Print progress
|
|
|
|
Returns:
|
|
List of extraction results
|
|
"""
|
|
# Get documents to validate
|
|
docs = self.get_documents_with_failed_matches(limit=limit)
|
|
|
|
if verbose:
|
|
print(f"Found {len(docs)} documents with failed matches to validate")
|
|
|
|
results = []
|
|
for i, doc in enumerate(docs):
|
|
doc_id = doc['document_id']
|
|
|
|
if verbose:
|
|
failed_fields = [f['field'] for f in doc['failed_fields']]
|
|
print(f"[{i+1}/{len(docs)}] Validating {doc_id[:8]}... (failed: {', '.join(failed_fields)})")
|
|
|
|
result = self.validate_document(doc_id, provider, model)
|
|
results.append(result)
|
|
|
|
if verbose:
|
|
if result.error:
|
|
print(f" ERROR: {result.error}")
|
|
else:
|
|
print(f" OK ({result.processing_time_ms:.0f}ms)")
|
|
|
|
return results
|
|
|
|
def compare_results(self, doc_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Compare LLM extraction with autolabel results.
|
|
|
|
Args:
|
|
doc_id: Document ID
|
|
|
|
Returns:
|
|
Comparison results
|
|
"""
|
|
conn = self.connect()
|
|
with conn.cursor() as cursor:
|
|
# Get autolabel results
|
|
cursor.execute("""
|
|
SELECT field_name, csv_value, matched, matched_text
|
|
FROM field_results
|
|
WHERE document_id = %s
|
|
""", (doc_id,))
|
|
|
|
autolabel = {}
|
|
for row in cursor.fetchall():
|
|
autolabel[row[0]] = {
|
|
'csv_value': row[1],
|
|
'matched': row[2],
|
|
'matched_text': row[3],
|
|
}
|
|
|
|
# Get LLM results
|
|
cursor.execute("""
|
|
SELECT invoice_number, invoice_date, invoice_due_date,
|
|
ocr_number, bankgiro, plusgiro, amount,
|
|
supplier_organisation_number
|
|
FROM llm_validations
|
|
WHERE document_id = %s
|
|
""", (doc_id,))
|
|
|
|
row = cursor.fetchone()
|
|
if not row:
|
|
return {'error': 'No LLM validation found'}
|
|
|
|
llm = {
|
|
'InvoiceNumber': row[0],
|
|
'InvoiceDate': row[1],
|
|
'InvoiceDueDate': row[2],
|
|
'OCR': row[3],
|
|
'Bankgiro': row[4],
|
|
'Plusgiro': row[5],
|
|
'Amount': row[6],
|
|
'supplier_organisation_number': row[7],
|
|
}
|
|
|
|
# Compare
|
|
comparison = {}
|
|
for field in self.FIELDS_TO_EXTRACT:
|
|
auto = autolabel.get(field, {})
|
|
llm_value = llm.get(field)
|
|
|
|
comparison[field] = {
|
|
'csv_value': auto.get('csv_value'),
|
|
'autolabel_matched': auto.get('matched'),
|
|
'autolabel_text': auto.get('matched_text'),
|
|
'llm_value': llm_value,
|
|
'agreement': self._values_match(auto.get('csv_value'), llm_value),
|
|
}
|
|
|
|
return comparison
|
|
|
|
def _values_match(self, csv_value: str, llm_value: str) -> bool:
|
|
"""Check if CSV value matches LLM extracted value."""
|
|
if csv_value is None or llm_value is None:
|
|
return csv_value == llm_value
|
|
|
|
# Normalize for comparison
|
|
csv_norm = str(csv_value).strip().lower().replace('-', '').replace(' ', '')
|
|
llm_norm = str(llm_value).strip().lower().replace('-', '').replace(' ', '')
|
|
|
|
return csv_norm == llm_norm
|