Add more tests

This commit is contained in:
Yaojia Wang
2026-02-01 22:40:41 +01:00
parent a564ac9d70
commit 400b12a967
55 changed files with 9306 additions and 267 deletions

View File

@@ -0,0 +1,258 @@
"""
Database Setup Integration Tests
Tests for database connection, session management, and basic operations.
"""
import pytest
from sqlmodel import Session, select
from inference.data.admin_models import AdminDocument, AdminToken
class TestDatabaseConnection:
"""Tests for database engine and connection."""
def test_engine_connection(self, test_engine):
"""Verify database engine can establish connection."""
with test_engine.connect() as conn:
result = conn.execute(select(1))
assert result.scalar() == 1
def test_tables_created(self, test_engine):
"""Verify all expected tables are created."""
from sqlmodel import SQLModel
table_names = SQLModel.metadata.tables.keys()
expected_tables = [
"admin_tokens",
"admin_documents",
"admin_annotations",
"training_tasks",
"training_logs",
"batch_uploads",
"batch_upload_files",
"training_datasets",
"dataset_documents",
"training_document_links",
"model_versions",
]
for table in expected_tables:
assert table in table_names, f"Table '{table}' not found"
class TestSessionManagement:
"""Tests for database session context manager."""
def test_session_commit(self, db_session):
"""Verify session commits changes successfully."""
token = AdminToken(
token="commit-test-token",
name="Commit Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.exec(
select(AdminToken).where(AdminToken.token == "commit-test-token")
).first()
assert result is not None
assert result.name == "Commit Test"
def test_session_rollback_on_error(self, test_engine):
"""Verify session rollback on exception."""
session = Session(test_engine)
try:
token = AdminToken(
token="rollback-test-token",
name="Rollback Test",
is_active=True,
)
session.add(token)
session.commit()
# Try to insert duplicate (should fail)
duplicate = AdminToken(
token="rollback-test-token", # Same primary key
name="Duplicate",
is_active=True,
)
session.add(duplicate)
session.commit()
except Exception:
session.rollback()
finally:
session.close()
# Verify original record exists
with Session(test_engine) as verify_session:
result = verify_session.exec(
select(AdminToken).where(AdminToken.token == "rollback-test-token")
).first()
assert result is not None
assert result.name == "Rollback Test"
def test_session_isolation(self, test_engine):
"""Verify sessions are isolated from each other."""
session1 = Session(test_engine)
session2 = Session(test_engine)
try:
# Insert in session1, don't commit
token = AdminToken(
token="isolation-test-token",
name="Isolation Test",
is_active=True,
)
session1.add(token)
session1.flush()
# Session2 should not see uncommitted data (with proper isolation)
# Note: SQLite in-memory may have different isolation behavior
session1.commit()
result = session2.exec(
select(AdminToken).where(AdminToken.token == "isolation-test-token")
).first()
# After commit, session2 should see the data
assert result is not None
finally:
session1.close()
session2.close()
class TestBasicCRUDOperations:
"""Tests for basic CRUD operations on database."""
def test_create_and_read_token(self, db_session):
"""Test creating and reading admin token."""
token = AdminToken(
token="crud-test-token",
name="CRUD Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
result = db_session.get(AdminToken, "crud-test-token")
assert result is not None
assert result.name == "CRUD Test"
assert result.is_active is True
def test_update_entity(self, db_session, admin_token):
"""Test updating an entity."""
admin_token.name = "Updated Name"
db_session.add(admin_token)
db_session.commit()
result = db_session.get(AdminToken, admin_token.token)
assert result is not None
assert result.name == "Updated Name"
def test_delete_entity(self, db_session):
"""Test deleting an entity."""
token = AdminToken(
token="delete-test-token",
name="Delete Test",
is_active=True,
)
db_session.add(token)
db_session.commit()
db_session.delete(token)
db_session.commit()
result = db_session.get(AdminToken, "delete-test-token")
assert result is None
def test_foreign_key_constraint(self, db_session, admin_token):
"""Test foreign key constraints are enforced."""
from uuid import uuid4
doc = AdminDocument(
document_id=uuid4(),
admin_token=admin_token.token,
filename="fk_test.pdf",
file_size=1024,
content_type="application/pdf",
file_path="/test/fk_test.pdf",
page_count=1,
status="pending",
)
db_session.add(doc)
db_session.commit()
# Document should reference valid token
result = db_session.get(AdminDocument, doc.document_id)
assert result is not None
assert result.admin_token == admin_token.token
class TestQueryOperations:
"""Tests for various query operations."""
def test_select_with_filter(self, db_session, multiple_documents):
"""Test SELECT with WHERE clause."""
results = db_session.exec(
select(AdminDocument).where(AdminDocument.status == "labeled")
).all()
assert len(results) == 2
for doc in results:
assert doc.status == "labeled"
def test_select_with_order(self, db_session, multiple_documents):
"""Test SELECT with ORDER BY clause."""
results = db_session.exec(
select(AdminDocument).order_by(AdminDocument.file_size.desc())
).all()
file_sizes = [doc.file_size for doc in results]
assert file_sizes == sorted(file_sizes, reverse=True)
def test_select_with_limit_offset(self, db_session, multiple_documents):
"""Test SELECT with LIMIT and OFFSET."""
results = db_session.exec(
select(AdminDocument)
.order_by(AdminDocument.filename)
.offset(2)
.limit(2)
).all()
assert len(results) == 2
def test_count_query(self, db_session, multiple_documents):
"""Test COUNT aggregation."""
from sqlalchemy import func
count = db_session.exec(
select(func.count()).select_from(AdminDocument)
).one()
assert count == 5
def test_group_by_query(self, db_session, multiple_documents):
"""Test GROUP BY aggregation."""
from sqlalchemy import func
results = db_session.exec(
select(
AdminDocument.status,
func.count(AdminDocument.document_id).label("count"),
).group_by(AdminDocument.status)
).all()
status_counts = {row[0]: row[1] for row in results}
assert status_counts.get("pending") == 2
assert status_counts.get("labeled") == 2
assert status_counts.get("exported") == 1