208 lines
6.9 KiB
Python
208 lines
6.9 KiB
Python
"""
|
|
Tests for Document Category Feature.
|
|
|
|
TDD tests for adding category field to admin_documents table.
|
|
Documents can be categorized (e.g., invoice, letter, receipt) for training different models.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock
|
|
from uuid import UUID, uuid4
|
|
|
|
from inference.data.admin_models import AdminDocument
|
|
|
|
|
|
# Test constants
|
|
TEST_DOC_UUID = "550e8400-e29b-41d4-a716-446655440000"
|
|
TEST_TOKEN = "test-admin-token-12345"
|
|
|
|
|
|
class TestAdminDocumentCategoryField:
|
|
"""Tests for AdminDocument category field."""
|
|
|
|
def test_document_has_category_field(self):
|
|
"""Test AdminDocument model has category field."""
|
|
doc = AdminDocument(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/path/to/file.pdf",
|
|
)
|
|
assert hasattr(doc, "category")
|
|
|
|
def test_document_category_defaults_to_invoice(self):
|
|
"""Test category defaults to 'invoice' when not specified."""
|
|
doc = AdminDocument(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/path/to/file.pdf",
|
|
)
|
|
assert doc.category == "invoice"
|
|
|
|
def test_document_accepts_custom_category(self):
|
|
"""Test document accepts custom category values."""
|
|
categories = ["invoice", "letter", "receipt", "contract", "custom_type"]
|
|
|
|
for cat in categories:
|
|
doc = AdminDocument(
|
|
document_id=uuid4(),
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/path/to/file.pdf",
|
|
category=cat,
|
|
)
|
|
assert doc.category == cat
|
|
|
|
def test_document_category_is_string_type(self):
|
|
"""Test category field is a string type."""
|
|
doc = AdminDocument(
|
|
document_id=UUID(TEST_DOC_UUID),
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
content_type="application/pdf",
|
|
file_path="/path/to/file.pdf",
|
|
category="letter",
|
|
)
|
|
assert isinstance(doc.category, str)
|
|
|
|
|
|
class TestDocumentCategoryInReadModel:
|
|
"""Tests for category in response models."""
|
|
|
|
def test_admin_document_read_has_category(self):
|
|
"""Test AdminDocumentRead includes category field."""
|
|
from inference.data.admin_models import AdminDocumentRead
|
|
|
|
# Check the model has category field in its schema
|
|
assert "category" in AdminDocumentRead.model_fields
|
|
|
|
|
|
class TestDocumentCategoryAPI:
|
|
"""Tests for document category in API endpoints."""
|
|
|
|
@pytest.fixture
|
|
def mock_admin_db(self):
|
|
"""Create mock AdminDB."""
|
|
db = MagicMock()
|
|
db.is_valid_admin_token.return_value = True
|
|
return db
|
|
|
|
def test_upload_document_with_category(self, mock_admin_db):
|
|
"""Test uploading document with category parameter."""
|
|
from inference.web.schemas.admin import DocumentUploadResponse
|
|
|
|
# Verify response schema supports category
|
|
response = DocumentUploadResponse(
|
|
document_id=TEST_DOC_UUID,
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
page_count=1,
|
|
status="pending",
|
|
message="Upload successful",
|
|
category="letter",
|
|
)
|
|
assert response.category == "letter"
|
|
|
|
def test_list_documents_returns_category(self, mock_admin_db):
|
|
"""Test list documents endpoint returns category."""
|
|
from inference.web.schemas.admin import DocumentItem
|
|
|
|
item = DocumentItem(
|
|
document_id=TEST_DOC_UUID,
|
|
filename="test.pdf",
|
|
file_size=1024,
|
|
page_count=1,
|
|
status="pending",
|
|
annotation_count=0,
|
|
created_at=datetime.utcnow(),
|
|
updated_at=datetime.utcnow(),
|
|
category="invoice",
|
|
)
|
|
assert item.category == "invoice"
|
|
|
|
def test_document_detail_includes_category(self, mock_admin_db):
|
|
"""Test document detail response includes category."""
|
|
from inference.web.schemas.admin import DocumentDetailResponse
|
|
|
|
# Check schema has category
|
|
assert "category" in DocumentDetailResponse.model_fields
|
|
|
|
|
|
class TestDocumentCategoryFiltering:
|
|
"""Tests for filtering documents by category."""
|
|
|
|
@pytest.fixture
|
|
def mock_admin_db(self):
|
|
"""Create mock AdminDB with category filtering support."""
|
|
db = MagicMock()
|
|
db.is_valid_admin_token.return_value = True
|
|
|
|
# Mock documents with different categories
|
|
invoice_doc = MagicMock()
|
|
invoice_doc.document_id = uuid4()
|
|
invoice_doc.category = "invoice"
|
|
|
|
letter_doc = MagicMock()
|
|
letter_doc.document_id = uuid4()
|
|
letter_doc.category = "letter"
|
|
|
|
db.get_documents_by_category.return_value = [invoice_doc]
|
|
return db
|
|
|
|
def test_filter_documents_by_category(self, mock_admin_db):
|
|
"""Test filtering documents by category."""
|
|
# This tests the DB method signature
|
|
result = mock_admin_db.get_documents_by_category("invoice")
|
|
assert len(result) == 1
|
|
assert result[0].category == "invoice"
|
|
|
|
|
|
class TestDocumentCategoryUpdate:
|
|
"""Tests for updating document category."""
|
|
|
|
def test_update_document_category_schema(self):
|
|
"""Test update document request supports category."""
|
|
from inference.web.schemas.admin import DocumentUpdateRequest
|
|
|
|
request = DocumentUpdateRequest(category="letter")
|
|
assert request.category == "letter"
|
|
|
|
def test_update_document_category_optional(self):
|
|
"""Test category is optional in update request."""
|
|
from inference.web.schemas.admin import DocumentUpdateRequest
|
|
|
|
# Should not raise - category is optional
|
|
request = DocumentUpdateRequest()
|
|
assert request.category is None
|
|
|
|
|
|
class TestDatasetWithCategory:
|
|
"""Tests for dataset creation with category filtering."""
|
|
|
|
def test_dataset_create_with_category_filter(self):
|
|
"""Test creating dataset can filter by document category."""
|
|
from inference.web.schemas.admin import DatasetCreateRequest
|
|
|
|
request = DatasetCreateRequest(
|
|
name="Invoice Training Set",
|
|
document_ids=[TEST_DOC_UUID],
|
|
category="invoice", # Optional filter
|
|
)
|
|
assert request.category == "invoice"
|
|
|
|
def test_dataset_create_category_is_optional(self):
|
|
"""Test category filter is optional when creating dataset."""
|
|
from inference.web.schemas.admin import DatasetCreateRequest
|
|
|
|
request = DatasetCreateRequest(
|
|
name="Mixed Training Set",
|
|
document_ids=[TEST_DOC_UUID],
|
|
)
|
|
# category should be optional
|
|
assert not hasattr(request, "category") or request.category is None
|