using ColaFlow.Modules.Mcp.Infrastructure.Middleware;
using FluentAssertions;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using NSubstitute;
using System.Text;
namespace ColaFlow.Modules.Mcp.Tests.Infrastructure.Middleware;
///
/// Unit tests for McpLoggingMiddleware
///
public class McpLoggingMiddlewareTests
{
private readonly ILogger _logger;
public McpLoggingMiddlewareTests()
{
_logger = Substitute.For>();
}
[Fact]
public async Task InvokeAsync_ShouldSkipLogging_ForNonMcpRequests()
{
// Arrange
var context = new DefaultHttpContext();
context.Request.Method = "GET";
context.Request.Path = "/api/tasks";
var nextCalled = false;
RequestDelegate next = (HttpContext ctx) =>
{
nextCalled = true;
return Task.CompletedTask;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
nextCalled.Should().BeTrue();
// Logger should not be called for non-MCP requests
_logger.DidNotReceive().Log(
Arg.Any(),
Arg.Any(),
Arg.Any(),
Arg.Any(),
Arg.Any>());
}
[Fact]
public async Task InvokeAsync_ShouldLogRequestAndResponse_ForMcpRequests()
{
// Arrange
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"method\":\"initialize\"}"));
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
RequestDelegate next = (HttpContext ctx) =>
{
ctx.Response.StatusCode = 200;
var responseBytes = Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"result\":{}}");
ctx.Response.Body.Write(responseBytes, 0, responseBytes.Length);
return Task.CompletedTask;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
// Should log request (Debug level)
_logger.Received().Log(
LogLevel.Debug,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("MCP Request")),
Arg.Any(),
Arg.Any>());
// Should log response (Debug level for 2xx status)
_logger.Received().Log(
LogLevel.Debug,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("MCP Response")),
Arg.Any(),
Arg.Any>());
}
[Fact]
public async Task InvokeAsync_ShouldLogErrorLevel_ForErrorResponses()
{
// Arrange
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
RequestDelegate next = (HttpContext ctx) =>
{
ctx.Response.StatusCode = 500; // Error status
return Task.CompletedTask;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
// Should log response at Error level for 5xx status
_logger.Received().Log(
LogLevel.Error,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("MCP Response")),
Arg.Any(),
Arg.Any>());
}
[Fact]
public async Task InvokeAsync_ShouldLogWarning_ForSlowRequests()
{
// Arrange
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
RequestDelegate next = async (HttpContext ctx) =>
{
// Simulate slow request (> 1 second)
await Task.Delay(1100);
ctx.Response.StatusCode = 200;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
// Should log warning for slow requests
_logger.Received().Log(
LogLevel.Warning,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("Slow MCP Request")),
Arg.Any(),
Arg.Any>());
}
[Fact]
public async Task InvokeAsync_ShouldSanitizeSensitiveData()
{
// Arrange
var requestBody = "{\"keyHash\":\"secret-key-hash\",\"password\":\"my-password\"}";
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
var loggedRequest = string.Empty;
_logger.When(x => x.Log(
LogLevel.Debug,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("MCP Request")),
Arg.Any(),
Arg.Any>()))
.Do(callInfo =>
{
var state = callInfo.ArgAt(2);
loggedRequest = state.ToString() ?? "";
});
RequestDelegate next = (HttpContext ctx) =>
{
ctx.Response.StatusCode = 200;
return Task.CompletedTask;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
loggedRequest.Should().Contain("[REDACTED]");
loggedRequest.Should().NotContain("secret-key-hash");
loggedRequest.Should().NotContain("my-password");
}
[Fact]
public async Task InvokeAsync_ShouldIncludePerformanceMetrics()
{
// Arrange
var context = new DefaultHttpContext();
context.Request.Method = "POST";
context.Request.Path = "/mcp";
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
context.Response.Body = new MemoryStream();
context.Items["CorrelationId"] = "test-correlation-id";
var loggedResponse = string.Empty;
_logger.When(x => x.Log(
LogLevel.Debug,
Arg.Any(),
Arg.Is(o => o.ToString()!.Contains("MCP Response")),
Arg.Any(),
Arg.Any>()))
.Do(callInfo =>
{
var state = callInfo.ArgAt(2);
loggedResponse = state.ToString() ?? "";
});
RequestDelegate next = (HttpContext ctx) =>
{
ctx.Response.StatusCode = 200;
return Task.CompletedTask;
};
var middleware = new McpLoggingMiddleware(next, _logger);
// Act
await middleware.InvokeAsync(context);
// Assert
loggedResponse.Should().Contain("Duration:");
loggedResponse.Should().MatchRegex(@"Duration:\s*\d+ms");
}
}