Compare commits

3 Commits

Author SHA1 Message Date
02290bb935 WIP 2025-08-13 23:30:22 +02:00
87ba009bd7 Vector. 2025-08-11 21:38:25 +02:00
f87834a1b3 Add vscode debug config 2025-08-11 16:58:45 +02:00
8 changed files with 50 additions and 53 deletions

16
.vscode/launch.json vendored Normal file
View File

@@ -0,0 +1,16 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Python: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": [
"app.main:app",
"--reload"
],
"justMyCode": true
}
]
}

View File

@@ -15,11 +15,14 @@ The document's primary language is '{language}'.
## Instructions: ## Instructions:
Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field. Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field.
- `date`: Extract in YYYY-MM-DD format. If unclear, leave as an empty string. - `invoice_date`: The invoice date. Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
- `invoice_due_date`: The invoice due date.Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
- `invoice_number`: If not found or unclear, leave as an empty string. - `invoice_number`: If not found or unclear, leave as an empty string.
- `ocr_number`: The OCR number from the invoice. If not found or unclear, leave as an empty string.
- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string. - `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string.
- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string. - `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string.
- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null. - `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null.
- `tax_exclusive_amount`: Extract the the amount excluding taxes and format it to a decimal number. If not present, leave as null.
- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string. - `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string.
- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string. - `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string.
- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string. - `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string.
@@ -33,7 +36,7 @@ Carefully analyze the invoice image and extract the following fields according t
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string. - `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string. - `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string. - `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null. - `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. A list of all line items from the invoice. Make sure all of them are extracted. If a value is not present, leave it as null.
## Example: ## Example:
If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be: If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be:

View File

@@ -1,29 +1,25 @@
# app/agents/vectorization_agent.py # app/agents/vectorization_agent.py
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
import chromadb
client = chromadb.PersistentClient(path="./chroma_db")
vector_store = client.get_or_create_collection(name="documents")
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, chunk_size=1000,
chunk_overlap=100, chunk_overlap=100,
) )
def agent_vectorize_and_store(doc_id: str, text: str, category: str, language: str): def agent_vectorize_and_store(
""" doc_id: str,
Agent 4: Vectorizes a document and stores it in ChromaDB. text: str,
""" category: str,
language: str,
embedding_model,
vector_store
):
print(f"--- [Background Task] Starting vectorization (ID: {doc_id})...") print(f"--- [Background Task] Starting vectorization (ID: {doc_id})...")
try: try:
return
chunks = text_splitter.split_text(text) chunks = text_splitter.split_text(text)
if not chunks: if not chunks:
print(f"--- [Background Task] document {doc_id} has no text to vectorize.") print(f"--- [Background task] document is empty, skip vectorization. (ID: {doc_id})")
return return
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))] chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
@@ -38,6 +34,6 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str, language: s
documents=chunks, documents=chunks,
metadatas=metadatas metadatas=metadatas
) )
print(f"--- [Background Task] Document {doc_id} vectorized and stored successfully.") print(f"--- [Background Task] Document {doc_id} vectorized")
except Exception as e: except Exception as e:
print(f"--- [background Task] Vectorization failed (ID: {doc_id}): {e}") print(f"--- [Background Task] Document vectorization failed (ID: {doc_id}): {e}")

View File

@@ -4,20 +4,16 @@ from typing import List
def extract_text_from_images(images: List[Image.Image]) -> str: def extract_text_from_images(images: List[Image.Image]) -> str:
""" print("--- [Core OCR] Extracting text...")
使用Tesseract OCR从一系列图片中提取并合并所有文本。
"""
print("--- [Core OCR] 正在从图片中提取文本用于向量化...")
full_text = [] full_text = []
for img in images: for img in images:
try: try:
# lang='chi_sim+eng' 表示同时识别简体中文和英文
text = pytesseract.image_to_string(img, lang='chi_sim+eng') text = pytesseract.image_to_string(img, lang='chi_sim+eng')
full_text.append(text) full_text.append(text)
except Exception as e: except Exception as e:
print(f"--- [Core OCR] 单页处理失败: {e}") print(f"--- [Core OCR] Processing image failed: {e}")
continue continue
combined_text = "\n\n--- Page Break ---\n\n".join(full_text) combined_text = "\n\n--- Page Break ---\n\n".join(full_text)
print("--- [Core OCR] 文本提取成功。") print("--- [Core OCR] Text extraction completed.")
return combined_text return combined_text

