165 lines
5.2 KiB
Python
165 lines
5.2 KiB
Python
"""
|
|
Tests for Admin Authentication.
|
|
"""
|
|
|
|
import pytest
|
|
from datetime import datetime, timedelta
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from inference.data.repositories import TokenRepository
|
|
from inference.data.admin_models import AdminToken
|
|
from inference.web.core.auth import (
|
|
get_token_repository,
|
|
reset_token_repository,
|
|
validate_admin_token,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_token_repo():
|
|
"""Create a mock TokenRepository."""
|
|
repo = MagicMock(spec=TokenRepository)
|
|
repo.is_valid.return_value = True
|
|
return repo
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_repo():
|
|
"""Reset token repository after each test."""
|
|
yield
|
|
reset_token_repository()
|
|
|
|
|
|
class TestValidateAdminToken:
|
|
"""Tests for validate_admin_token dependency."""
|
|
|
|
def test_missing_token_raises_401(self, mock_token_repo):
|
|
"""Test that missing token raises 401."""
|
|
import asyncio
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.get_event_loop().run_until_complete(
|
|
validate_admin_token(None, mock_token_repo)
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Admin token required" in exc_info.value.detail
|
|
|
|
def test_invalid_token_raises_401(self, mock_token_repo):
|
|
"""Test that invalid token raises 401."""
|
|
import asyncio
|
|
|
|
mock_token_repo.is_valid.return_value = False
|
|
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
asyncio.get_event_loop().run_until_complete(
|
|
validate_admin_token("invalid-token", mock_token_repo)
|
|
)
|
|
|
|
assert exc_info.value.status_code == 401
|
|
assert "Invalid or expired" in exc_info.value.detail
|
|
|
|
def test_valid_token_returns_token(self, mock_token_repo):
|
|
"""Test that valid token is returned."""
|
|
import asyncio
|
|
|
|
token = "valid-test-token"
|
|
mock_token_repo.is_valid.return_value = True
|
|
|
|
result = asyncio.get_event_loop().run_until_complete(
|
|
validate_admin_token(token, mock_token_repo)
|
|
)
|
|
|
|
assert result == token
|
|
mock_token_repo.update_usage.assert_called_once_with(token)
|
|
|
|
|
|
class TestTokenRepository:
|
|
"""Tests for TokenRepository operations."""
|
|
|
|
def test_is_valid_active_token(self):
|
|
"""Test valid active token."""
|
|
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
mock_token = AdminToken(
|
|
token="test-token",
|
|
name="Test",
|
|
is_active=True,
|
|
expires_at=None,
|
|
)
|
|
mock_session.get.return_value = mock_token
|
|
|
|
repo = TokenRepository()
|
|
assert repo.is_valid("test-token") is True
|
|
|
|
def test_is_valid_inactive_token(self):
|
|
"""Test inactive token."""
|
|
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
mock_token = AdminToken(
|
|
token="test-token",
|
|
name="Test",
|
|
is_active=False,
|
|
expires_at=None,
|
|
)
|
|
mock_session.get.return_value = mock_token
|
|
|
|
repo = TokenRepository()
|
|
assert repo.is_valid("test-token") is False
|
|
|
|
def test_is_valid_expired_token(self):
|
|
"""Test expired token."""
|
|
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
|
|
mock_token = AdminToken(
|
|
token="test-token",
|
|
name="Test",
|
|
is_active=True,
|
|
expires_at=datetime.utcnow() - timedelta(days=1),
|
|
)
|
|
mock_session.get.return_value = mock_token
|
|
|
|
repo = TokenRepository()
|
|
# Need to also mock _now() to ensure proper comparison
|
|
with patch.object(repo, "_now", return_value=datetime.utcnow()):
|
|
assert repo.is_valid("test-token") is False
|
|
|
|
def test_is_valid_token_not_found(self):
|
|
"""Test token not found."""
|
|
with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx:
|
|
mock_session = MagicMock()
|
|
mock_ctx.return_value.__enter__.return_value = mock_session
|
|
mock_session.get.return_value = None
|
|
|
|
repo = TokenRepository()
|
|
assert repo.is_valid("nonexistent") is False
|
|
|
|
|
|
class TestGetTokenRepository:
|
|
"""Tests for get_token_repository function."""
|
|
|
|
def test_returns_singleton(self):
|
|
"""Test that get_token_repository returns singleton."""
|
|
reset_token_repository()
|
|
|
|
repo1 = get_token_repository()
|
|
repo2 = get_token_repository()
|
|
|
|
assert repo1 is repo2
|
|
|
|
def test_reset_clears_singleton(self):
|
|
"""Test that reset clears singleton."""
|
|
repo1 = get_token_repository()
|
|
reset_token_repository()
|
|
repo2 = get_token_repository()
|
|
|
|
assert repo1 is not repo2
|