using ColaFlow.Modules.Mcp.Contracts.JsonRpc; using ColaFlow.Modules.Mcp.Domain.Exceptions; using ColaFlow.Modules.Mcp.Infrastructure.Middleware; using FluentAssertions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using NSubstitute; using System.Text.Json; namespace ColaFlow.Modules.Mcp.Tests.Infrastructure.Middleware; /// /// Unit tests for McpExceptionHandlerMiddleware /// public class McpExceptionHandlerMiddlewareTests { private readonly ILogger _logger; public McpExceptionHandlerMiddlewareTests() { _logger = Substitute.For>(); } [Fact] public async Task InvokeAsync_ShouldCallNextMiddleware_WhenNoExceptionThrown() { // Arrange var context = new DefaultHttpContext(); var nextCalled = false; RequestDelegate next = (HttpContext ctx) => { nextCalled = true; return Task.CompletedTask; }; var middleware = new McpExceptionHandlerMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert nextCalled.Should().BeTrue(); } [Fact] public async Task InvokeAsync_ShouldHandleMcpException_AndReturnJsonRpcErrorResponse() { // Arrange var context = new DefaultHttpContext(); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; context.Items["McpRequestId"] = "test-request-id"; var expectedException = new McpNotFoundException("Task", "task-123"); RequestDelegate next = (HttpContext ctx) => { throw expectedException; }; var middleware = new McpExceptionHandlerMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert context.Response.StatusCode.Should().Be(404); context.Response.ContentType.Should().Be("application/json"); context.Response.Body.Seek(0, SeekOrigin.Begin); var responseBody = await new StreamReader(context.Response.Body).ReadToEndAsync(); var response = JsonSerializer.Deserialize(responseBody, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); response.Should().NotBeNull(); response!.JsonRpc.Should().Be("2.0"); response.Error.Should().NotBeNull(); response.Error!.Code.Should().Be((int)JsonRpcErrorCode.NotFound); response.Error.Message.Should().Be("Task not found: task-123"); response.Id.Should().NotBeNull(); response.Id!.ToString().Should().Be("test-request-id"); } [Fact] public async Task InvokeAsync_ShouldHandleUnexpectedException_AndReturnInternalError() { // Arrange var context = new DefaultHttpContext(); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; RequestDelegate next = (HttpContext ctx) => { throw new InvalidOperationException("Unexpected error"); }; var middleware = new McpExceptionHandlerMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert context.Response.StatusCode.Should().Be(500); context.Response.ContentType.Should().Be("application/json"); context.Response.Body.Seek(0, SeekOrigin.Begin); var responseBody = await new StreamReader(context.Response.Body).ReadToEndAsync(); var response = JsonSerializer.Deserialize(responseBody, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); response.Should().NotBeNull(); response!.Error.Should().NotBeNull(); response.Error!.Code.Should().Be((int)JsonRpcErrorCode.InternalError); response.Error.Message.Should().Be("Internal server error"); // Should NOT expose exception details response.Error.Data.Should().BeNull(); } [Theory] [InlineData(typeof(McpUnauthorizedException), 401)] [InlineData(typeof(McpForbiddenException), 403)] [InlineData(typeof(McpNotFoundException), 404)] [InlineData(typeof(McpValidationException), 422)] [InlineData(typeof(McpParseException), 400)] [InlineData(typeof(McpInvalidRequestException), 400)] [InlineData(typeof(McpMethodNotFoundException), 404)] [InlineData(typeof(McpInvalidParamsException), 400)] public async Task InvokeAsync_ShouldMapExceptionToCorrectHttpStatusCode(Type exceptionType, int expectedStatusCode) { // Arrange var context = new DefaultHttpContext(); context.Response.Body = new MemoryStream(); McpException exception = exceptionType.Name switch { nameof(McpUnauthorizedException) => new McpUnauthorizedException(), nameof(McpForbiddenException) => new McpForbiddenException(), nameof(McpNotFoundException) => new McpNotFoundException("Resource", "123"), nameof(McpValidationException) => new McpValidationException(), nameof(McpParseException) => new McpParseException(), nameof(McpInvalidRequestException) => new McpInvalidRequestException(), nameof(McpMethodNotFoundException) => new McpMethodNotFoundException("test"), nameof(McpInvalidParamsException) => new McpInvalidParamsException(), _ => throw new ArgumentException("Unknown exception type") }; RequestDelegate next = (HttpContext ctx) => { throw exception; }; var middleware = new McpExceptionHandlerMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert context.Response.StatusCode.Should().Be(expectedStatusCode); } [Fact] public async Task InvokeAsync_ShouldLogErrorWithStructuredData() { // Arrange var context = new DefaultHttpContext(); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; context.Items["TenantId"] = "tenant-123"; context.Items["ApiKeyId"] = "key-456"; var exception = new McpValidationException("Test validation error"); RequestDelegate next = (HttpContext ctx) => { throw exception; }; var middleware = new McpExceptionHandlerMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert _logger.Received(1).Log( LogLevel.Error, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Error")), exception, Arg.Any>()); } }