Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 02290bb935 | |||
| 87ba009bd7 | |||
| f87834a1b3 |
16
.vscode/launch.json
vendored
Normal file
16
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python: FastAPI",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"module": "uvicorn",
|
||||||
|
"args": [
|
||||||
|
"app.main:app",
|
||||||
|
"--reload"
|
||||||
|
],
|
||||||
|
"justMyCode": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
@@ -15,11 +15,14 @@ The document's primary language is '{language}'.
|
|||||||
## Instructions:
|
## Instructions:
|
||||||
Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field.
|
Carefully analyze the invoice image and extract the following fields according to these specific rules. Do not invent information. If a field is not found or is unclear, follow the specific instruction for that field.
|
||||||
|
|
||||||
- `date`: Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
|
- `invoice_date`: The invoice date. Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
|
||||||
|
- `invoice_due_date`: The invoice due date.Extract in YYYY-MM-DD format. If unclear, leave as an empty string.
|
||||||
- `invoice_number`: If not found or unclear, leave as an empty string.
|
- `invoice_number`: If not found or unclear, leave as an empty string.
|
||||||
|
- `ocr_number`: The OCR number from the invoice. If not found or unclear, leave as an empty string.
|
||||||
- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string.
|
- `supplier_number`: This is the organisation number. If not found or unclear, leave as an empty string.
|
||||||
- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string.
|
- `biller_name`: This is the sender's name. If not found or unclear, leave as an empty string.
|
||||||
- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null.
|
- `amount`: Extract the final total amount and format it to a decimal number. If not present, leave as null.
|
||||||
|
- `tax_exclusive_amount`: Extract the the amount excluding taxes and format it to a decimal number. If not present, leave as null.
|
||||||
- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string.
|
- `customer_name`: This is the receiver's name. Ensure it is a name and clear any special characters. If not found or unclear, leave as an empty string.
|
||||||
- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string.
|
- `customer_address`: This is the receiver's full address. Put it in one line. If not found or unclear, leave as an empty string.
|
||||||
- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string.
|
- `customer_address_line`: This is only the street address line from the receiver's address. If not found or unclear, leave as an empty string.
|
||||||
@@ -33,7 +36,7 @@ Carefully analyze the invoice image and extract the following fields according t
|
|||||||
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
|
- `bank_giro`: If found, extract the bank giro number. It often follows patterns like 'ddd-dddd', 'dddd-dddd', or 'dddddddd #41#'. If not found or unclear, leave as an empty string.
|
||||||
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
|
- `plus_giro`: If found, extract the plus giro number. It often follows patterns like 'ddddddd-d #16#', 'ddddddd-d', or 'ddd dd dd-d'. If not found or unclear, leave as an empty string.
|
||||||
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
|
- `customer_ssn`: If found, extract the customer social security number (personnummer). It follows the pattern 'YYYYMMDD-XXXX' or 'YYMMDD-XXXX'. If not found or unclear, leave as an empty string.
|
||||||
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. If a value is not present, leave it as null.
|
- `line_items`: Extract all line items from the invoice. For each item, extract the `description`, `quantity`, `unit_price`, and `total_price`. A list of all line items from the invoice. Make sure all of them are extracted. If a value is not present, leave it as null.
|
||||||
|
|
||||||
## Example:
|
## Example:
|
||||||
If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be:
|
If the invoice shows a line item "Consulting Services | 2 hours | $100.00/hr | $200.00", the output for that line item should be:
|
||||||
|
|||||||
@@ -1,29 +1,25 @@
|
|||||||
# app/agents/vectorization_agent.py
|
# app/agents/vectorization_agent.py
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain_openai import OpenAIEmbeddings
|
|
||||||
embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
|
|
||||||
import chromadb
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path="./chroma_db")
|
|
||||||
vector_store = client.get_or_create_collection(name="documents")
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
chunk_size=1000,
|
chunk_size=1000,
|
||||||
chunk_overlap=100,
|
chunk_overlap=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
def agent_vectorize_and_store(doc_id: str, text: str, category: str, language: str):
|
def agent_vectorize_and_store(
|
||||||
"""
|
doc_id: str,
|
||||||
Agent 4: Vectorizes a document and stores it in ChromaDB.
|
text: str,
|
||||||
"""
|
category: str,
|
||||||
|
language: str,
|
||||||
|
embedding_model,
|
||||||
|
vector_store
|
||||||
|
):
|
||||||
print(f"--- [Background Task] Starting vectorization (ID: {doc_id})...")
|
print(f"--- [Background Task] Starting vectorization (ID: {doc_id})...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return
|
|
||||||
|
|
||||||
chunks = text_splitter.split_text(text)
|
chunks = text_splitter.split_text(text)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
print(f"--- [Background Task] document {doc_id} has no text to vectorize.")
|
print(f"--- [Background task] document is empty, skip vectorization. (ID: {doc_id})")
|
||||||
return
|
return
|
||||||
|
|
||||||
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
chunk_ids = [f"{doc_id}_{i}" for i in range(len(chunks))]
|
||||||
@@ -38,6 +34,6 @@ def agent_vectorize_and_store(doc_id: str, text: str, category: str, language: s
|
|||||||
documents=chunks,
|
documents=chunks,
|
||||||
metadatas=metadatas
|
metadatas=metadatas
|
||||||
)
|
)
|
||||||
print(f"--- [Background Task] Document {doc_id} vectorized and stored successfully.")
|
print(f"--- [Background Task] Document {doc_id} vectorized。")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"--- [background Task] Vectorization failed (ID: {doc_id}): {e}")
|
print(f"--- [Background Task] Document vectorization failed (ID: {doc_id}): {e}")
|
||||||
|
|||||||
@@ -4,20 +4,16 @@ from typing import List
|
|||||||
|
|
||||||
|
|
||||||
def extract_text_from_images(images: List[Image.Image]) -> str:
|
def extract_text_from_images(images: List[Image.Image]) -> str:
|
||||||
"""
|
print("--- [Core OCR] Extracting text...")
|
||||||
使用Tesseract OCR从一系列图片中提取并合并所有文本。
|
|
||||||
"""
|
|
||||||
print("--- [Core OCR] 正在从图片中提取文本用于向量化...")
|
|
||||||
full_text = []
|
full_text = []
|
||||||
for img in images:
|
for img in images:
|
||||||
try:
|
try:
|
||||||
# lang='chi_sim+eng' 表示同时识别简体中文和英文
|
|
||||||
text = pytesseract.image_to_string(img, lang='chi_sim+eng')
|
text = pytesseract.image_to_string(img, lang='chi_sim+eng')
|
||||||
full_text.append(text)
|
full_text.append(text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"--- [Core OCR] 单页处理失败: {e}")
|
print(f"--- [Core OCR] Processing image failed: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
combined_text = "\n\n--- Page Break ---\n\n".join(full_text)
|
combined_text = "\n\n--- Page Break ---\n\n".join(full_text)
|
||||||
print("--- [Core OCR] 文本提取成功。")
|
print("--- [Core OCR] Text extraction completed.")
|
||||||
return combined_text
|
return combined_text
|
||||||
|
|||||||
@@ -1,47 +1,28 @@
|
|||||||
# app/core/vector_store.py
|
|
||||||
import os
|
import os
|
||||||
import chromadb
|
import chromadb
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
|
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "openai").lower()
|
||||||
|
|
||||||
embedding_model = None
|
embedding_model = None
|
||||||
|
|
||||||
print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
|
print(f"--- [Core] Initializing Embeddings with provider: {LLM_PROVIDER} ---")
|
||||||
|
|
||||||
if LLM_PROVIDER == "azure":
|
if LLM_PROVIDER == "azure":
|
||||||
required_vars = [
|
|
||||||
"AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_API_KEY",
|
|
||||||
"OPENAI_API_VERSION", "AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"
|
|
||||||
]
|
|
||||||
if not all(os.getenv(var) for var in required_vars):
|
|
||||||
raise ValueError("One or more Azure OpenAI environment variables for embeddings are not set.")
|
|
||||||
|
|
||||||
embedding_model = AzureOpenAIEmbeddings(
|
embedding_model = AzureOpenAIEmbeddings(
|
||||||
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
||||||
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
||||||
api_version=os.getenv("OPENAI_API_VERSION"),
|
api_version=os.getenv("OPENAI_API_VERSION"),
|
||||||
azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"),
|
azure_deployment=os.getenv("AZURE_OPENAI_EMBEDDING_DEPLOYMENT_NAME"),
|
||||||
)
|
)
|
||||||
|
|
||||||
elif LLM_PROVIDER == "openai":
|
elif LLM_PROVIDER == "openai":
|
||||||
if not os.getenv("OPENAI_API_KEY"):
|
|
||||||
raise ValueError("OPENAI_API_KEY is not set for the 'openai' provider.")
|
|
||||||
|
|
||||||
embedding_model = OpenAIEmbeddings(
|
embedding_model = OpenAIEmbeddings(
|
||||||
api_key=os.getenv("OPENAI_API_KEY"),
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small")
|
model=os.getenv("OPENAI_EMBEDDING_MODEL_NAME", "text-embedding-3-small")
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}. Please use 'azure' or 'openai'.")
|
raise ValueError(f"Unsupported LLM_PROVIDER: {LLM_PROVIDER}.")
|
||||||
|
|
||||||
|
|
||||||
client = chromadb.PersistentClient(path="./chroma_db")
|
client = chromadb.PersistentClient(path="./chroma_db")
|
||||||
vector_store = client.get_or_create_collection(
|
vector_store = client.get_or_create_collection(name="documents")
|
||||||
name="documents",
|
|
||||||
metadata={"hnsw:space": "cosine"}
|
|
||||||
)
|
|
||||||
@@ -5,10 +5,10 @@ from typing import Dict, Any, List
|
|||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from .. import agents
|
from .. import agents
|
||||||
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
|
from ..core.pdf_processor import convert_pdf_to_images, image_to_base64_str
|
||||||
from ..core.ocr import extract_text_from_images
|
from ..core.ocr import extract_text_from_images
|
||||||
|
from ..core.vector_store import embedding_model, vector_store
|
||||||
|
|
||||||
# Create an APIRouter instance
|
# Create an APIRouter instance
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -102,10 +102,12 @@ async def upload_and_process_document(
|
|||||||
full_text = await run_in_threadpool(extract_text_from_images, images)
|
full_text = await run_in_threadpool(extract_text_from_images, images)
|
||||||
background_tasks.add_task(
|
background_tasks.add_task(
|
||||||
agents.agent_vectorize_and_store,
|
agents.agent_vectorize_and_store,
|
||||||
doc_id,
|
doc_id=doc_id,
|
||||||
full_text,
|
text=full_text,
|
||||||
category,
|
category=category,
|
||||||
language
|
language=language,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
vector_store=vector_store
|
||||||
)
|
)
|
||||||
print("--- [Main] Vectorization job added to background tasks.")
|
print("--- [Main] Vectorization job added to background tasks.")
|
||||||
|
|
||||||
@@ -118,4 +120,4 @@ async def upload_and_process_document(
|
|||||||
async def get_result(doc_id: str):
|
async def get_result(doc_id: str):
|
||||||
if doc_id in db_results:
|
if doc_id in db_results:
|
||||||
return db_results[doc_id]
|
return db_results[doc_id]
|
||||||
raise HTTPException(status_code=404, detail="Document not found.")
|
raise HTTPException(status_code=404, detail="Document not found.")
|
||||||
@@ -32,11 +32,14 @@ class LineItem(BaseModel):
|
|||||||
|
|
||||||
class InvoiceInfo(BaseModel):
|
class InvoiceInfo(BaseModel):
|
||||||
"""Defines the detailed, structured information to be extracted from an invoice."""
|
"""Defines the detailed, structured information to be extracted from an invoice."""
|
||||||
date: Optional[str] = Field("", description="Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
|
invoice_date: Optional[str] = Field("", description="The invoice date. Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
|
||||||
|
invoice_due_date: Optional[str] = Field("", description="The invoice due date.Extract in YYYY-MM-DD format. If unclear, leave as an empty string.")
|
||||||
invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.")
|
invoice_number: Optional[str] = Field("", description="If not found or unclear, leave as an empty string.")
|
||||||
|
ocr_number: Optional[str] = Field("", description="The OCR number from the invoice. If not found or unclear, leave as an empty string.")
|
||||||
supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.")
|
supplier_number: Optional[str] = Field("", description="It's the organisation number. If not found or unclear, leave as an empty string.")
|
||||||
biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.")
|
biller_name: Optional[str] = Field("", description="It's the sender's name. If not found or unclear, leave as an empty string.")
|
||||||
amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.")
|
amount: Optional[float] = Field(None, description="Extract and format to decimal. If not present, leave as null.")
|
||||||
|
tax_exclusive_amount: Optional[float] = Field(None, description="Extract the the amount excluding taxes and format it to a decimal number. If not present, leave as null.")
|
||||||
customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.")
|
customer_name: Optional[str] = Field("", description="It's the receiver's name. Clean any special chars from the name. If not found or unclear, leave as an empty string.")
|
||||||
customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.")
|
customer_address: Optional[str] = Field("", description="It's the receiver's address. Put it in one line. If not found or unclear, leave as an empty string.")
|
||||||
customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.")
|
customer_address_line: Optional[str] = Field("", description="It's the receiver's address line, not the whole address. If not found or unclear, leave as an empty string.")
|
||||||
@@ -50,4 +53,4 @@ class InvoiceInfo(BaseModel):
|
|||||||
bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.")
|
bank_giro: Optional[str] = Field("", description="BankGiro number, e.g., '123-4567'. If not found, leave as an empty string.")
|
||||||
plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.")
|
plus_giro: Optional[str] = Field("", description="PlusGiro number, e.g., '123456-7'. If not found, leave as an empty string.")
|
||||||
customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.")
|
customer_ssn: Optional[str] = Field("", description="Customer's social security number, e.g., 'YYYYMMDD-XXXX'. If not found, leave as an empty string.")
|
||||||
line_items: List[LineItem] = Field([], description="A list of all line items from the invoice.")
|
line_items: List[LineItem] = Field([], description="A list of all line items from the invoice. Make sure all of them are extracted.")
|
||||||
Binary file not shown.
Reference in New Issue
Block a user