Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1 @@
"""Repository integration tests."""

View File

@@ -0,0 +1,464 @@
"""
Annotation Repository Integration Tests
Tests AnnotationRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.annotation_repository import AnnotationRepository
class TestAnnotationRepositoryCreate:
"""Tests for annotation creation."""
def test_create_annotation(self, patched_session, sample_document):
"""Test creating a single annotation."""
repo = AnnotationRepository()
ann_id = repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.3,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=240,
bbox_width=160,
bbox_height=40,
text_value="INV-2024-001",
confidence=0.95,
source="auto",
)
assert ann_id is not None
ann = repo.get(ann_id)
assert ann is not None
assert ann.class_name == "invoice_number"
assert ann.text_value == "INV-2024-001"
assert ann.confidence == 0.95
assert ann.source == "auto"
def test_create_batch_annotations(self, patched_session, sample_document):
"""Test batch creation of annotations."""
repo = AnnotationRepository()
annotations_data = [
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 0,
"class_name": "invoice_number",
"x_center": 0.5,
"y_center": 0.1,
"width": 0.2,
"height": 0.05,
"bbox_x": 400,
"bbox_y": 80,
"bbox_width": 160,
"bbox_height": 40,
"text_value": "INV-001",
"confidence": 0.95,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 1,
"class_name": "invoice_date",
"x_center": 0.5,
"y_center": 0.2,
"width": 0.15,
"height": 0.04,
"bbox_x": 400,
"bbox_y": 160,
"bbox_width": 120,
"bbox_height": 32,
"text_value": "2024-01-15",
"confidence": 0.92,
},
{
"document_id": str(sample_document.document_id),
"page_number": 1,
"class_id": 6,
"class_name": "amount",
"x_center": 0.7,
"y_center": 0.8,
"width": 0.1,
"height": 0.04,
"bbox_x": 560,
"bbox_y": 640,
"bbox_width": 80,
"bbox_height": 32,
"text_value": "1500.00",
"confidence": 0.98,
},
]
ids = repo.create_batch(annotations_data)
assert len(ids) == 3
# Verify all annotations exist
for ann_id in ids:
ann = repo.get(ann_id)
assert ann is not None
class TestAnnotationRepositoryRead:
"""Tests for annotation retrieval."""
def test_get_nonexistent_annotation(self, patched_session):
"""Test getting an annotation that doesn't exist."""
repo = AnnotationRepository()
ann = repo.get(str(uuid4()))
assert ann is None
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotations for a document."""
repo = AnnotationRepository()
# Add another annotation
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.4,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=320,
bbox_width=120,
bbox_height=32,
text_value="2024-01-15",
)
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 2
# Should be ordered by class_id
assert annotations[0].class_id == 0
assert annotations[1].class_id == 1
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
"""Test getting annotations for a specific page."""
repo = AnnotationRepository()
# Create annotations on different pages
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
)
repo.create(
document_id=str(sample_document.document_id),
page_number=2,
class_id=6,
class_name="amount",
x_center=0.7,
y_center=0.8,
width=0.1,
height=0.04,
bbox_x=560,
bbox_y=640,
bbox_width=80,
bbox_height=32,
)
page1_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=1,
)
page2_annotations = repo.get_for_document(
str(sample_document.document_id),
page_number=2,
)
assert len(page1_annotations) == 1
assert len(page2_annotations) == 1
assert page1_annotations[0].page_number == 1
assert page2_annotations[0].page_number == 2
class TestAnnotationRepositoryUpdate:
"""Tests for annotation updates."""
def test_update_annotation_bbox(self, patched_session, sample_annotation):
"""Test updating annotation bounding box."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
x_center=0.6,
y_center=0.4,
width=0.25,
height=0.06,
bbox_x=480,
bbox_y=320,
bbox_width=200,
bbox_height=48,
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.x_center == 0.6
assert ann.y_center == 0.4
assert ann.bbox_x == 480
assert ann.bbox_width == 200
def test_update_annotation_text(self, patched_session, sample_annotation):
"""Test updating annotation text value."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
text_value="INV-2024-002",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.text_value == "INV-2024-002"
def test_update_annotation_class(self, patched_session, sample_annotation):
"""Test updating annotation class."""
repo = AnnotationRepository()
result = repo.update(
str(sample_annotation.annotation_id),
class_id=1,
class_name="invoice_date",
)
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is not None
assert ann.class_id == 1
assert ann.class_name == "invoice_date"
def test_update_nonexistent_annotation(self, patched_session):
"""Test updating annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.update(
str(uuid4()),
text_value="new value",
)
assert result is False
class TestAnnotationRepositoryDelete:
"""Tests for annotation deletion."""
def test_delete_annotation(self, patched_session, sample_annotation):
"""Test deleting a single annotation."""
repo = AnnotationRepository()
result = repo.delete(str(sample_annotation.annotation_id))
assert result is True
ann = repo.get(str(sample_annotation.annotation_id))
assert ann is None
def test_delete_nonexistent_annotation(self, patched_session):
"""Test deleting annotation that doesn't exist."""
repo = AnnotationRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_annotations_for_document(self, patched_session, sample_document):
"""Test deleting all annotations for a document."""
repo = AnnotationRepository()
# Create multiple annotations
for i in range(3):
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=i,
class_name=f"field_{i}",
x_center=0.5,
y_center=0.1 + i * 0.2,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80 + i * 160,
bbox_width=160,
bbox_height=40,
)
# Delete all
count = repo.delete_for_document(str(sample_document.document_id))
assert count == 3
annotations = repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_annotations_by_source(self, patched_session, sample_document):
"""Test deleting annotations by source type."""
repo = AnnotationRepository()
# Create auto and manual annotations
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=0,
class_name="invoice_number",
x_center=0.5,
y_center=0.1,
width=0.2,
height=0.05,
bbox_x=400,
bbox_y=80,
bbox_width=160,
bbox_height=40,
source="auto",
)
repo.create(
document_id=str(sample_document.document_id),
page_number=1,
class_id=1,
class_name="invoice_date",
x_center=0.5,
y_center=0.2,
width=0.15,
height=0.04,
bbox_x=400,
bbox_y=160,
bbox_width=120,
bbox_height=32,
source="manual",
)
# Delete only auto annotations
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
assert count == 1
remaining = repo.get_for_document(str(sample_document.document_id))
assert len(remaining) == 1
assert remaining[0].source == "manual"
class TestAnnotationVerification:
"""Tests for annotation verification."""
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
"""Test marking annotation as verified."""
repo = AnnotationRepository()
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
assert ann is not None
assert ann.is_verified is True
assert ann.verified_by == admin_token.token
assert ann.verified_at is not None
class TestAnnotationOverride:
"""Tests for annotation override functionality."""
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
"""Test overriding an auto-generated annotation."""
repo = AnnotationRepository()
# Override the annotation
ann = repo.override(
str(sample_annotation.annotation_id),
admin_token.token,
change_reason="Correcting OCR error",
text_value="INV-2024-CORRECTED",
x_center=0.55,
)
assert ann is not None
assert ann.text_value == "INV-2024-CORRECTED"
assert ann.x_center == 0.55
assert ann.source == "manual" # Changed from auto to manual
assert ann.override_source == "auto"
class TestAnnotationHistory:
"""Tests for annotation history tracking."""
def test_create_history_record(self, patched_session, sample_annotation):
"""Test creating annotation history record."""
repo = AnnotationRepository()
history = repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
changed_by="test-user",
)
assert history is not None
assert history.action == "created"
assert history.changed_by == "test-user"
def test_get_annotation_history(self, patched_session, sample_annotation):
"""Test getting history for an annotation."""
repo = AnnotationRepository()
# Create history records
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="created",
new_value={"text_value": "INV-001"},
)
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_annotation.document_id,
action="updated",
previous_value={"text_value": "INV-001"},
new_value={"text_value": "INV-002"},
)
history = repo.get_history(sample_annotation.annotation_id)
assert len(history) == 2
# Should be ordered by created_at desc
assert history[0].action == "updated"
assert history[1].action == "created"
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
"""Test getting all annotation history for a document."""
repo = AnnotationRepository()
repo.create_history(
annotation_id=sample_annotation.annotation_id,
document_id=sample_document.document_id,
action="created",
new_value={"class_name": "invoice_number"},
)
history = repo.get_document_history(sample_document.document_id)
assert len(history) >= 1
assert all(h.document_id == sample_document.document_id for h in history)

