using ColaFlow.Modules.Mcp.Contracts.JsonRpc;
using ColaFlow.Modules.Mcp.Domain.Exceptions;
using ColaFlow.Modules.Mcp.Infrastructure.Middleware;
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using NSubstitute;
using System.Text.Json;
namespace ColaFlow.Modules.Mcp.Tests.Infrastructure.Middleware;
///
/// Unit tests for McpExceptionHandlerMiddleware
///
public class McpExceptionHandlerMiddlewareTests
{
private readonly ILogger _logger;
public McpExceptionHandlerMiddlewareTests()
{
_logger = Substitute.For>();
}
[Fact]
public async Task InvokeAsync_ShouldCallNextMiddleware_WhenNoExceptionThrown()
{
// Arrange
var context = new DefaultHttpContext();
var nextCalled = false;
RequestDelegate next = (HttpContext ctx) =>
{
nextCalled = true;
return Task.CompletedTask;
};
var middleware = new McpExceptionHandlerMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
nextCalled.Should().BeTrue();
}
[Fact]
public async Task InvokeAsync_ShouldHandleMcpException_AndReturnJsonRpcErrorResponse()
{
// Arrange
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
context.Items["McpRequestId"] = "test-request-id";
var expectedException = new McpNotFoundException("Task", "task-123");
RequestDelegate next = (HttpContext ctx) =>
{
throw expectedException;
};
var middleware = new McpExceptionHandlerMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
context.Response.StatusCode.Should().Be(404);
context.Response.ContentType.Should().Be("application/json");
context.Response.Body.Seek(0, SeekOrigin.Begin);
var responseBody = await new StreamReader(context.Response.Body).ReadToEndAsync();
var response = JsonSerializer.Deserialize(responseBody, new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
});
response.Should().NotBeNull();
response!.JsonRpc.Should().Be("2.0");
response.Error.Should().NotBeNull();
response.Error!.Code.Should().Be((int)JsonRpcErrorCode.NotFound);
response.Error.Message.Should().Be("Task not found: task-123");
response.Id.Should().NotBeNull();
response.Id!.ToString().Should().Be("test-request-id");
}
[Fact]
public async Task InvokeAsync_ShouldHandleUnexpectedException_AndReturnInternalError()
{
// Arrange
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
RequestDelegate next = (HttpContext ctx) =>
{
throw new InvalidOperationException("Unexpected error");
};
var middleware = new McpExceptionHandlerMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
context.Response.StatusCode.Should().Be(500);
context.Response.ContentType.Should().Be("application/json");
context.Response.Body.Seek(0, SeekOrigin.Begin);
var responseBody = await new StreamReader(context.Response.Body).ReadToEndAsync();
var response = JsonSerializer.Deserialize(responseBody, new JsonSerializerOptions
{
PropertyNameCaseInsensitive = true
});
response.Should().NotBeNull();
response!.Error.Should().NotBeNull();
response.Error!.Code.Should().Be((int)JsonRpcErrorCode.InternalError);
response.Error.Message.Should().Be("Internal server error");
// Should NOT expose exception details
response.Error.Data.Should().BeNull();
}
[Theory]
[InlineData(typeof(McpUnauthorizedException), 401)]
[InlineData(typeof(McpForbiddenException), 403)]
[InlineData(typeof(McpNotFoundException), 404)]
[InlineData(typeof(McpValidationException), 422)]
[InlineData(typeof(McpParseException), 400)]
[InlineData(typeof(McpInvalidRequestException), 400)]
[InlineData(typeof(McpMethodNotFoundException), 404)]
[InlineData(typeof(McpInvalidParamsException), 400)]
public async Task InvokeAsync_ShouldMapExceptionToCorrectHttpStatusCode(Type exceptionType, int expectedStatusCode)
{
// Arrange
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
McpException exception = exceptionType.Name switch
{
nameof(McpUnauthorizedException) => new McpUnauthorizedException(),
nameof(McpForbiddenException) => new McpForbiddenException(),
nameof(McpNotFoundException) => new McpNotFoundException("Resource", "123"),
nameof(McpValidationException) => new McpValidationException(),
nameof(McpParseException) => new McpParseException(),
nameof(McpInvalidRequestException) => new McpInvalidRequestException(),
nameof(McpMethodNotFoundException) => new McpMethodNotFoundException("test"),
nameof(McpInvalidParamsException) => new McpInvalidParamsException(),
_ => throw new ArgumentException("Unknown exception type")
};
RequestDelegate next = (HttpContext ctx) =>
{
throw exception;
};
var middleware = new McpExceptionHandlerMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
context.Response.StatusCode.Should().Be(expectedStatusCode);
}
[Fact]
public async Task InvokeAsync_ShouldLogErrorWithStructuredData()
{
// Arrange
var context = new DefaultHttpContext();
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
context.Items["TenantId"] = "tenant-123";
context.Items["ApiKeyId"] = "key-456";
var exception = new McpValidationException("Test validation error");
RequestDelegate next = (HttpContext ctx) =>
{
throw exception;
};
var middleware = new McpExceptionHandlerMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
_logger.Received(1).Log(
LogLevel.Error,
Arg.Any(),
Arg.Is