Add more tests
This commit is contained in:
1
tests/integration/repositories/__init__.py
Normal file
1
tests/integration/repositories/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository integration tests."""
|
||||
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
Annotation Repository Integration Tests
|
||||
|
||||
Tests AnnotationRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
|
||||
class TestAnnotationRepositoryCreate:
|
||||
"""Tests for annotation creation."""
|
||||
|
||||
def test_create_annotation(self, patched_session, sample_document):
|
||||
"""Test creating a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann_id = repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.3,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=240,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
text_value="INV-2024-001",
|
||||
confidence=0.95,
|
||||
source="auto",
|
||||
)
|
||||
|
||||
assert ann_id is not None
|
||||
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
assert ann.class_name == "invoice_number"
|
||||
assert ann.text_value == "INV-2024-001"
|
||||
assert ann.confidence == 0.95
|
||||
assert ann.source == "auto"
|
||||
|
||||
def test_create_batch_annotations(self, patched_session, sample_document):
|
||||
"""Test batch creation of annotations."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
annotations_data = [
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 0,
|
||||
"class_name": "invoice_number",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.1,
|
||||
"width": 0.2,
|
||||
"height": 0.05,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 80,
|
||||
"bbox_width": 160,
|
||||
"bbox_height": 40,
|
||||
"text_value": "INV-001",
|
||||
"confidence": 0.95,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 1,
|
||||
"class_name": "invoice_date",
|
||||
"x_center": 0.5,
|
||||
"y_center": 0.2,
|
||||
"width": 0.15,
|
||||
"height": 0.04,
|
||||
"bbox_x": 400,
|
||||
"bbox_y": 160,
|
||||
"bbox_width": 120,
|
||||
"bbox_height": 32,
|
||||
"text_value": "2024-01-15",
|
||||
"confidence": 0.92,
|
||||
},
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"page_number": 1,
|
||||
"class_id": 6,
|
||||
"class_name": "amount",
|
||||
"x_center": 0.7,
|
||||
"y_center": 0.8,
|
||||
"width": 0.1,
|
||||
"height": 0.04,
|
||||
"bbox_x": 560,
|
||||
"bbox_y": 640,
|
||||
"bbox_width": 80,
|
||||
"bbox_height": 32,
|
||||
"text_value": "1500.00",
|
||||
"confidence": 0.98,
|
||||
},
|
||||
]
|
||||
|
||||
ids = repo.create_batch(annotations_data)
|
||||
|
||||
assert len(ids) == 3
|
||||
|
||||
# Verify all annotations exist
|
||||
for ann_id in ids:
|
||||
ann = repo.get(ann_id)
|
||||
assert ann is not None
|
||||
|
||||
|
||||
class TestAnnotationRepositoryRead:
|
||||
"""Tests for annotation retrieval."""
|
||||
|
||||
def test_get_nonexistent_annotation(self, patched_session):
|
||||
"""Test getting an annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.get(str(uuid4()))
|
||||
assert ann is None
|
||||
|
||||
def test_get_annotations_for_document(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Add another annotation
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.4,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=320,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
text_value="2024-01-15",
|
||||
)
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
|
||||
assert len(annotations) == 2
|
||||
# Should be ordered by class_id
|
||||
assert annotations[0].class_id == 0
|
||||
assert annotations[1].class_id == 1
|
||||
|
||||
def test_get_annotations_for_specific_page(self, patched_session, sample_document):
|
||||
"""Test getting annotations for a specific page."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create annotations on different pages
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=2,
|
||||
class_id=6,
|
||||
class_name="amount",
|
||||
x_center=0.7,
|
||||
y_center=0.8,
|
||||
width=0.1,
|
||||
height=0.04,
|
||||
bbox_x=560,
|
||||
bbox_y=640,
|
||||
bbox_width=80,
|
||||
bbox_height=32,
|
||||
)
|
||||
|
||||
page1_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=1,
|
||||
)
|
||||
page2_annotations = repo.get_for_document(
|
||||
str(sample_document.document_id),
|
||||
page_number=2,
|
||||
)
|
||||
|
||||
assert len(page1_annotations) == 1
|
||||
assert len(page2_annotations) == 1
|
||||
assert page1_annotations[0].page_number == 1
|
||||
assert page2_annotations[0].page_number == 2
|
||||
|
||||
|
||||
class TestAnnotationRepositoryUpdate:
|
||||
"""Tests for annotation updates."""
|
||||
|
||||
def test_update_annotation_bbox(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation bounding box."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
x_center=0.6,
|
||||
y_center=0.4,
|
||||
width=0.25,
|
||||
height=0.06,
|
||||
bbox_x=480,
|
||||
bbox_y=320,
|
||||
bbox_width=200,
|
||||
bbox_height=48,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.x_center == 0.6
|
||||
assert ann.y_center == 0.4
|
||||
assert ann.bbox_x == 480
|
||||
assert ann.bbox_width == 200
|
||||
|
||||
def test_update_annotation_text(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation text value."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
text_value="INV-2024-002",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-002"
|
||||
|
||||
def test_update_annotation_class(self, patched_session, sample_annotation):
|
||||
"""Test updating annotation class."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(sample_annotation.annotation_id),
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is not None
|
||||
assert ann.class_id == 1
|
||||
assert ann.class_name == "invoice_date"
|
||||
|
||||
def test_update_nonexistent_annotation(self, patched_session):
|
||||
"""Test updating annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.update(
|
||||
str(uuid4()),
|
||||
text_value="new value",
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAnnotationRepositoryDelete:
|
||||
"""Tests for annotation deletion."""
|
||||
|
||||
def test_delete_annotation(self, patched_session, sample_annotation):
|
||||
"""Test deleting a single annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(sample_annotation.annotation_id))
|
||||
assert result is True
|
||||
|
||||
ann = repo.get(str(sample_annotation.annotation_id))
|
||||
assert ann is None
|
||||
|
||||
def test_delete_nonexistent_annotation(self, patched_session):
|
||||
"""Test deleting annotation that doesn't exist."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_annotations_for_document(self, patched_session, sample_document):
|
||||
"""Test deleting all annotations for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create multiple annotations
|
||||
for i in range(3):
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=i,
|
||||
class_name=f"field_{i}",
|
||||
x_center=0.5,
|
||||
y_center=0.1 + i * 0.2,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80 + i * 160,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
)
|
||||
|
||||
# Delete all
|
||||
count = repo.delete_for_document(str(sample_document.document_id))
|
||||
|
||||
assert count == 3
|
||||
|
||||
annotations = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_annotations_by_source(self, patched_session, sample_document):
|
||||
"""Test deleting annotations by source type."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create auto and manual annotations
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=0,
|
||||
class_name="invoice_number",
|
||||
x_center=0.5,
|
||||
y_center=0.1,
|
||||
width=0.2,
|
||||
height=0.05,
|
||||
bbox_x=400,
|
||||
bbox_y=80,
|
||||
bbox_width=160,
|
||||
bbox_height=40,
|
||||
source="auto",
|
||||
)
|
||||
repo.create(
|
||||
document_id=str(sample_document.document_id),
|
||||
page_number=1,
|
||||
class_id=1,
|
||||
class_name="invoice_date",
|
||||
x_center=0.5,
|
||||
y_center=0.2,
|
||||
width=0.15,
|
||||
height=0.04,
|
||||
bbox_x=400,
|
||||
bbox_y=160,
|
||||
bbox_width=120,
|
||||
bbox_height=32,
|
||||
source="manual",
|
||||
)
|
||||
|
||||
# Delete only auto annotations
|
||||
count = repo.delete_for_document(str(sample_document.document_id), source="auto")
|
||||
|
||||
assert count == 1
|
||||
|
||||
remaining = repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(remaining) == 1
|
||||
assert remaining[0].source == "manual"
|
||||
|
||||
|
||||
class TestAnnotationVerification:
|
||||
"""Tests for annotation verification."""
|
||||
|
||||
def test_verify_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test marking annotation as verified."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
ann = repo.verify(str(sample_annotation.annotation_id), admin_token.token)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.is_verified is True
|
||||
assert ann.verified_by == admin_token.token
|
||||
assert ann.verified_at is not None
|
||||
|
||||
|
||||
class TestAnnotationOverride:
|
||||
"""Tests for annotation override functionality."""
|
||||
|
||||
def test_override_auto_annotation(self, patched_session, admin_token, sample_annotation):
|
||||
"""Test overriding an auto-generated annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Override the annotation
|
||||
ann = repo.override(
|
||||
str(sample_annotation.annotation_id),
|
||||
admin_token.token,
|
||||
change_reason="Correcting OCR error",
|
||||
text_value="INV-2024-CORRECTED",
|
||||
x_center=0.55,
|
||||
)
|
||||
|
||||
assert ann is not None
|
||||
assert ann.text_value == "INV-2024-CORRECTED"
|
||||
assert ann.x_center == 0.55
|
||||
assert ann.source == "manual" # Changed from auto to manual
|
||||
assert ann.override_source == "auto"
|
||||
|
||||
|
||||
class TestAnnotationHistory:
|
||||
"""Tests for annotation history tracking."""
|
||||
|
||||
def test_create_history_record(self, patched_session, sample_annotation):
|
||||
"""Test creating annotation history record."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
history = repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
changed_by="test-user",
|
||||
)
|
||||
|
||||
assert history is not None
|
||||
assert history.action == "created"
|
||||
assert history.changed_by == "test-user"
|
||||
|
||||
def test_get_annotation_history(self, patched_session, sample_annotation):
|
||||
"""Test getting history for an annotation."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
# Create history records
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="created",
|
||||
new_value={"text_value": "INV-001"},
|
||||
)
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_annotation.document_id,
|
||||
action="updated",
|
||||
previous_value={"text_value": "INV-001"},
|
||||
new_value={"text_value": "INV-002"},
|
||||
)
|
||||
|
||||
history = repo.get_history(sample_annotation.annotation_id)
|
||||
|
||||
assert len(history) == 2
|
||||
# Should be ordered by created_at desc
|
||||
assert history[0].action == "updated"
|
||||
assert history[1].action == "created"
|
||||
|
||||
def test_get_document_history(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test getting all annotation history for a document."""
|
||||
repo = AnnotationRepository()
|
||||
|
||||
repo.create_history(
|
||||
annotation_id=sample_annotation.annotation_id,
|
||||
document_id=sample_document.document_id,
|
||||
action="created",
|
||||
new_value={"class_name": "invoice_number"},
|
||||
)
|
||||
|
||||
history = repo.get_document_history(sample_document.document_id)
|
||||
|
||||
assert len(history) >= 1
|
||||
assert all(h.document_id == sample_document.document_id for h in history)
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Batch Upload Repository Integration Tests
|
||||
|
||||
Tests BatchUploadRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.batch_upload_repository import BatchUploadRepository
|
||||
|
||||
|
||||
class TestBatchUploadCreate:
|
||||
"""Tests for batch upload creation."""
|
||||
|
||||
def test_create_batch_upload(self, patched_session, admin_token):
|
||||
"""Test creating a batch upload."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="test_batch.zip",
|
||||
file_size=10240,
|
||||
upload_source="api",
|
||||
)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id is not None
|
||||
assert batch.filename == "test_batch.zip"
|
||||
assert batch.file_size == 10240
|
||||
assert batch.upload_source == "api"
|
||||
assert batch.status == "processing"
|
||||
assert batch.total_files == 0
|
||||
assert batch.processed_files == 0
|
||||
|
||||
def test_create_batch_upload_default_source(self, patched_session, admin_token):
|
||||
"""Test creating batch upload with default source."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="ui_batch.zip",
|
||||
file_size=5120,
|
||||
)
|
||||
|
||||
assert batch.upload_source == "ui"
|
||||
|
||||
|
||||
class TestBatchUploadRead:
|
||||
"""Tests for batch upload retrieval."""
|
||||
|
||||
def test_get_batch_upload(self, patched_session, sample_batch_upload):
|
||||
"""Test getting a batch upload by ID."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
|
||||
assert batch is not None
|
||||
assert batch.batch_id == sample_batch_upload.batch_id
|
||||
assert batch.filename == sample_batch_upload.filename
|
||||
|
||||
def test_get_nonexistent_batch_upload(self, patched_session):
|
||||
"""Test getting a batch upload that doesn't exist."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
batch = repo.get(uuid4())
|
||||
assert batch is None
|
||||
|
||||
def test_get_paginated_batch_uploads(self, patched_session, admin_token):
|
||||
"""Test paginated batch upload listing."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple batches
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024 * (i + 1),
|
||||
)
|
||||
|
||||
batches, total = repo.get_paginated(limit=3, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(batches) == 3
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, admin_token):
|
||||
"""Test pagination offset."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename=f"batch_{i}.zip",
|
||||
file_size=1024,
|
||||
)
|
||||
|
||||
page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
ids_page1 = {b.batch_id for b in page1}
|
||||
ids_page2 = {b.batch_id for b in page2}
|
||||
|
||||
assert len(ids_page1 & ids_page2) == 0
|
||||
|
||||
|
||||
class TestBatchUploadUpdate:
|
||||
"""Tests for batch upload updates."""
|
||||
|
||||
def test_update_batch_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="completed",
|
||||
total_files=10,
|
||||
processed_files=10,
|
||||
successful_files=8,
|
||||
failed_files=2,
|
||||
)
|
||||
|
||||
# Need to commit to see changes
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "completed"
|
||||
assert batch.total_files == 10
|
||||
assert batch.successful_files == 8
|
||||
assert batch.failed_files == 2
|
||||
|
||||
def test_update_batch_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch upload with error message."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
status="failed",
|
||||
error_message="ZIP extraction failed",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.status == "failed"
|
||||
assert batch.error_message == "ZIP extraction failed"
|
||||
|
||||
def test_update_batch_csv_info(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch with CSV information."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
repo.update(
|
||||
sample_batch_upload.batch_id,
|
||||
csv_filename="manifest.csv",
|
||||
csv_row_count=100,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
batch = repo.get(sample_batch_upload.batch_id)
|
||||
assert batch.csv_filename == "manifest.csv"
|
||||
assert batch.csv_row_count == 100
|
||||
|
||||
|
||||
class TestBatchUploadFiles:
|
||||
"""Tests for batch upload file management."""
|
||||
|
||||
def test_create_batch_file(self, patched_session, sample_batch_upload):
|
||||
"""Test creating a batch upload file record."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_001.pdf",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
assert file_record is not None
|
||||
assert file_record.file_id is not None
|
||||
assert file_record.filename == "invoice_001.pdf"
|
||||
assert file_record.batch_id == sample_batch_upload.batch_id
|
||||
assert file_record.status == "pending"
|
||||
|
||||
def test_create_batch_file_with_document_link(self, patched_session, sample_batch_upload, sample_document):
|
||||
"""Test creating batch file linked to a document."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_linked.pdf",
|
||||
document_id=sample_document.document_id,
|
||||
status="completed",
|
||||
annotation_count=5,
|
||||
)
|
||||
|
||||
assert file_record.document_id == sample_document.document_id
|
||||
assert file_record.status == "completed"
|
||||
assert file_record.annotation_count == 5
|
||||
|
||||
def test_get_batch_files(self, patched_session, sample_batch_upload):
|
||||
"""Test getting all files for a batch."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# Create multiple files
|
||||
for i in range(3):
|
||||
repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename=f"file_{i}.pdf",
|
||||
)
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert len(files) == 3
|
||||
assert all(f.batch_id == sample_batch_upload.batch_id for f in files)
|
||||
|
||||
def test_get_batch_files_empty(self, patched_session, sample_batch_upload):
|
||||
"""Test getting files for batch with no files."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
|
||||
assert files == []
|
||||
|
||||
def test_update_batch_file_status(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file status."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="test.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="completed",
|
||||
annotation_count=10,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "completed"
|
||||
assert updated_file.annotation_count == 10
|
||||
|
||||
def test_update_batch_file_with_error(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with error."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="corrupt.pdf",
|
||||
)
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
status="failed",
|
||||
error_message="Invalid PDF format",
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.status == "failed"
|
||||
assert updated_file.error_message == "Invalid PDF format"
|
||||
|
||||
def test_update_batch_file_with_csv_data(self, patched_session, sample_batch_upload):
|
||||
"""Test updating batch file with CSV row data."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
file_record = repo.create_file(
|
||||
batch_id=sample_batch_upload.batch_id,
|
||||
filename="invoice_with_csv.pdf",
|
||||
)
|
||||
|
||||
csv_data = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier": "Test Corp",
|
||||
}
|
||||
|
||||
repo.update_file(
|
||||
file_record.file_id,
|
||||
csv_row_data=csv_data,
|
||||
)
|
||||
|
||||
patched_session.commit()
|
||||
|
||||
files = repo.get_files(sample_batch_upload.batch_id)
|
||||
updated_file = files[0]
|
||||
assert updated_file.csv_row_data == csv_data
|
||||
|
||||
|
||||
class TestBatchUploadWorkflow:
|
||||
"""Tests for complete batch upload workflows."""
|
||||
|
||||
def test_complete_batch_workflow(self, patched_session, admin_token):
|
||||
"""Test complete batch upload workflow."""
|
||||
repo = BatchUploadRepository()
|
||||
|
||||
# 1. Create batch
|
||||
batch = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
filename="full_workflow.zip",
|
||||
file_size=50000,
|
||||
)
|
||||
|
||||
# 2. Update with file count
|
||||
repo.update(batch.batch_id, total_files=3)
|
||||
patched_session.commit()
|
||||
|
||||
# 3. Create file records
|
||||
file_ids = []
|
||||
for i in range(3):
|
||||
file_record = repo.create_file(
|
||||
batch_id=batch.batch_id,
|
||||
filename=f"doc_{i}.pdf",
|
||||
)
|
||||
file_ids.append(file_record.file_id)
|
||||
|
||||
# 4. Process files one by one
|
||||
for i, file_id in enumerate(file_ids):
|
||||
status = "completed" if i < 2 else "failed"
|
||||
repo.update_file(
|
||||
file_id,
|
||||
status=status,
|
||||
annotation_count=5 if status == "completed" else 0,
|
||||
)
|
||||
|
||||
# 5. Update batch progress
|
||||
repo.update(
|
||||
batch.batch_id,
|
||||
processed_files=3,
|
||||
successful_files=2,
|
||||
failed_files=1,
|
||||
status="partial",
|
||||
)
|
||||
patched_session.commit()
|
||||
|
||||
# Verify final state
|
||||
final_batch = repo.get(batch.batch_id)
|
||||
assert final_batch.status == "partial"
|
||||
assert final_batch.total_files == 3
|
||||
assert final_batch.processed_files == 3
|
||||
assert final_batch.successful_files == 2
|
||||
assert final_batch.failed_files == 1
|
||||
|
||||
files = repo.get_files(batch.batch_id)
|
||||
assert len(files) == 3
|
||||
completed = [f for f in files if f.status == "completed"]
|
||||
failed = [f for f in files if f.status == "failed"]
|
||||
assert len(completed) == 2
|
||||
assert len(failed) == 1
|
||||
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
321
tests/integration/repositories/test_dataset_repo_integration.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""
|
||||
Dataset Repository Integration Tests
|
||||
|
||||
Tests DatasetRepository with real database operations.
|
||||
"""
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.dataset_repository import DatasetRepository
|
||||
|
||||
|
||||
class TestDatasetRepositoryCreate:
|
||||
"""Tests for dataset creation."""
|
||||
|
||||
def test_create_dataset(self, patched_session):
|
||||
"""Test creating a training dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(
|
||||
name="Test Dataset",
|
||||
description="Dataset for integration testing",
|
||||
train_ratio=0.8,
|
||||
val_ratio=0.1,
|
||||
seed=42,
|
||||
)
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.name == "Test Dataset"
|
||||
assert dataset.description == "Dataset for integration testing"
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
assert dataset.status == "building"
|
||||
|
||||
def test_create_dataset_with_defaults(self, patched_session):
|
||||
"""Test creating dataset with default values."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.create(name="Minimal Dataset")
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.train_ratio == 0.8
|
||||
assert dataset.val_ratio == 0.1
|
||||
assert dataset.seed == 42
|
||||
|
||||
|
||||
class TestDatasetRepositoryRead:
|
||||
"""Tests for dataset retrieval."""
|
||||
|
||||
def test_get_dataset_by_id(self, patched_session, sample_dataset):
|
||||
"""Test getting dataset by ID."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_id == sample_dataset.dataset_id
|
||||
assert dataset.name == sample_dataset.name
|
||||
|
||||
def test_get_nonexistent_dataset(self, patched_session):
|
||||
"""Test getting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
dataset = repo.get(str(uuid4()))
|
||||
assert dataset is None
|
||||
|
||||
def test_get_paginated_datasets(self, patched_session):
|
||||
"""Test paginated dataset listing."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create multiple datasets
|
||||
for i in range(5):
|
||||
repo.create(name=f"Dataset {i}")
|
||||
|
||||
datasets, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(datasets) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering datasets by status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Create datasets with different statuses
|
||||
d1 = repo.create(name="Building Dataset")
|
||||
repo.update_status(str(d1.dataset_id), "ready")
|
||||
|
||||
d2 = repo.create(name="Another Building Dataset")
|
||||
# stays as "building"
|
||||
|
||||
datasets, total = repo.get_paginated(status="ready")
|
||||
|
||||
assert total == 1
|
||||
assert datasets[0].status == "ready"
|
||||
|
||||
|
||||
class TestDatasetRepositoryUpdate:
|
||||
"""Tests for dataset updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
total_documents=100,
|
||||
total_images=150,
|
||||
total_annotations=500,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "ready"
|
||||
assert dataset.total_documents == 100
|
||||
assert dataset.total_images == 150
|
||||
assert dataset.total_annotations == 500
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset status with error message."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="failed",
|
||||
error_message="Failed to build dataset: insufficient documents",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.status == "failed"
|
||||
assert "insufficient documents" in dataset.error_message
|
||||
|
||||
def test_update_status_with_path(self, patched_session, sample_dataset):
|
||||
"""Test updating dataset path."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
status="ready",
|
||||
dataset_path="/datasets/test_dataset_2024",
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.dataset_path == "/datasets/test_dataset_2024"
|
||||
|
||||
def test_update_training_status(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test updating dataset training status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="running",
|
||||
active_training_task_id=str(sample_training_task.task_id),
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "running"
|
||||
assert dataset.active_training_task_id == sample_training_task.task_id
|
||||
|
||||
def test_update_training_status_completed(self, patched_session, sample_dataset):
|
||||
"""Test updating training status to completed updates main status."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# First set to ready
|
||||
repo.update_status(str(sample_dataset.dataset_id), status="ready")
|
||||
|
||||
# Then complete training
|
||||
repo.update_training_status(
|
||||
str(sample_dataset.dataset_id),
|
||||
training_status="completed",
|
||||
update_main_status=True,
|
||||
)
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is not None
|
||||
assert dataset.training_status == "completed"
|
||||
assert dataset.status == "trained"
|
||||
|
||||
|
||||
class TestDatasetDocuments:
|
||||
"""Tests for dataset document management."""
|
||||
|
||||
def test_add_documents_to_dataset(self, patched_session, sample_dataset, multiple_documents):
|
||||
"""Test adding documents to a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
documents_data = [
|
||||
{
|
||||
"document_id": str(multiple_documents[0].document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[1].document_id),
|
||||
"split": "train",
|
||||
"page_count": 2,
|
||||
"annotation_count": 8,
|
||||
},
|
||||
{
|
||||
"document_id": str(multiple_documents[2].document_id),
|
||||
"split": "val",
|
||||
"page_count": 1,
|
||||
"annotation_count": 3,
|
||||
},
|
||||
]
|
||||
|
||||
repo.add_documents(str(sample_dataset.dataset_id), documents_data)
|
||||
|
||||
# Verify documents were added
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 3
|
||||
|
||||
train_docs = [d for d in docs if d.split == "train"]
|
||||
val_docs = [d for d in docs if d.split == "val"]
|
||||
|
||||
assert len(train_docs) == 2
|
||||
assert len(val_docs) == 1
|
||||
|
||||
def test_get_dataset_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test getting documents from a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
|
||||
assert len(docs) == 1
|
||||
assert docs[0].document_id == sample_document.document_id
|
||||
assert docs[0].split == "train"
|
||||
assert docs[0].page_count == 1
|
||||
assert docs[0].annotation_count == 5
|
||||
|
||||
|
||||
class TestDatasetRepositoryDelete:
|
||||
"""Tests for dataset deletion."""
|
||||
|
||||
def test_delete_dataset(self, patched_session, sample_dataset):
|
||||
"""Test deleting a dataset."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(sample_dataset.dataset_id))
|
||||
assert result is True
|
||||
|
||||
dataset = repo.get(str(sample_dataset.dataset_id))
|
||||
assert dataset is None
|
||||
|
||||
def test_delete_nonexistent_dataset(self, patched_session):
|
||||
"""Test deleting dataset that doesn't exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
def test_delete_dataset_cascades_documents(self, patched_session, sample_dataset, sample_document):
|
||||
"""Test deleting dataset also removes document links."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Add document to dataset
|
||||
repo.add_documents(
|
||||
str(sample_dataset.dataset_id),
|
||||
[
|
||||
{
|
||||
"document_id": str(sample_document.document_id),
|
||||
"split": "train",
|
||||
"page_count": 1,
|
||||
"annotation_count": 5,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
# Delete dataset
|
||||
repo.delete(str(sample_dataset.dataset_id))
|
||||
|
||||
# Document links should be gone
|
||||
docs = repo.get_documents(str(sample_dataset.dataset_id))
|
||||
assert len(docs) == 0
|
||||
|
||||
|
||||
class TestActiveTrainingTasks:
|
||||
"""Tests for active training task queries."""
|
||||
|
||||
def test_get_active_training_tasks(self, patched_session, sample_dataset, sample_training_task):
|
||||
"""Test getting active training tasks for datasets."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
# Update task to running
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
task_repo = TrainingTaskRepository()
|
||||
task_repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
assert str(sample_dataset.dataset_id) in result
|
||||
assert result[str(sample_dataset.dataset_id)]["status"] == "running"
|
||||
|
||||
def test_get_active_training_tasks_empty(self, patched_session, sample_dataset):
|
||||
"""Test getting active training tasks returns empty when no tasks exist."""
|
||||
repo = DatasetRepository()
|
||||
|
||||
result = repo.get_active_training_tasks([str(sample_dataset.dataset_id)])
|
||||
|
||||
# No training task exists for this dataset, so result should be empty
|
||||
assert str(sample_dataset.dataset_id) not in result
|
||||
assert result == {}
|
||||
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
350
tests/integration/repositories/test_document_repo_integration.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
Document Repository Integration Tests
|
||||
|
||||
Tests DocumentRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlmodel import select
|
||||
|
||||
from inference.data.admin_models import AdminAnnotation, AdminDocument
|
||||
from inference.data.repositories.document_repository import DocumentRepository
|
||||
|
||||
|
||||
def ensure_utc(dt: datetime | None) -> datetime | None:
|
||||
"""Ensure datetime is timezone-aware (UTC).
|
||||
|
||||
PostgreSQL may return offset-naive datetimes. This helper
|
||||
converts them to UTC for proper comparison.
|
||||
"""
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
class TestDocumentRepositoryCreate:
|
||||
"""Tests for document creation."""
|
||||
|
||||
def test_create_document(self, patched_session):
|
||||
"""Test creating a document and retrieving it."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="test_invoice.pdf",
|
||||
file_size=2048,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/test_invoice.pdf",
|
||||
page_count=2,
|
||||
upload_source="api",
|
||||
category="invoice",
|
||||
)
|
||||
|
||||
assert doc_id is not None
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.filename == "test_invoice.pdf"
|
||||
assert doc.file_size == 2048
|
||||
assert doc.page_count == 2
|
||||
assert doc.upload_source == "api"
|
||||
assert doc.category == "invoice"
|
||||
assert doc.status == "pending"
|
||||
|
||||
def test_create_document_with_csv_values(self, patched_session):
|
||||
"""Test creating document with CSV field values."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
csv_values = {
|
||||
"invoice_number": "INV-001",
|
||||
"amount": "1500.00",
|
||||
"supplier_name": "Test Supplier AB",
|
||||
}
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="invoice_with_csv.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/invoice_with_csv.pdf",
|
||||
csv_field_values=csv_values,
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.csv_field_values == csv_values
|
||||
|
||||
def test_create_document_with_group_key(self, patched_session):
|
||||
"""Test creating document with group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc_id = repo.create(
|
||||
filename="grouped_doc.pdf",
|
||||
file_size=1024,
|
||||
content_type="application/pdf",
|
||||
file_path="/uploads/grouped_doc.pdf",
|
||||
group_key="batch-2024-01",
|
||||
)
|
||||
|
||||
doc = repo.get(doc_id)
|
||||
assert doc is not None
|
||||
assert doc.group_key == "batch-2024-01"
|
||||
|
||||
|
||||
class TestDocumentRepositoryRead:
|
||||
"""Tests for document retrieval."""
|
||||
|
||||
def test_get_nonexistent_document(self, patched_session):
|
||||
"""Test getting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.get(str(uuid4()))
|
||||
assert doc is None
|
||||
|
||||
def test_get_paginated_documents(self, patched_session, multiple_documents):
|
||||
"""Test paginated document listing."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(docs) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(status="labeled")
|
||||
|
||||
assert total == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
def test_get_paginated_with_category_filter(self, patched_session, multiple_documents):
|
||||
"""Test filtering documents by category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs, total = repo.get_paginated(category="letter")
|
||||
|
||||
assert total == 1
|
||||
assert docs[0].category == "letter"
|
||||
|
||||
def test_get_paginated_with_offset(self, patched_session, multiple_documents):
|
||||
"""Test pagination offset."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs_page1, _ = repo.get_paginated(limit=2, offset=0)
|
||||
docs_page2, _ = repo.get_paginated(limit=2, offset=2)
|
||||
|
||||
doc_ids_page1 = {str(d.document_id) for d in docs_page1}
|
||||
doc_ids_page2 = {str(d.document_id) for d in docs_page2}
|
||||
|
||||
assert len(doc_ids_page1 & doc_ids_page2) == 0
|
||||
|
||||
def test_get_by_ids(self, patched_session, multiple_documents):
|
||||
"""Test getting multiple documents by IDs."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
ids_to_fetch = [str(multiple_documents[0].document_id), str(multiple_documents[2].document_id)]
|
||||
docs = repo.get_by_ids(ids_to_fetch)
|
||||
|
||||
assert len(docs) == 2
|
||||
fetched_ids = {str(d.document_id) for d in docs}
|
||||
assert fetched_ids == set(ids_to_fetch)
|
||||
|
||||
|
||||
class TestDocumentRepositoryUpdate:
|
||||
"""Tests for document updates."""
|
||||
|
||||
def test_update_status(self, patched_session, sample_document):
|
||||
"""Test updating document status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="labeled",
|
||||
auto_label_status="completed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.status == "labeled"
|
||||
assert doc.auto_label_status == "completed"
|
||||
|
||||
def test_update_status_with_error(self, patched_session, sample_document):
|
||||
"""Test updating document status with error message."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_document.document_id),
|
||||
status="pending",
|
||||
auto_label_status="failed",
|
||||
auto_label_error="OCR extraction failed",
|
||||
)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.auto_label_status == "failed"
|
||||
assert doc.auto_label_error == "OCR extraction failed"
|
||||
|
||||
def test_update_file_path(self, patched_session, sample_document):
|
||||
"""Test updating document file path."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
new_path = "/archive/2024/test_invoice.pdf"
|
||||
repo.update_file_path(str(sample_document.document_id), new_path)
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.file_path == new_path
|
||||
|
||||
def test_update_group_key(self, patched_session, sample_document):
|
||||
"""Test updating document group key."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.update_group_key(str(sample_document.document_id), "new-group-key")
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is not None
|
||||
assert doc.group_key == "new-group-key"
|
||||
|
||||
def test_update_category(self, patched_session, sample_document):
|
||||
"""Test updating document category."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.update_category(str(sample_document.document_id), "letter")
|
||||
|
||||
assert doc is not None
|
||||
assert doc.category == "letter"
|
||||
|
||||
|
||||
class TestDocumentRepositoryDelete:
|
||||
"""Tests for document deletion."""
|
||||
|
||||
def test_delete_document(self, patched_session, sample_document):
|
||||
"""Test deleting a document."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
doc = repo.get(str(sample_document.document_id))
|
||||
assert doc is None
|
||||
|
||||
def test_delete_document_with_annotations(self, patched_session, sample_document, sample_annotation):
|
||||
"""Test deleting document also deletes its annotations."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(sample_document.document_id))
|
||||
assert result is True
|
||||
|
||||
# Verify annotation is also deleted
|
||||
from inference.data.repositories.annotation_repository import AnnotationRepository
|
||||
|
||||
ann_repo = AnnotationRepository()
|
||||
annotations = ann_repo.get_for_document(str(sample_document.document_id))
|
||||
assert len(annotations) == 0
|
||||
|
||||
def test_delete_nonexistent_document(self, patched_session):
|
||||
"""Test deleting a document that doesn't exist."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestDocumentRepositoryQueries:
|
||||
"""Tests for complex document queries."""
|
||||
|
||||
def test_count_by_status(self, patched_session, multiple_documents):
|
||||
"""Test counting documents by status."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
counts = repo.count_by_status()
|
||||
|
||||
assert counts.get("pending") == 2
|
||||
assert counts.get("labeled") == 2
|
||||
assert counts.get("exported") == 1
|
||||
|
||||
def test_get_categories(self, patched_session, multiple_documents):
|
||||
"""Test getting unique categories."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
categories = repo.get_categories()
|
||||
|
||||
assert "invoice" in categories
|
||||
assert "letter" in categories
|
||||
|
||||
def test_get_labeled_for_export(self, patched_session, multiple_documents):
|
||||
"""Test getting labeled documents for export."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
docs = repo.get_labeled_for_export()
|
||||
|
||||
assert len(docs) == 2
|
||||
for doc in docs:
|
||||
assert doc.status == "labeled"
|
||||
|
||||
|
||||
class TestDocumentAnnotationLocking:
|
||||
"""Tests for annotation locking mechanism."""
|
||||
|
||||
def test_acquire_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test acquiring annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is not None
|
||||
lock_until = ensure_utc(doc.annotation_lock_until)
|
||||
assert lock_until > datetime.now(timezone.utc)
|
||||
|
||||
def test_acquire_lock_when_already_locked(self, patched_session, sample_document):
|
||||
"""Test acquiring lock fails when already locked."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# First lock
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
|
||||
# Second lock attempt should fail
|
||||
result = repo.acquire_annotation_lock(str(sample_document.document_id))
|
||||
assert result is None
|
||||
|
||||
def test_release_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test releasing annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
repo.acquire_annotation_lock(str(sample_document.document_id), duration_seconds=300)
|
||||
doc = repo.release_annotation_lock(str(sample_document.document_id))
|
||||
|
||||
assert doc is not None
|
||||
assert doc.annotation_lock_until is None
|
||||
|
||||
def test_extend_annotation_lock(self, patched_session, sample_document):
|
||||
"""Test extending annotation lock."""
|
||||
repo = DocumentRepository()
|
||||
|
||||
# Acquire initial lock
|
||||
initial_doc = repo.acquire_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
duration_seconds=300,
|
||||
)
|
||||
initial_expiry = ensure_utc(initial_doc.annotation_lock_until)
|
||||
|
||||
# Extend lock
|
||||
extended_doc = repo.extend_annotation_lock(
|
||||
str(sample_document.document_id),
|
||||
additional_seconds=300,
|
||||
)
|
||||
|
||||
assert extended_doc is not None
|
||||
extended_expiry = ensure_utc(extended_doc.annotation_lock_until)
|
||||
assert extended_expiry > initial_expiry
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Model Version Repository Integration Tests
|
||||
|
||||
Tests ModelVersionRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.model_version_repository import ModelVersionRepository
|
||||
|
||||
|
||||
class TestModelVersionCreate:
|
||||
"""Tests for model version creation."""
|
||||
|
||||
def test_create_model_version(self, patched_session):
|
||||
"""Test creating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.0.0",
|
||||
name="Invoice Extractor v1",
|
||||
model_path="/models/invoice_v1.pt",
|
||||
description="Initial production model",
|
||||
metrics_mAP=0.92,
|
||||
metrics_precision=0.89,
|
||||
metrics_recall=0.85,
|
||||
document_count=1000,
|
||||
file_size=50000000,
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.version == "1.0.0"
|
||||
assert model.name == "Invoice Extractor v1"
|
||||
assert model.model_path == "/models/invoice_v1.pt"
|
||||
assert model.metrics_mAP == 0.92
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
def test_create_model_version_with_training_info(
|
||||
self, patched_session, sample_training_task, sample_dataset
|
||||
):
|
||||
"""Test creating model version linked to training task and dataset."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.create(
|
||||
version="1.1.0",
|
||||
name="Invoice Extractor v1.1",
|
||||
model_path="/models/invoice_v1.1.pt",
|
||||
task_id=sample_training_task.task_id,
|
||||
dataset_id=sample_dataset.dataset_id,
|
||||
training_config={"epochs": 100, "batch_size": 16},
|
||||
trained_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.task_id == sample_training_task.task_id
|
||||
assert model.dataset_id == sample_dataset.dataset_id
|
||||
assert model.training_config["epochs"] == 100
|
||||
|
||||
|
||||
class TestModelVersionRead:
|
||||
"""Tests for model version retrieval."""
|
||||
|
||||
def test_get_model_version_by_id(self, patched_session, sample_model_version):
|
||||
"""Test getting model version by ID."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.version_id == sample_model_version.version_id
|
||||
|
||||
def test_get_nonexistent_model_version(self, patched_session):
|
||||
"""Test getting model version that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.get(str(uuid4()))
|
||||
assert model is None
|
||||
|
||||
def test_get_paginated_model_versions(self, patched_session):
|
||||
"""Test paginated model version listing."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple versions
|
||||
for i in range(5):
|
||||
repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model v1.{i}",
|
||||
model_path=f"/models/model_v1.{i}.pt",
|
||||
)
|
||||
|
||||
models, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(models) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session):
|
||||
"""Test filtering model versions by status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create active and inactive models
|
||||
m1 = repo.create(version="1.0.0", name="Active Model", model_path="/models/active.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
repo.create(version="2.0.0", name="Inactive Model", model_path="/models/inactive.pt")
|
||||
|
||||
active_models, active_total = repo.get_paginated(status="active")
|
||||
inactive_models, inactive_total = repo.get_paginated(status="inactive")
|
||||
|
||||
assert active_total == 1
|
||||
assert inactive_total == 1
|
||||
|
||||
|
||||
class TestModelVersionActivation:
|
||||
"""Tests for model version activation."""
|
||||
|
||||
def test_activate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test activating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is True
|
||||
assert model.status == "active"
|
||||
assert model.activated_at is not None
|
||||
|
||||
def test_activate_deactivates_others(self, patched_session):
|
||||
"""Test that activating one version deactivates others."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create and activate first model
|
||||
m1 = repo.create(version="1.0.0", name="Model 1", model_path="/models/m1.pt")
|
||||
repo.activate(str(m1.version_id))
|
||||
|
||||
# Create and activate second model
|
||||
m2 = repo.create(version="2.0.0", name="Model 2", model_path="/models/m2.pt")
|
||||
repo.activate(str(m2.version_id))
|
||||
|
||||
# Check first model is now inactive
|
||||
m1_after = repo.get(str(m1.version_id))
|
||||
assert m1_after.is_active is False
|
||||
assert m1_after.status == "inactive"
|
||||
|
||||
# Check second model is active
|
||||
m2_after = repo.get(str(m2.version_id))
|
||||
assert m2_after.is_active is True
|
||||
|
||||
def test_get_active_model(self, patched_session, sample_model_version):
|
||||
"""Test getting the currently active model."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Initially no active model
|
||||
active = repo.get_active()
|
||||
assert active is None
|
||||
|
||||
# Activate model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Now should return active model
|
||||
active = repo.get_active()
|
||||
assert active is not None
|
||||
assert active.version_id == sample_model_version.version_id
|
||||
|
||||
def test_deactivate_model_version(self, patched_session, sample_model_version):
|
||||
"""Test deactivating a model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# First activate
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Then deactivate
|
||||
model = repo.deactivate(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.is_active is False
|
||||
assert model.status == "inactive"
|
||||
|
||||
|
||||
class TestModelVersionUpdate:
|
||||
"""Tests for model version updates."""
|
||||
|
||||
def test_update_model_metadata(self, patched_session, sample_model_version):
|
||||
"""Test updating model version metadata."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(
|
||||
str(sample_model_version.version_id),
|
||||
name="Updated Model Name",
|
||||
description="Updated description",
|
||||
)
|
||||
|
||||
assert model is not None
|
||||
assert model.name == "Updated Model Name"
|
||||
assert model.description == "Updated description"
|
||||
|
||||
def test_update_model_status(self, patched_session, sample_model_version):
|
||||
"""Test updating model version status."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(sample_model_version.version_id), status="deprecated")
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "deprecated"
|
||||
|
||||
def test_update_nonexistent_model(self, patched_session):
|
||||
"""Test updating model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.update(str(uuid4()), name="New Name")
|
||||
assert model is None
|
||||
|
||||
|
||||
class TestModelVersionArchive:
|
||||
"""Tests for model version archiving."""
|
||||
|
||||
def test_archive_model_version(self, patched_session, sample_model_version):
|
||||
"""Test archiving an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is not None
|
||||
assert model.status == "archived"
|
||||
|
||||
def test_cannot_archive_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be archived."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to archive
|
||||
model = repo.archive(str(sample_model_version.version_id))
|
||||
|
||||
assert model is None
|
||||
|
||||
# Verify model is still active
|
||||
current = repo.get(str(sample_model_version.version_id))
|
||||
assert current.status == "active"
|
||||
|
||||
|
||||
class TestModelVersionDelete:
|
||||
"""Tests for model version deletion."""
|
||||
|
||||
def test_delete_inactive_model(self, patched_session, sample_model_version):
|
||||
"""Test deleting an inactive model version."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is None
|
||||
|
||||
def test_cannot_delete_active_model(self, patched_session, sample_model_version):
|
||||
"""Test that active model cannot be deleted."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Activate the model
|
||||
repo.activate(str(sample_model_version.version_id))
|
||||
|
||||
# Try to delete
|
||||
result = repo.delete(str(sample_model_version.version_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
# Verify model still exists
|
||||
model = repo.get(str(sample_model_version.version_id))
|
||||
assert model is not None
|
||||
|
||||
def test_delete_nonexistent_model(self, patched_session):
|
||||
"""Test deleting model that doesn't exist."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
result = repo.delete(str(uuid4()))
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestOnlyOneActiveModel:
|
||||
"""Tests to verify only one model can be active at a time."""
|
||||
|
||||
def test_single_active_model_constraint(self, patched_session):
|
||||
"""Test that only one model can be active at any time."""
|
||||
repo = ModelVersionRepository()
|
||||
|
||||
# Create multiple models
|
||||
models = []
|
||||
for i in range(3):
|
||||
m = repo.create(
|
||||
version=f"1.{i}.0",
|
||||
name=f"Model {i}",
|
||||
model_path=f"/models/model_{i}.pt",
|
||||
)
|
||||
models.append(m)
|
||||
|
||||
# Activate each model in sequence
|
||||
for model in models:
|
||||
repo.activate(str(model.version_id))
|
||||
|
||||
# Count active models
|
||||
all_models, _ = repo.get_paginated(status="active")
|
||||
assert len(all_models) == 1
|
||||
|
||||
# Verify it's the last one activated
|
||||
assert all_models[0].version_id == models[-1].version_id
|
||||
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
274
tests/integration/repositories/test_token_repo_integration.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""
|
||||
Token Repository Integration Tests
|
||||
|
||||
Tests TokenRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.token_repository import TokenRepository
|
||||
|
||||
|
||||
class TestTokenCreate:
|
||||
"""Tests for token creation."""
|
||||
|
||||
def test_create_new_token(self, patched_session):
|
||||
"""Test creating a new admin token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.create(
|
||||
token="new-test-token-abc123",
|
||||
name="New Test Admin",
|
||||
)
|
||||
|
||||
token = repo.get("new-test-token-abc123")
|
||||
assert token is not None
|
||||
assert token.token == "new-test-token-abc123"
|
||||
assert token.name == "New Test Admin"
|
||||
assert token.is_active is True
|
||||
assert token.expires_at is None
|
||||
|
||||
def test_create_token_with_expiration(self, patched_session):
|
||||
"""Test creating token with expiration date."""
|
||||
repo = TokenRepository()
|
||||
expiry = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
repo.create(
|
||||
token="expiring-token-xyz789",
|
||||
name="Expiring Token",
|
||||
expires_at=expiry,
|
||||
)
|
||||
|
||||
token = repo.get("expiring-token-xyz789")
|
||||
assert token is not None
|
||||
assert token.expires_at is not None
|
||||
|
||||
def test_create_updates_existing_token(self, patched_session, admin_token):
|
||||
"""Test creating with existing token updates it."""
|
||||
repo = TokenRepository()
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=60)
|
||||
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Updated Admin Name",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.name == "Updated Admin Name"
|
||||
assert token.is_active is True
|
||||
|
||||
|
||||
class TestTokenValidation:
|
||||
"""Tests for token validation."""
|
||||
|
||||
def test_is_valid_active_token(self, patched_session, admin_token):
|
||||
"""Test that active token is valid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_valid_nonexistent_token(self, patched_session):
|
||||
"""Test that nonexistent token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.is_valid("nonexistent-token-12345")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test that deactivated token is invalid."""
|
||||
repo = TokenRepository()
|
||||
|
||||
repo.deactivate(admin_token.token)
|
||||
result = repo.is_valid(admin_token.token)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_expired_token(self, patched_session):
|
||||
"""Test that expired token is invalid."""
|
||||
repo = TokenRepository()
|
||||
past_expiry = datetime.now(timezone.utc) - timedelta(days=1)
|
||||
|
||||
repo.create(
|
||||
token="expired-token-test",
|
||||
name="Expired Token",
|
||||
expires_at=past_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("expired-token-test")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_valid_not_yet_expired_token(self, patched_session):
|
||||
"""Test that not-yet-expired token is valid."""
|
||||
repo = TokenRepository()
|
||||
future_expiry = datetime.now(timezone.utc) + timedelta(days=7)
|
||||
|
||||
repo.create(
|
||||
token="valid-expiring-token",
|
||||
name="Valid Expiring Token",
|
||||
expires_at=future_expiry,
|
||||
)
|
||||
|
||||
result = repo.is_valid("valid-expiring-token")
|
||||
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestTokenGet:
|
||||
"""Tests for token retrieval."""
|
||||
|
||||
def test_get_existing_token(self, patched_session, admin_token):
|
||||
"""Test getting an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get(admin_token.token)
|
||||
|
||||
assert token is not None
|
||||
assert token.token == admin_token.token
|
||||
assert token.name == admin_token.name
|
||||
|
||||
def test_get_nonexistent_token(self, patched_session):
|
||||
"""Test getting a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
token = repo.get("nonexistent-token-xyz")
|
||||
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenDeactivate:
|
||||
"""Tests for token deactivation."""
|
||||
|
||||
def test_deactivate_existing_token(self, patched_session, admin_token):
|
||||
"""Test deactivating an existing token."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate(admin_token.token)
|
||||
|
||||
assert result is True
|
||||
token = repo.get(admin_token.token)
|
||||
assert token is not None
|
||||
assert token.is_active is False
|
||||
|
||||
def test_deactivate_nonexistent_token(self, patched_session):
|
||||
"""Test deactivating a token that doesn't exist."""
|
||||
repo = TokenRepository()
|
||||
|
||||
result = repo.deactivate("nonexistent-token-abc")
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_reactivate_deactivated_token(self, patched_session, admin_token):
|
||||
"""Test reactivating a deactivated token via create."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Deactivate first
|
||||
repo.deactivate(admin_token.token)
|
||||
assert repo.is_valid(admin_token.token) is False
|
||||
|
||||
# Reactivate via create
|
||||
repo.create(
|
||||
token=admin_token.token,
|
||||
name="Reactivated Admin",
|
||||
)
|
||||
|
||||
assert repo.is_valid(admin_token.token) is True
|
||||
|
||||
|
||||
class TestTokenUsageTracking:
|
||||
"""Tests for token usage tracking."""
|
||||
|
||||
def test_update_usage(self, patched_session, admin_token):
|
||||
"""Test updating token last used timestamp."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Initially last_used_at might be None
|
||||
initial_token = repo.get(admin_token.token)
|
||||
initial_last_used = initial_token.last_used_at
|
||||
|
||||
repo.update_usage(admin_token.token)
|
||||
|
||||
updated_token = repo.get(admin_token.token)
|
||||
assert updated_token.last_used_at is not None
|
||||
if initial_last_used:
|
||||
assert updated_token.last_used_at >= initial_last_used
|
||||
|
||||
def test_update_usage_nonexistent_token(self, patched_session):
|
||||
"""Test updating usage for nonexistent token does nothing."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Should not raise, just does nothing
|
||||
repo.update_usage("nonexistent-token-usage")
|
||||
|
||||
token = repo.get("nonexistent-token-usage")
|
||||
assert token is None
|
||||
|
||||
|
||||
class TestTokenWorkflow:
|
||||
"""Tests for complete token workflows."""
|
||||
|
||||
def test_full_token_lifecycle(self, patched_session):
|
||||
"""Test complete token lifecycle: create, validate, use, deactivate."""
|
||||
repo = TokenRepository()
|
||||
token_str = "lifecycle-test-token"
|
||||
|
||||
# 1. Create token
|
||||
repo.create(token=token_str, name="Lifecycle Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# 2. Use token
|
||||
repo.update_usage(token_str)
|
||||
token = repo.get(token_str)
|
||||
assert token.last_used_at is not None
|
||||
|
||||
# 3. Update token info
|
||||
new_expiry = datetime.now(timezone.utc) + timedelta(days=90)
|
||||
repo.create(
|
||||
token=token_str,
|
||||
name="Updated Lifecycle Token",
|
||||
expires_at=new_expiry,
|
||||
)
|
||||
token = repo.get(token_str)
|
||||
assert token.name == "Updated Lifecycle Token"
|
||||
|
||||
# 4. Deactivate token
|
||||
result = repo.deactivate(token_str)
|
||||
assert result is True
|
||||
assert repo.is_valid(token_str) is False
|
||||
|
||||
# 5. Reactivate token
|
||||
repo.create(token=token_str, name="Reactivated Token")
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
def test_multiple_tokens(self, patched_session):
|
||||
"""Test managing multiple tokens."""
|
||||
repo = TokenRepository()
|
||||
|
||||
# Create multiple tokens
|
||||
tokens = [
|
||||
("token-a", "Admin A"),
|
||||
("token-b", "Admin B"),
|
||||
("token-c", "Admin C"),
|
||||
]
|
||||
|
||||
for token_str, name in tokens:
|
||||
repo.create(token=token_str, name=name)
|
||||
|
||||
# Verify all are valid
|
||||
for token_str, _ in tokens:
|
||||
assert repo.is_valid(token_str) is True
|
||||
|
||||
# Deactivate one
|
||||
repo.deactivate("token-b")
|
||||
|
||||
# Verify states
|
||||
assert repo.is_valid("token-a") is True
|
||||
assert repo.is_valid("token-b") is False
|
||||
assert repo.is_valid("token-c") is True
|
||||
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Training Task Repository Integration Tests
|
||||
|
||||
Tests TrainingTaskRepository with real database operations.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from inference.data.repositories.training_task_repository import TrainingTaskRepository
|
||||
|
||||
|
||||
class TestTrainingTaskCreate:
|
||||
"""Tests for training task creation."""
|
||||
|
||||
def test_create_training_task(self, patched_session, admin_token):
|
||||
"""Test creating a training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Test Training Task",
|
||||
task_type="train",
|
||||
description="Integration test training task",
|
||||
config={"epochs": 100, "batch_size": 16},
|
||||
)
|
||||
|
||||
assert task_id is not None
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.name == "Test Training Task"
|
||||
assert task.task_type == "train"
|
||||
assert task.status == "pending"
|
||||
assert task.config["epochs"] == 100
|
||||
|
||||
def test_create_scheduled_task(self, patched_session, admin_token):
|
||||
"""Test creating a scheduled training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
scheduled_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Scheduled Task",
|
||||
scheduled_at=scheduled_time,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.status == "scheduled"
|
||||
assert task.scheduled_at is not None
|
||||
|
||||
def test_create_recurring_task(self, patched_session, admin_token):
|
||||
"""Test creating a recurring training task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Recurring Task",
|
||||
cron_expression="0 2 * * *",
|
||||
is_recurring=True,
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.is_recurring is True
|
||||
assert task.cron_expression == "0 2 * * *"
|
||||
|
||||
def test_create_task_with_dataset(self, patched_session, admin_token, sample_dataset):
|
||||
"""Test creating task linked to a dataset."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Dataset Training Task",
|
||||
dataset_id=str(sample_dataset.dataset_id),
|
||||
)
|
||||
|
||||
task = repo.get(task_id)
|
||||
assert task is not None
|
||||
assert task.dataset_id == sample_dataset.dataset_id
|
||||
|
||||
|
||||
class TestTrainingTaskRead:
|
||||
"""Tests for training task retrieval."""
|
||||
|
||||
def test_get_task_by_id(self, patched_session, sample_training_task):
|
||||
"""Test getting task by ID."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
|
||||
assert task is not None
|
||||
assert task.task_id == sample_training_task.task_id
|
||||
|
||||
def test_get_nonexistent_task(self, patched_session):
|
||||
"""Test getting task that doesn't exist."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task = repo.get(str(uuid4()))
|
||||
assert task is None
|
||||
|
||||
def test_get_paginated_tasks(self, patched_session, admin_token):
|
||||
"""Test paginated task listing."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks
|
||||
for i in range(5):
|
||||
repo.create(admin_token=admin_token.token, name=f"Task {i}")
|
||||
|
||||
tasks, total = repo.get_paginated(limit=2, offset=0)
|
||||
|
||||
assert total == 5
|
||||
assert len(tasks) == 2
|
||||
|
||||
def test_get_paginated_with_status_filter(self, patched_session, admin_token):
|
||||
"""Test filtering tasks by status."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create tasks with different statuses
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
tasks, total = repo.get_paginated(status="running")
|
||||
|
||||
assert total == 1
|
||||
assert tasks[0].status == "running"
|
||||
|
||||
def test_get_pending_tasks(self, patched_session, admin_token):
|
||||
"""Test getting pending tasks ready to run."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create pending task
|
||||
repo.create(admin_token=admin_token.token, name="Ready Task")
|
||||
|
||||
# Create scheduled task in the past (should be included)
|
||||
past_time = datetime.now(timezone.utc) - timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Past Scheduled Task",
|
||||
scheduled_at=past_time,
|
||||
)
|
||||
|
||||
# Create scheduled task in the future (should not be included)
|
||||
future_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
repo.create(
|
||||
admin_token=admin_token.token,
|
||||
name="Future Scheduled Task",
|
||||
scheduled_at=future_time,
|
||||
)
|
||||
|
||||
pending = repo.get_pending()
|
||||
|
||||
# Should include pending and past scheduled, not future scheduled
|
||||
assert len(pending) >= 2
|
||||
names = [t.name for t in pending]
|
||||
assert "Ready Task" in names
|
||||
assert "Past Scheduled Task" in names
|
||||
|
||||
def test_get_running_task(self, patched_session, admin_token):
|
||||
"""Test getting currently running task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
task_id = repo.create(admin_token=admin_token.token, name="Running Task")
|
||||
repo.update_status(task_id, "running")
|
||||
|
||||
running = repo.get_running()
|
||||
|
||||
assert running is not None
|
||||
assert running.status == "running"
|
||||
|
||||
def test_get_running_task_none(self, patched_session, admin_token):
|
||||
"""Test getting running task when none is running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.create(admin_token=admin_token.token, name="Pending Task")
|
||||
|
||||
running = repo.get_running()
|
||||
assert running is None
|
||||
|
||||
|
||||
class TestTrainingTaskUpdate:
|
||||
"""Tests for training task updates."""
|
||||
|
||||
def test_update_status_to_running(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to running."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "running"
|
||||
assert task.started_at is not None
|
||||
|
||||
def test_update_status_to_completed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to completed."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
metrics = {"mAP": 0.92, "precision": 0.89, "recall": 0.85}
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"completed",
|
||||
result_metrics=metrics,
|
||||
model_path="/models/trained_model.pt",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "completed"
|
||||
assert task.completed_at is not None
|
||||
assert task.result_metrics["mAP"] == 0.92
|
||||
assert task.model_path == "/models/trained_model.pt"
|
||||
|
||||
def test_update_status_to_failed(self, patched_session, sample_training_task):
|
||||
"""Test updating task status to failed with error message."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(
|
||||
str(sample_training_task.task_id),
|
||||
"failed",
|
||||
error_message="CUDA out of memory",
|
||||
)
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "failed"
|
||||
assert task.completed_at is not None
|
||||
assert "CUDA out of memory" in task.error_message
|
||||
|
||||
def test_cancel_pending_task(self, patched_session, sample_training_task):
|
||||
"""Test cancelling a pending task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is True
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task is not None
|
||||
assert task.status == "cancelled"
|
||||
|
||||
def test_cannot_cancel_running_task(self, patched_session, sample_training_task):
|
||||
"""Test that running task cannot be cancelled."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.update_status(str(sample_training_task.task_id), "running")
|
||||
|
||||
result = repo.cancel(str(sample_training_task.task_id))
|
||||
|
||||
assert result is False
|
||||
|
||||
task = repo.get(str(sample_training_task.task_id))
|
||||
assert task.status == "running"
|
||||
|
||||
|
||||
class TestTrainingLogs:
|
||||
"""Tests for training log management."""
|
||||
|
||||
def test_add_log_entry(self, patched_session, sample_training_task):
|
||||
"""Test adding a training log entry."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message="Starting training...",
|
||||
details={"epoch": 1, "batch": 0},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 1
|
||||
assert logs[0].level == "INFO"
|
||||
assert logs[0].message == "Starting training..."
|
||||
|
||||
def test_add_multiple_log_entries(self, patched_session, sample_training_task):
|
||||
"""Test adding multiple log entries."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(5):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Epoch {i} completed",
|
||||
details={"epoch": i, "loss": 0.5 - i * 0.1},
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id))
|
||||
assert len(logs) == 5
|
||||
|
||||
def test_get_logs_pagination(self, patched_session, sample_training_task):
|
||||
"""Test paginated log retrieval."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for i in range(10):
|
||||
repo.add_log(
|
||||
str(sample_training_task.task_id),
|
||||
level="INFO",
|
||||
message=f"Log entry {i}",
|
||||
)
|
||||
|
||||
logs = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=0)
|
||||
assert len(logs) == 5
|
||||
|
||||
logs_page2 = repo.get_logs(str(sample_training_task.task_id), limit=5, offset=5)
|
||||
assert len(logs_page2) == 5
|
||||
|
||||
|
||||
class TestDocumentLinks:
|
||||
"""Tests for training document link management."""
|
||||
|
||||
def test_create_document_link(self, patched_session, sample_training_task, sample_document):
|
||||
"""Test creating a document link."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
link = repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=sample_document.document_id,
|
||||
annotation_snapshot={"count": 5, "verified": 3},
|
||||
)
|
||||
|
||||
assert link is not None
|
||||
assert link.task_id == sample_training_task.task_id
|
||||
assert link.document_id == sample_document.document_id
|
||||
assert link.annotation_snapshot["count"] == 5
|
||||
|
||||
def test_get_document_links(self, patched_session, sample_training_task, multiple_documents):
|
||||
"""Test getting all document links for a task."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
for doc in multiple_documents[:3]:
|
||||
repo.create_document_link(
|
||||
task_id=sample_training_task.task_id,
|
||||
document_id=doc.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_links(sample_training_task.task_id)
|
||||
assert len(links) == 3
|
||||
|
||||
def test_get_document_training_tasks(self, patched_session, admin_token, sample_document):
|
||||
"""Test getting training tasks that used a document."""
|
||||
repo = TrainingTaskRepository()
|
||||
|
||||
# Create multiple tasks using the same document
|
||||
task1_id = repo.create(admin_token=admin_token.token, name="Task 1")
|
||||
task2_id = repo.create(admin_token=admin_token.token, name="Task 2")
|
||||
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task1_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
repo.create_document_link(
|
||||
task_id=repo.get(task2_id).task_id,
|
||||
document_id=sample_document.document_id,
|
||||
)
|
||||
|
||||
links = repo.get_document_training_tasks(sample_document.document_id)
|
||||
assert len(links) == 2
|
||||
Reference in New Issue
Block a user