View File

@@ -0,0 +1,355 @@
"""
Batch Upload Repository Integration Tests
Tests BatchUploadRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
class TestBatchUploadCreate:
"""Tests for batch upload creation."""
def test_create_batch_upload(self, patched_session, admin_token):
"""Test creating a batch upload."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="test_batch.zip",
file_size=10240,
upload_source="api",
)
assert batch is not None
assert batch.batch_id is not None
assert batch.filename == "test_batch.zip"
assert batch.file_size == 10240
assert batch.upload_source == "api"
assert batch.status == "processing"
assert batch.total_files == 0
assert batch.processed_files == 0
def test_create_batch_upload_default_source(self, patched_session, admin_token):
"""Test creating batch upload with default source."""
repo = BatchUploadRepository()
batch = repo.create(
admin_token=admin_token.token,
filename="ui_batch.zip",
file_size=5120,
)
assert batch.upload_source == "ui"
class TestBatchUploadRead:
"""Tests for batch upload retrieval."""
def test_get_batch_upload(self, patched_session, sample_batch_upload):
"""Test getting a batch upload by ID."""
repo = BatchUploadRepository()
batch = repo.get(sample_batch_upload.batch_id)
assert batch is not None
assert batch.batch_id == sample_batch_upload.batch_id
assert batch.filename == sample_batch_upload.filename
def test_get_nonexistent_batch_upload(self, patched_session):
"""Test getting a batch upload that doesn't exist."""
repo = BatchUploadRepository()
batch = repo.get(uuid4())
assert batch is None
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
"""Test paginated batch upload listing."""
repo = BatchUploadRepository()
# Create multiple batches
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024 * (i + 1),
)
batches, total = repo.get_paginated(limit=3, offset=0)
assert total == 5
assert len(batches) == 3
def test_get_paginated_with_offset(self, patched_session, admin_token):
"""Test pagination offset."""
repo = BatchUploadRepository()
for i in range(5):
repo.create(
admin_token=admin_token.token,
filename=f"batch_{i}.zip",
file_size=1024,
)
page1, _ = repo.get_paginated(limit=2, offset=0)
page2, _ = repo.get_paginated(limit=2, offset=2)
ids_page1 = {b.batch_id for b in page1}
ids_page2 = {b.batch_id for b in page2}
assert len(ids_page1 & ids_page2) == 0
class TestBatchUploadUpdate:
"""Tests for batch upload updates."""
def test_update_batch_status(self, patched_session, sample_batch_upload):
"""Test updating batch upload status."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="completed",
total_files=10,
processed_files=10,
successful_files=8,
failed_files=2,
)
# Need to commit to see changes
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "completed"
assert batch.total_files == 10
assert batch.successful_files == 8
assert batch.failed_files == 2
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch upload with error message."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
status="failed",
error_message="ZIP extraction failed",
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.status == "failed"
assert batch.error_message == "ZIP extraction failed"
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
"""Test updating batch with CSV information."""
repo = BatchUploadRepository()
repo.update(
sample_batch_upload.batch_id,
csv_filename="manifest.csv",
csv_row_count=100,
)
patched_session.commit()
batch = repo.get(sample_batch_upload.batch_id)
assert batch.csv_filename == "manifest.csv"
assert batch.csv_row_count == 100
class TestBatchUploadFiles:
"""Tests for batch upload file management."""
def test_create_batch_file(self, patched_session, sample_batch_upload):
"""Test creating a batch upload file record."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_001.pdf",
status="pending",
)
assert file_record is not None
assert file_record.file_id is not None
assert file_record.filename == "invoice_001.pdf"
assert file_record.batch_id == sample_batch_upload.batch_id
assert file_record.status == "pending"
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
"""Test creating batch file linked to a document."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_linked.pdf",
document_id=sample_document.document_id,
status="completed",
annotation_count=5,
)
assert file_record.document_id == sample_document.document_id
assert file_record.status == "completed"
assert file_record.annotation_count == 5
def test_get_batch_files(self, patched_session, sample_batch_upload):
"""Test getting all files for a batch."""
repo = BatchUploadRepository()
# Create multiple files
for i in range(3):
repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename=f"file_{i}.pdf",
)
files = repo.get_files(sample_batch_upload.batch_id)
assert len(files) == 3
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
"""Test getting files for batch with no files."""
repo = BatchUploadRepository()
files = repo.get_files(sample_batch_upload.batch_id)
assert files == []
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
"""Test updating batch file status."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="test.pdf",
)
repo.update_file(
file_record.file_id,
status="completed",
annotation_count=10,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "completed"
assert updated_file.annotation_count == 10
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
"""Test updating batch file with error."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="corrupt.pdf",
)
repo.update_file(
file_record.file_id,
status="failed",
error_message="Invalid PDF format",
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.status == "failed"
assert updated_file.error_message == "Invalid PDF format"
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
"""Test updating batch file with CSV row data."""
repo = BatchUploadRepository()
file_record = repo.create_file(
batch_id=sample_batch_upload.batch_id,
filename="invoice_with_csv.pdf",
)
csv_data = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier": "Test Corp",
}
repo.update_file(
file_record.file_id,
csv_row_data=csv_data,
)
patched_session.commit()
files = repo.get_files(sample_batch_upload.batch_id)
updated_file = files[0]
assert updated_file.csv_row_data == csv_data
class TestBatchUploadWorkflow:
"""Tests for complete batch upload workflows."""
def test_complete_batch_workflow(self, patched_session, admin_token):
"""Test complete batch upload workflow."""
repo = BatchUploadRepository()
# 1. Create batch
batch = repo.create(
admin_token=admin_token.token,
filename="full_workflow.zip",
file_size=50000,
)
# 2. Update with file count
repo.update(batch.batch_id, total_files=3)
patched_session.commit()
# 3. Create file records
file_ids = []
for i in range(3):
file_record = repo.create_file(
batch_id=batch.batch_id,
filename=f"doc_{i}.pdf",
)
file_ids.append(file_record.file_id)
# 4. Process files one by one
for i, file_id in enumerate(file_ids):
status = "completed" if i < 2 else "failed"
repo.update_file(
file_id,
status=status,
annotation_count=5 if status == "completed" else 0,
)
# 5. Update batch progress
repo.update(
batch.batch_id,
processed_files=3,
successful_files=2,
failed_files=1,
status="partial",
)
patched_session.commit()
# Verify final state
final_batch = repo.get(batch.batch_id)
assert final_batch.status == "partial"
assert final_batch.total_files == 3
assert final_batch.processed_files == 3
assert final_batch.successful_files == 2
assert final_batch.failed_files == 1
files = repo.get_files(batch.batch_id)
assert len(files) == 3
completed = [f for f in files if f.status == "completed"]
failed = [f for f in files if f.status == "failed"]
assert len(completed) == 2
assert len(failed) == 1

View File

@@ -0,0 +1,321 @@
"""
Dataset Repository Integration Tests
Tests DatasetRepository with real database operations.
"""
from uuid import uuid4
import pytest
from inference.data.repositories.dataset_repository import DatasetRepository
class TestDatasetRepositoryCreate:
"""Tests for dataset creation."""
def test_create_dataset(self, patched_session):
"""Test creating a training dataset."""
repo = DatasetRepository()
dataset = repo.create(
name="Test Dataset",
description="Dataset for integration testing",
train_ratio=0.8,
val_ratio=0.1,
seed=42,
)
assert dataset is not None
assert dataset.name == "Test Dataset"
assert dataset.description == "Dataset for integration testing"
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
assert dataset.status == "building"
def test_create_dataset_with_defaults(self, patched_session):
"""Test creating dataset with default values."""
repo = DatasetRepository()
dataset = repo.create(name="Minimal Dataset")
assert dataset is not None
assert dataset.train_ratio == 0.8
assert dataset.val_ratio == 0.1
assert dataset.seed == 42
class TestDatasetRepositoryRead:
"""Tests for dataset retrieval."""
def test_get_dataset_by_id(self, patched_session, sample_dataset):
"""Test getting dataset by ID."""
repo = DatasetRepository()
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_id == sample_dataset.dataset_id
assert dataset.name == sample_dataset.name
def test_get_nonexistent_dataset(self, patched_session):
"""Test getting dataset that doesn't exist."""
repo = DatasetRepository()
dataset = repo.get(str(uuid4()))
assert dataset is None
def test_get_paginated_datasets(self, patched_session):
"""Test paginated dataset listing."""
repo = DatasetRepository()
# Create multiple datasets
for i in range(5):
repo.create(name=f"Dataset {i}")
datasets, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(datasets) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering datasets by status."""
repo = DatasetRepository()
# Create datasets with different statuses
d1 = repo.create(name="Building Dataset")
repo.update_status(str(d1.dataset_id), "ready")
d2 = repo.create(name="Another Building Dataset")
# stays as "building"
datasets, total = repo.get_paginated(status="ready")
assert total == 1
assert datasets[0].status == "ready"
class TestDatasetRepositoryUpdate:
"""Tests for dataset updates."""
def test_update_status(self, patched_session, sample_dataset):
"""Test updating dataset status."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
total_documents=100,
total_images=150,
total_annotations=500,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "ready"
assert dataset.total_documents == 100
assert dataset.total_images == 150
assert dataset.total_annotations == 500
def test_update_status_with_error(self, patched_session, sample_dataset):
"""Test updating dataset status with error message."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="failed",
error_message="Failed to build dataset: insufficient documents",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.status == "failed"
assert "insufficient documents" in dataset.error_message
def test_update_status_with_path(self, patched_session, sample_dataset):
"""Test updating dataset path."""
repo = DatasetRepository()
repo.update_status(
str(sample_dataset.dataset_id),
status="ready",
dataset_path="/datasets/test_dataset_2024",
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.dataset_path == "/datasets/test_dataset_2024"
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
"""Test updating dataset training status."""
repo = DatasetRepository()
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="running",
active_training_task_id=str(sample_training_task.task_id),
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "running"
assert dataset.active_training_task_id == sample_training_task.task_id
def test_update_training_status_completed(self, patched_session, sample_dataset):
"""Test updating training status to completed updates main status."""
repo = DatasetRepository()
# First set to ready
repo.update_status(str(sample_dataset.dataset_id), status="ready")
# Then complete training
repo.update_training_status(
str(sample_dataset.dataset_id),
training_status="completed",
update_main_status=True,
)
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is not None
assert dataset.training_status == "completed"
assert dataset.status == "trained"
class TestDatasetDocuments:
"""Tests for dataset document management."""
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
"""Test adding documents to a dataset."""
repo = DatasetRepository()
documents_data = [
{
"document_id": str(multiple_documents[0].document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
},
{
"document_id": str(multiple_documents[1].document_id),
"split": "train",
"page_count": 2,
"annotation_count": 8,
},
{
"document_id": str(multiple_documents[2].document_id),
"split": "val",
"page_count": 1,
"annotation_count": 3,
},
]
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
# Verify documents were added
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 3
train_docs = [d for d in docs if d.split == "train"]
val_docs = [d for d in docs if d.split == "val"]
assert len(train_docs) == 2
assert len(val_docs) == 1
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
"""Test getting documents from a dataset."""
repo = DatasetRepository()
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 1
assert docs[0].document_id == sample_document.document_id
assert docs[0].split == "train"
assert docs[0].page_count == 1
assert docs[0].annotation_count == 5
class TestDatasetRepositoryDelete:
"""Tests for dataset deletion."""
def test_delete_dataset(self, patched_session, sample_dataset):
"""Test deleting a dataset."""
repo = DatasetRepository()
result = repo.delete(str(sample_dataset.dataset_id))
assert result is True
dataset = repo.get(str(sample_dataset.dataset_id))
assert dataset is None
def test_delete_nonexistent_dataset(self, patched_session):
"""Test deleting dataset that doesn't exist."""
repo = DatasetRepository()
result = repo.delete(str(uuid4()))
assert result is False
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
"""Test deleting dataset also removes document links."""
repo = DatasetRepository()
# Add document to dataset
repo.add_documents(
str(sample_dataset.dataset_id),
[
{
"document_id": str(sample_document.document_id),
"split": "train",
"page_count": 1,
"annotation_count": 5,
}
],
)
# Delete dataset
repo.delete(str(sample_dataset.dataset_id))
# Document links should be gone
docs = repo.get_documents(str(sample_dataset.dataset_id))
assert len(docs) == 0
class TestActiveTrainingTasks:
"""Tests for active training task queries."""
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
"""Test getting active training tasks for datasets."""
repo = DatasetRepository()
# Update task to running
from inference.data.repositories.training_task_repository import TrainingTaskRepository
task_repo = TrainingTaskRepository()
task_repo.update_status(str(sample_training_task.task_id), "running")
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
assert str(sample_dataset.dataset_id) in result
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
"""Test getting active training tasks returns empty when no tasks exist."""
repo = DatasetRepository()
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
# No training task exists for this dataset, so result should be empty
assert str(sample_dataset.dataset_id) not in result
assert result == {}

View File

@@ -0,0 +1,350 @@
"""
Document Repository Integration Tests
Tests DocumentRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from sqlmodel import select
from inference.data.admin_models import AdminAnnotation, AdminDocument
from inference.data.repositories.document_repository import DocumentRepository
def ensure_utc(dt: datetime | None) -> datetime | None:
"""Ensure datetime is timezone-aware (UTC).
PostgreSQL may return offset-naive datetimes. This helper
converts them to UTC for proper comparison.
"""
if dt is None:
return None
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
class TestDocumentRepositoryCreate:
"""Tests for document creation."""
def test_create_document(self, patched_session):
"""Test creating a document and retrieving it."""
repo = DocumentRepository()
doc_id = repo.create(
filename="test_invoice.pdf",
file_size=2048,
content_type="application/pdf",
file_path="/uploads/test_invoice.pdf",
page_count=2,
upload_source="api",
category="invoice",
)
assert doc_id is not None
doc = repo.get(doc_id)
assert doc is not None
assert doc.filename == "test_invoice.pdf"
assert doc.file_size == 2048
assert doc.page_count == 2
assert doc.upload_source == "api"
assert doc.category == "invoice"
assert doc.status == "pending"
def test_create_document_with_csv_values(self, patched_session):
"""Test creating document with CSV field values."""
repo = DocumentRepository()
csv_values = {
"invoice_number": "INV-001",
"amount": "1500.00",
"supplier_name": "Test Supplier AB",
}
doc_id = repo.create(
filename="invoice_with_csv.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/invoice_with_csv.pdf",
csv_field_values=csv_values,
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.csv_field_values == csv_values
def test_create_document_with_group_key(self, patched_session):
"""Test creating document with group key."""
repo = DocumentRepository()
doc_id = repo.create(
filename="grouped_doc.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/uploads/grouped_doc.pdf",
group_key="batch-2024-01",
)
doc = repo.get(doc_id)
assert doc is not None
assert doc.group_key == "batch-2024-01"
class TestDocumentRepositoryRead:
"""Tests for document retrieval."""
def test_get_nonexistent_document(self, patched_session):
"""Test getting a document that doesn't exist."""
repo = DocumentRepository()
doc = repo.get(str(uuid4()))
assert doc is None
def test_get_paginated_documents(self, patched_session, multiple_documents):
"""Test paginated document listing."""
repo = DocumentRepository()
docs, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(docs) == 2
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
"""Test filtering documents by status."""
repo = DocumentRepository()
docs, total = repo.get_paginated(status="labeled")
assert total == 2
for doc in docs:
assert doc.status == "labeled"
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
"""Test filtering documents by category."""
repo = DocumentRepository()
docs, total = repo.get_paginated(category="letter")
assert total == 1
assert docs[0].category == "letter"
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
"""Test pagination offset."""
repo = DocumentRepository()
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
assert len(doc_ids_page1 & doc_ids_page2) == 0
def test_get_by_ids(self, patched_session, multiple_documents):
"""Test getting multiple documents by IDs."""
repo = DocumentRepository()
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
docs = repo.get_by_ids(ids_to_fetch)
assert len(docs) == 2
fetched_ids = {str(d.document_id) for d in docs}
assert fetched_ids == set(ids_to_fetch)
class TestDocumentRepositoryUpdate:
"""Tests for document updates."""
def test_update_status(self, patched_session, sample_document):
"""Test updating document status."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="labeled",
auto_label_status="completed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.status == "labeled"
assert doc.auto_label_status == "completed"
def test_update_status_with_error(self, patched_session, sample_document):
"""Test updating document status with error message."""
repo = DocumentRepository()
repo.update_status(
str(sample_document.document_id),
status="pending",
auto_label_status="failed",
auto_label_error="OCR extraction failed",
)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.auto_label_status == "failed"
assert doc.auto_label_error == "OCR extraction failed"
def test_update_file_path(self, patched_session, sample_document):
"""Test updating document file path."""
repo = DocumentRepository()
new_path = "/archive/2024/test_invoice.pdf"
repo.update_file_path(str(sample_document.document_id), new_path)
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.file_path == new_path
def test_update_group_key(self, patched_session, sample_document):
"""Test updating document group key."""
repo = DocumentRepository()
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is not None
assert doc.group_key == "new-group-key"
def test_update_category(self, patched_session, sample_document):
"""Test updating document category."""
repo = DocumentRepository()
doc = repo.update_category(str(sample_document.document_id), "letter")
assert doc is not None
assert doc.category == "letter"
class TestDocumentRepositoryDelete:
"""Tests for document deletion."""
def test_delete_document(self, patched_session, sample_document):
"""Test deleting a document."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
doc = repo.get(str(sample_document.document_id))
assert doc is None
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
"""Test deleting document also deletes its annotations."""
repo = DocumentRepository()
result = repo.delete(str(sample_document.document_id))
assert result is True
# Verify annotation is also deleted
from inference.data.repositories.annotation_repository import AnnotationRepository
ann_repo = AnnotationRepository()
annotations = ann_repo.get_for_document(str(sample_document.document_id))
assert len(annotations) == 0
def test_delete_nonexistent_document(self, patched_session):
"""Test deleting a document that doesn't exist."""
repo = DocumentRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestDocumentRepositoryQueries:
"""Tests for complex document queries."""
def test_count_by_status(self, patched_session, multiple_documents):
"""Test counting documents by status."""
repo = DocumentRepository()
counts = repo.count_by_status()
assert counts.get("pending") == 2
assert counts.get("labeled") == 2
assert counts.get("exported") == 1
def test_get_categories(self, patched_session, multiple_documents):
"""Test getting unique categories."""
repo = DocumentRepository()
categories = repo.get_categories()
assert "invoice" in categories
assert "letter" in categories
def test_get_labeled_for_export(self, patched_session, multiple_documents):
"""Test getting labeled documents for export."""
repo = DocumentRepository()
docs = repo.get_labeled_for_export()
assert len(docs) == 2
for doc in docs:
assert doc.status == "labeled"
class TestDocumentAnnotationLocking:
"""Tests for annotation locking mechanism."""
def test_acquire_annotation_lock(self, patched_session, sample_document):
"""Test acquiring annotation lock."""
repo = DocumentRepository()
doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
assert doc is not None
assert doc.annotation_lock_until is not None
lock_until = ensure_utc(doc.annotation_lock_until)
assert lock_until > datetime.now(timezone.utc)
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
"""Test acquiring lock fails when already locked."""
repo = DocumentRepository()
# First lock
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
# Second lock attempt should fail
result = repo.acquire_annotation_lock(str(sample_document.document_id))
assert result is None
def test_release_annotation_lock(self, patched_session, sample_document):
"""Test releasing annotation lock."""
repo = DocumentRepository()
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
doc = repo.release_annotation_lock(str(sample_document.document_id))
assert doc is not None
assert doc.annotation_lock_until is None
def test_extend_annotation_lock(self, patched_session, sample_document):
"""Test extending annotation lock."""
repo = DocumentRepository()
# Acquire initial lock
initial_doc = repo.acquire_annotation_lock(
str(sample_document.document_id),
duration_seconds=300,
)
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
# Extend lock
extended_doc = repo.extend_annotation_lock(
str(sample_document.document_id),
additional_seconds=300,
)
assert extended_doc is not None
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
assert extended_expiry > initial_expiry

View File

@@ -0,0 +1,310 @@
"""
Model Version Repository Integration Tests
Tests ModelVersionRepository with real database operations.
"""
from datetime import datetime, timezone
from uuid import uuid4
import pytest
from inference.data.repositories.model_version_repository import ModelVersionRepository
class TestModelVersionCreate:
"""Tests for model version creation."""
def test_create_model_version(self, patched_session):
"""Test creating a model version."""
repo = ModelVersionRepository()
model = repo.create(
version="1.0.0",
name="Invoice Extractor v1",
model_path="/models/invoice_v1.pt",
description="Initial production model",
metrics_mAP=0.92,
metrics_precision=0.89,
metrics_recall=0.85,
document_count=1000,
file_size=50000000,
)
assert model is not None
assert model.version == "1.0.0"
assert model.name == "Invoice Extractor v1"
assert model.model_path == "/models/invoice_v1.pt"
assert model.metrics_mAP == 0.92
assert model.is_active is False
assert model.status == "inactive"
def test_create_model_version_with_training_info(
self, patched_session, sample_training_task, sample_dataset
):
"""Test creating model version linked to training task and dataset."""
repo = ModelVersionRepository()
model = repo.create(
version="1.1.0",
name="Invoice Extractor v1.1",
model_path="/models/invoice_v1.1.pt",
task_id=sample_training_task.task_id,
dataset_id=sample_dataset.dataset_id,
training_config={"epochs": 100, "batch_size": 16},
trained_at=datetime.now(timezone.utc),
)
assert model is not None
assert model.task_id == sample_training_task.task_id
assert model.dataset_id == sample_dataset.dataset_id
assert model.training_config["epochs"] == 100
class TestModelVersionRead:
"""Tests for model version retrieval."""
def test_get_model_version_by_id(self, patched_session, sample_model_version):
"""Test getting model version by ID."""
repo = ModelVersionRepository()
model = repo.get(str(sample_model_version.version_id))
assert model is not None
assert model.version_id == sample_model_version.version_id
def test_get_nonexistent_model_version(self, patched_session):
"""Test getting model version that doesn't exist."""
repo = ModelVersionRepository()
model = repo.get(str(uuid4()))
assert model is None
def test_get_paginated_model_versions(self, patched_session):
"""Test paginated model version listing."""
repo = ModelVersionRepository()
# Create multiple versions
for i in range(5):
repo.create(
version=f"1.{i}.0",
name=f"Model v1.{i}",
model_path=f"/models/model_v1.{i}.pt",
)
models, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(models) == 2
def test_get_paginated_with_status_filter(self, patched_session):
"""Test filtering model versions by status."""
repo = ModelVersionRepository()
# Create active and inactive models
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
repo.activate(str(m1.version_id))
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
active_models, active_total = repo.get_paginated(status="active")
inactive_models, inactive_total = repo.get_paginated(status="inactive")
assert active_total == 1
assert inactive_total == 1
class TestModelVersionActivation:
"""Tests for model version activation."""
def test_activate_model_version(self, patched_session, sample_model_version):
"""Test activating a model version."""
repo = ModelVersionRepository()
model = repo.activate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is True
assert model.status == "active"
assert model.activated_at is not None
def test_activate_deactivates_others(self, patched_session):
"""Test that activating one version deactivates others."""
repo = ModelVersionRepository()
# Create and activate first model
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
repo.activate(str(m1.version_id))
# Create and activate second model
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
repo.activate(str(m2.version_id))
# Check first model is now inactive
m1_after = repo.get(str(m1.version_id))
assert m1_after.is_active is False
assert m1_after.status == "inactive"
# Check second model is active
m2_after = repo.get(str(m2.version_id))
assert m2_after.is_active is True
def test_get_active_model(self, patched_session, sample_model_version):
"""Test getting the currently active model."""
repo = ModelVersionRepository()
# Initially no active model
active = repo.get_active()
assert active is None
# Activate model
repo.activate(str(sample_model_version.version_id))
# Now should return active model
active = repo.get_active()
assert active is not None
assert active.version_id == sample_model_version.version_id
def test_deactivate_model_version(self, patched_session, sample_model_version):
"""Test deactivating a model version."""
repo = ModelVersionRepository()
# First activate
repo.activate(str(sample_model_version.version_id))
# Then deactivate
model = repo.deactivate(str(sample_model_version.version_id))
assert model is not None
assert model.is_active is False
assert model.status == "inactive"
class TestModelVersionUpdate:
"""Tests for model version updates."""
def test_update_model_metadata(self, patched_session, sample_model_version):
"""Test updating model version metadata."""
repo = ModelVersionRepository()
model = repo.update(
str(sample_model_version.version_id),
name="Updated Model Name",
description="Updated description",
)
assert model is not None
assert model.name == "Updated Model Name"
assert model.description == "Updated description"
def test_update_model_status(self, patched_session, sample_model_version):
"""Test updating model version status."""
repo = ModelVersionRepository()
model = repo.update(str(sample_model_version.version_id), status="deprecated")
assert model is not None
assert model.status == "deprecated"
def test_update_nonexistent_model(self, patched_session):
"""Test updating model that doesn't exist."""
repo = ModelVersionRepository()
model = repo.update(str(uuid4()), name="New Name")
assert model is None
class TestModelVersionArchive:
"""Tests for model version archiving."""
def test_archive_model_version(self, patched_session, sample_model_version):
"""Test archiving an inactive model version."""
repo = ModelVersionRepository()
model = repo.archive(str(sample_model_version.version_id))
assert model is not None
assert model.status == "archived"
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be archived."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to archive
model = repo.archive(str(sample_model_version.version_id))
assert model is None
# Verify model is still active
current = repo.get(str(sample_model_version.version_id))
assert current.status == "active"
class TestModelVersionDelete:
"""Tests for model version deletion."""
def test_delete_inactive_model(self, patched_session, sample_model_version):
"""Test deleting an inactive model version."""
repo = ModelVersionRepository()
result = repo.delete(str(sample_model_version.version_id))
assert result is True
model = repo.get(str(sample_model_version.version_id))
assert model is None
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
"""Test that active model cannot be deleted."""
repo = ModelVersionRepository()
# Activate the model
repo.activate(str(sample_model_version.version_id))
# Try to delete
result = repo.delete(str(sample_model_version.version_id))
assert result is False
# Verify model still exists
model = repo.get(str(sample_model_version.version_id))
assert model is not None
def test_delete_nonexistent_model(self, patched_session):
"""Test deleting model that doesn't exist."""
repo = ModelVersionRepository()
result = repo.delete(str(uuid4()))
assert result is False
class TestOnlyOneActiveModel:
"""Tests to verify only one model can be active at a time."""
def test_single_active_model_constraint(self, patched_session):
"""Test that only one model can be active at any time."""
repo = ModelVersionRepository()
# Create multiple models
models = []
for i in range(3):
m = repo.create(
version=f"1.{i}.0",
name=f"Model {i}",
model_path=f"/models/model_{i}.pt",
)
models.append(m)
# Activate each model in sequence
for model in models:
repo.activate(str(model.version_id))
# Count active models
all_models, _ = repo.get_paginated(status="active")
assert len(all_models) == 1
# Verify it's the last one activated
assert all_models[0].version_id == models[-1].version_id

View File

@@ -0,0 +1,274 @@
"""
Token Repository Integration Tests
Tests TokenRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
import pytest
from inference.data.repositories.token_repository import TokenRepository
class TestTokenCreate:
"""Tests for token creation."""
def test_create_new_token(self, patched_session):
"""Test creating a new admin token."""
repo = TokenRepository()
repo.create(
token="new-test-token-abc123",
name="New Test Admin",
)
token = repo.get("new-test-token-abc123")
assert token is not None
assert token.token == "new-test-token-abc123"
assert token.name == "New Test Admin"
assert token.is_active is True
assert token.expires_at is None
def test_create_token_with_expiration(self, patched_session):
"""Test creating token with expiration date."""
repo = TokenRepository()
expiry = datetime.now(timezone.utc) + timedelta(days=30)
repo.create(
token="expiring-token-xyz789",
name="Expiring Token",
expires_at=expiry,
)
token = repo.get("expiring-token-xyz789")
assert token is not None
assert token.expires_at is not None
def test_create_updates_existing_token(self, patched_session, admin_token):
"""Test creating with existing token updates it."""
repo = TokenRepository()
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
repo.create(
token=admin_token.token,
name="Updated Admin Name",
expires_at=new_expiry,
)
token = repo.get(admin_token.token)
assert token is not None
assert token.name == "Updated Admin Name"
assert token.is_active is True
class TestTokenValidation:
"""Tests for token validation."""
def test_is_valid_active_token(self, patched_session, admin_token):
"""Test that active token is valid."""
repo = TokenRepository()
result = repo.is_valid(admin_token.token)
assert result is True
def test_is_valid_nonexistent_token(self, patched_session):
"""Test that nonexistent token is invalid."""
repo = TokenRepository()
result = repo.is_valid("nonexistent-token-12345")
assert result is False
def test_is_valid_deactivated_token(self, patched_session, admin_token):
"""Test that deactivated token is invalid."""
repo = TokenRepository()
repo.deactivate(admin_token.token)
result = repo.is_valid(admin_token.token)
assert result is False
def test_is_valid_expired_token(self, patched_session):
"""Test that expired token is invalid."""
repo = TokenRepository()
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
repo.create(
token="expired-token-test",
name="Expired Token",
expires_at=past_expiry,
)
result = repo.is_valid("expired-token-test")
assert result is False
def test_is_valid_not_yet_expired_token(self, patched_session):
"""Test that not-yet-expired token is valid."""
repo = TokenRepository()
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
repo.create(
token="valid-expiring-token",
name="Valid Expiring Token",
expires_at=future_expiry,
)
result = repo.is_valid("valid-expiring-token")
assert result is True
class TestTokenGet:
"""Tests for token retrieval."""
def test_get_existing_token(self, patched_session, admin_token):
"""Test getting an existing token."""
repo = TokenRepository()
token = repo.get(admin_token.token)
assert token is not None
assert token.token == admin_token.token
assert token.name == admin_token.name
def test_get_nonexistent_token(self, patched_session):
"""Test getting a token that doesn't exist."""
repo = TokenRepository()
token = repo.get("nonexistent-token-xyz")
assert token is None
class TestTokenDeactivate:
"""Tests for token deactivation."""
def test_deactivate_existing_token(self, patched_session, admin_token):
"""Test deactivating an existing token."""
repo = TokenRepository()
result = repo.deactivate(admin_token.token)
assert result is True
token = repo.get(admin_token.token)
assert token is not None
assert token.is_active is False
def test_deactivate_nonexistent_token(self, patched_session):
"""Test deactivating a token that doesn't exist."""
repo = TokenRepository()
result = repo.deactivate("nonexistent-token-abc")
assert result is False
def test_reactivate_deactivated_token(self, patched_session, admin_token):
"""Test reactivating a deactivated token via create."""
repo = TokenRepository()
# Deactivate first
repo.deactivate(admin_token.token)
assert repo.is_valid(admin_token.token) is False
# Reactivate via create
repo.create(
token=admin_token.token,
name="Reactivated Admin",
)
assert repo.is_valid(admin_token.token) is True
class TestTokenUsageTracking:
"""Tests for token usage tracking."""
def test_update_usage(self, patched_session, admin_token):
"""Test updating token last used timestamp."""
repo = TokenRepository()
# Initially last_used_at might be None
initial_token = repo.get(admin_token.token)
initial_last_used = initial_token.last_used_at
repo.update_usage(admin_token.token)
updated_token = repo.get(admin_token.token)
assert updated_token.last_used_at is not None
if initial_last_used:
assert updated_token.last_used_at >= initial_last_used
def test_update_usage_nonexistent_token(self, patched_session):
"""Test updating usage for nonexistent token does nothing."""
repo = TokenRepository()
# Should not raise, just does nothing
repo.update_usage("nonexistent-token-usage")
token = repo.get("nonexistent-token-usage")
assert token is None
class TestTokenWorkflow:
"""Tests for complete token workflows."""
def test_full_token_lifecycle(self, patched_session):
"""Test complete token lifecycle: create, validate, use, deactivate."""
repo = TokenRepository()
token_str = "lifecycle-test-token"
# 1. Create token
repo.create(token=token_str, name="Lifecycle Token")
assert repo.is_valid(token_str) is True
# 2. Use token
repo.update_usage(token_str)
token = repo.get(token_str)
assert token.last_used_at is not None
# 3. Update token info
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
repo.create(
token=token_str,
name="Updated Lifecycle Token",
expires_at=new_expiry,
)
token = repo.get(token_str)
assert token.name == "Updated Lifecycle Token"
# 4. Deactivate token
result = repo.deactivate(token_str)
assert result is True
assert repo.is_valid(token_str) is False
# 5. Reactivate token
repo.create(token=token_str, name="Reactivated Token")
assert repo.is_valid(token_str) is True
def test_multiple_tokens(self, patched_session):
"""Test managing multiple tokens."""
repo = TokenRepository()
# Create multiple tokens
tokens = [
("token-a", "Admin A"),
("token-b", "Admin B"),
("token-c", "Admin C"),
]
for token_str, name in tokens:
repo.create(token=token_str, name=name)
# Verify all are valid
for token_str, _ in tokens:
assert repo.is_valid(token_str) is True
# Deactivate one
repo.deactivate("token-b")
# Verify states
assert repo.is_valid("token-a") is True
assert repo.is_valid("token-b") is False
assert repo.is_valid("token-c") is True

View File

@@ -0,0 +1,364 @@
"""
Training Task Repository Integration Tests
Tests TrainingTaskRepository with real database operations.
"""
from datetime import datetime, timezone, timedelta
from uuid import uuid4
import pytest
from inference.data.repositories.training_task_repository import TrainingTaskRepository
class TestTrainingTaskCreate:
"""Tests for training task creation."""
def test_create_training_task(self, patched_session, admin_token):
"""Test creating a training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Test Training Task",
task_type="train",
description="Integration test training task",
config={"epochs": 100, "batch_size": 16},
)
assert task_id is not None
task = repo.get(task_id)
assert task is not None
assert task.name == "Test Training Task"
assert task.task_type == "train"
assert task.status == "pending"
assert task.config["epochs"] == 100
def test_create_scheduled_task(self, patched_session, admin_token):
"""Test creating a scheduled training task."""
repo = TrainingTaskRepository()
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
task_id = repo.create(
admin_token=admin_token.token,
name="Scheduled Task",
scheduled_at=scheduled_time,
)
task = repo.get(task_id)
assert task is not None
assert task.status == "scheduled"
assert task.scheduled_at is not None
def test_create_recurring_task(self, patched_session, admin_token):
"""Test creating a recurring training task."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Recurring Task",
cron_expression="0 2 * * *",
is_recurring=True,
)
task = repo.get(task_id)
assert task is not None
assert task.is_recurring is True
assert task.cron_expression == "0 2 * * *"
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
"""Test creating task linked to a dataset."""
repo = TrainingTaskRepository()
task_id = repo.create(
admin_token=admin_token.token,
name="Dataset Training Task",
dataset_id=str(sample_dataset.dataset_id),
)
task = repo.get(task_id)
assert task is not None
assert task.dataset_id == sample_dataset.dataset_id
class TestTrainingTaskRead:
"""Tests for training task retrieval."""
def test_get_task_by_id(self, patched_session, sample_training_task):
"""Test getting task by ID."""
repo = TrainingTaskRepository()
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.task_id == sample_training_task.task_id
def test_get_nonexistent_task(self, patched_session):
"""Test getting task that doesn't exist."""
repo = TrainingTaskRepository()
task = repo.get(str(uuid4()))
assert task is None
def test_get_paginated_tasks(self, patched_session, admin_token):
"""Test paginated task listing."""
repo = TrainingTaskRepository()
# Create multiple tasks
for i in range(5):
repo.create(admin_token=admin_token.token, name=f"Task {i}")
tasks, total = repo.get_paginated(limit=2, offset=0)
assert total == 5
assert len(tasks) == 2
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
"""Test filtering tasks by status."""
repo = TrainingTaskRepository()
# Create tasks with different statuses
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
repo.create(admin_token=admin_token.token, name="Pending Task")
tasks, total = repo.get_paginated(status="running")
assert total == 1
assert tasks[0].status == "running"
def test_get_pending_tasks(self, patched_session, admin_token):
"""Test getting pending tasks ready to run."""
repo = TrainingTaskRepository()
# Create pending task
repo.create(admin_token=admin_token.token, name="Ready Task")
# Create scheduled task in the past (should be included)
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Past Scheduled Task",
scheduled_at=past_time,
)
# Create scheduled task in the future (should not be included)
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
repo.create(
admin_token=admin_token.token,
name="Future Scheduled Task",
scheduled_at=future_time,
)
pending = repo.get_pending()
# Should include pending and past scheduled, not future scheduled
assert len(pending) >= 2
names = [t.name for t in pending]
assert "Ready Task" in names
assert "Past Scheduled Task" in names
def test_get_running_task(self, patched_session, admin_token):
"""Test getting currently running task."""
repo = TrainingTaskRepository()
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
repo.update_status(task_id, "running")
running = repo.get_running()
assert running is not None
assert running.status == "running"
def test_get_running_task_none(self, patched_session, admin_token):
"""Test getting running task when none is running."""
repo = TrainingTaskRepository()
repo.create(admin_token=admin_token.token, name="Pending Task")
running = repo.get_running()
assert running is None
class TestTrainingTaskUpdate:
"""Tests for training task updates."""
def test_update_status_to_running(self, patched_session, sample_training_task):
"""Test updating task status to running."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "running"
assert task.started_at is not None
def test_update_status_to_completed(self, patched_session, sample_training_task):
"""Test updating task status to completed."""
repo = TrainingTaskRepository()
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
repo.update_status(
str(sample_training_task.task_id),
"completed",
result_metrics=metrics,
model_path="/models/trained_model.pt",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "completed"
assert task.completed_at is not None
assert task.result_metrics["mAP"] == 0.92
assert task.model_path == "/models/trained_model.pt"
def test_update_status_to_failed(self, patched_session, sample_training_task):
"""Test updating task status to failed with error message."""
repo = TrainingTaskRepository()
repo.update_status(
str(sample_training_task.task_id),
"failed",
error_message="CUDA out of memory",
)
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "failed"
assert task.completed_at is not None
assert "CUDA out of memory" in task.error_message
def test_cancel_pending_task(self, patched_session, sample_training_task):
"""Test cancelling a pending task."""
repo = TrainingTaskRepository()
result = repo.cancel(str(sample_training_task.task_id))
assert result is True
task = repo.get(str(sample_training_task.task_id))
assert task is not None
assert task.status == "cancelled"
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
"""Test that running task cannot be cancelled."""
repo = TrainingTaskRepository()
repo.update_status(str(sample_training_task.task_id), "running")
result = repo.cancel(str(sample_training_task.task_id))
assert result is False
task = repo.get(str(sample_training_task.task_id))
assert task.status == "running"
class TestTrainingLogs:
"""Tests for training log management."""
def test_add_log_entry(self, patched_session, sample_training_task):
"""Test adding a training log entry."""
repo = TrainingTaskRepository()
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message="Starting training...",
details={"epoch": 1, "batch": 0},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 1
assert logs[0].level == "INFO"
assert logs[0].message == "Starting training..."
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
"""Test adding multiple log entries."""
repo = TrainingTaskRepository()
for i in range(5):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Epoch {i} completed",
details={"epoch": i, "loss": 0.5 - i * 0.1},
)
logs = repo.get_logs(str(sample_training_task.task_id))
assert len(logs) == 5
def test_get_logs_pagination(self, patched_session, sample_training_task):
"""Test paginated log retrieval."""
repo = TrainingTaskRepository()
for i in range(10):
repo.add_log(
str(sample_training_task.task_id),
level="INFO",
message=f"Log entry {i}",
)
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
assert len(logs) == 5
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
assert len(logs_page2) == 5
class TestDocumentLinks:
"""Tests for training document link management."""
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
"""Test creating a document link."""
repo = TrainingTaskRepository()
link = repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=sample_document.document_id,
annotation_snapshot={"count": 5, "verified": 3},
)
assert link is not None
assert link.task_id == sample_training_task.task_id
assert link.document_id == sample_document.document_id
assert link.annotation_snapshot["count"] == 5
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
"""Test getting all document links for a task."""
repo = TrainingTaskRepository()
for doc in multiple_documents[:3]:
repo.create_document_link(
task_id=sample_training_task.task_id,
document_id=doc.document_id,
)
links = repo.get_document_links(sample_training_task.task_id)
assert len(links) == 3
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
"""Test getting training tasks that used a document."""
repo = TrainingTaskRepository()
# Create multiple tasks using the same document
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
repo.create_document_link(
task_id=repo.get(task1_id).task_id,
document_id=sample_document.document_id,
)
repo.create_document_link(
task_id=repo.get(task2_id).task_id,
document_id=sample_document.document_id,
)
links = repo.get_document_training_tasks(sample_document.document_id)
assert len(links) == 2