View File

@@ -1,47 +1,28 @@
# app/core/vector_store.py
import os import os
import chromadb import chromadb
from dotenv import load_dotenv from dotenv import load_dotenv
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
load_dotenv() load_dotenv()
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower() LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
embedding_model = None embedding_model = None
print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---") print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
if LLM_PROVIDER == "azure": if LLM_PROVIDER == "azure":
required_vars = [
"AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
"OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"
]
if not all(os.getenv(var) for var in required_vars):
raise ValueError("One or more Azure OpenAI environment variables for embeddings are not set.")
embedding_model = AzureOpenAIEmbeddings( embedding_model = AzureOpenAIEmbeddings(
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("OPENAI_API_VERSION"), api_version=os.getenv("OPENAI_API_VERSION"),
azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"), azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"),
) )
elif LLM_PROVIDER == "openai": elif LLM_PROVIDER == "openai":
if not os.getenv("OPENAI_API_KEY"):
raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
embedding_model = OpenAIEmbeddings( embedding_model = OpenAIEmbeddings(
api_key=os.getenv("OPENAI_API_KEY"), api_key=os.getenv("OPENAI_API_KEY"),
model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small") model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small")
) )
else: else:
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.") raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}.")
client = chromadb.PersistentClient(path="./chroma_db") client = chromadb.PersistentClient(path="./chroma_db")
vector_store = client.get_or_create_collection( vector_store = client.get_or_create_collection(name="documents")
name="documents",
metadata={"hnsw:space": "cosine"}
)

View File

@@ -5,10 +5,10 @@ from typing import Dict, Any, List
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from PIL import Image from PIL import Image
from io import BytesIO from io import BytesIO
from .. import agents from .. import agents
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
from ..core.ocr import extract_text_from_images from ..core.ocr import extract_text_from_images
from ..core.vector_store import embedding_model, vector_store
# Create an APIRouter instance # Create an APIRouter instance
router = APIRouter( router = APIRouter(
@@ -102,10 +102,12 @@ async def upload_and_process_document(
full_text = await run_in_threadpool(extract_text_from_images, images) full_text = await run_in_threadpool(extract_text_from_images, images)
background_tasks.add_task( background_tasks.add_task(
agents.agent_vectorize_and_store, agents.agent_vectorize_and_store,
doc_id, doc_id=doc_id,
full_text, text=full_text,
category, category=category,
language language=language,
embedding_model=embedding_model,
vector_store=vector_store
) )
print("--- [Main] Vectorization job added to background tasks.") print("--- [Main] Vectorization job added to background tasks.")
@@ -118,4 +120,4 @@ async def upload_and_process_document(
async def get_result(doc_id: str): async def get_result(doc_id: str):
if doc_id in db_results: if doc_id in db_results:
return db_results[doc_id] return db_results[doc_id]
raise HTTPException(status_code=404, detail="Document not found.") raise HTTPException(status_code=404, detail="Document not found.")

View File

@@ -32,11 +32,14 @@ class LineItem(BaseModel):
class InvoiceInfo(BaseModel): class InvoiceInfo(BaseModel):
"""Defines the detailed, structured information to be extracted from an invoice.""" """Defines the detailed, structured information to be extracted from an invoice."""
date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.") invoice_date: Optional[str] = Field("", description="The invoice date. Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
invoice_due_date: Optional[str] = Field("", description="The invoice due date.Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.") invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.")
ocr_number: Optional[str] = Field("", description="The OCR number from the invoice. If not found or unclear, leave as an empty string.")
supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.") supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.")
biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.") biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.")
amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.") amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.")
tax_exclusive_amount: Optional[float] = Field(None, description="Extract the the amount excluding taxes and format it to a decimal number. If not present, leave as null.")
customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.") customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.")
customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.") customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.")
customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.") customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.")
@@ -50,4 +53,4 @@ class InvoiceInfo(BaseModel):
bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.") bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.")
plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.") plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.")
customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.") customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.")
line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.") line_items: List[LineItem] = Field([], description="A list of all line items from the invoice. Make sure all of them are extracted.")

Binary file not shown.