""" Tests for Admin Authentication. """ import pytest from datetime import datetime, timedelta from unittest.mock import MagicMock, patch from fastapi import HTTPException from inference.data.repositories import TokenRepository from inference.data.admin_models import AdminToken from inference.web.core.auth import ( get_token_repository, reset_token_repository, validate_admin_token, ) @pytest.fixture def mock_token_repo(): """Create a mock TokenRepository.""" repo = MagicMock(spec=TokenRepository) repo.is_valid.return_value = True return repo @pytest.fixture(autouse=True) def reset_repo(): """Reset token repository after each test.""" yield reset_token_repository() class TestValidateAdminToken: """Tests for validate_admin_token dependency.""" def test_missing_token_raises_401(self, mock_token_repo): """Test that missing token raises 401.""" import asyncio with pytest.raises(HTTPException) as exc_info: asyncio.get_event_loop().run_until_complete( validate_admin_token(None, mock_token_repo) ) assert exc_info.value.status_code == 401 assert "Admin token required" in exc_info.value.detail def test_invalid_token_raises_401(self, mock_token_repo): """Test that invalid token raises 401.""" import asyncio mock_token_repo.is_valid.return_value = False with pytest.raises(HTTPException) as exc_info: asyncio.get_event_loop().run_until_complete( validate_admin_token("invalid-token", mock_token_repo) ) assert exc_info.value.status_code == 401 assert "Invalid or expired" in exc_info.value.detail def test_valid_token_returns_token(self, mock_token_repo): """Test that valid token is returned.""" import asyncio token = "valid-test-token" mock_token_repo.is_valid.return_value = True result = asyncio.get_event_loop().run_until_complete( validate_admin_token(token, mock_token_repo) ) assert result == token mock_token_repo.update_usage.assert_called_once_with(token) class TestTokenRepository: """Tests for TokenRepository operations.""" def test_is_valid_active_token(self): """Test valid active token.""" with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session mock_token = AdminToken( token="test-token", name="Test", is_active=True, expires_at=None, ) mock_session.get.return_value = mock_token repo = TokenRepository() assert repo.is_valid("test-token") is True def test_is_valid_inactive_token(self): """Test inactive token.""" with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session mock_token = AdminToken( token="test-token", name="Test", is_active=False, expires_at=None, ) mock_session.get.return_value = mock_token repo = TokenRepository() assert repo.is_valid("test-token") is False def test_is_valid_expired_token(self): """Test expired token.""" with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session mock_token = AdminToken( token="test-token", name="Test", is_active=True, expires_at=datetime.utcnow() - timedelta(days=1), ) mock_session.get.return_value = mock_token repo = TokenRepository() # Need to also mock _now() to ensure proper comparison with patch.object(repo, "_now", return_value=datetime.utcnow()): assert repo.is_valid("test-token") is False def test_is_valid_token_not_found(self): """Test token not found.""" with patch("inference.data.repositories.token_repository.BaseRepository._session") as mock_ctx: mock_session = MagicMock() mock_ctx.return_value.__enter__.return_value = mock_session mock_session.get.return_value = None repo = TokenRepository() assert repo.is_valid("nonexistent") is False class TestGetTokenRepository: """Tests for get_token_repository function.""" def test_returns_singleton(self): """Test that get_token_repository returns singleton.""" reset_token_repository() repo1 = get_token_repository() repo2 = get_token_repository() assert repo1 is repo2 def test_reset_clears_singleton(self): """Test that reset clears singleton.""" repo1 = get_token_repository() reset_token_repository() repo2 = get_token_repository() assert repo1 is not repo2