using ColaFlow.Modules.Mcp.Infrastructure.Middleware; using FluentAssertions; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using NSubstitute; using System.Text; namespace ColaFlow.Modules.Mcp.Tests.Infrastructure.Middleware; /// /// Unit tests for McpLoggingMiddleware /// public class McpLoggingMiddlewareTests { private readonly ILogger _logger; public McpLoggingMiddlewareTests() { _logger = Substitute.For>(); } [Fact] public async Task InvokeAsync_ShouldSkipLogging_ForNonMcpRequests() { // Arrange var context = new DefaultHttpContext(); context.Request.Method = "GET"; context.Request.Path = "/api/tasks"; var nextCalled = false; RequestDelegate next = (HttpContext ctx) => { nextCalled = true; return Task.CompletedTask; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert nextCalled.Should().BeTrue(); // Logger should not be called for non-MCP requests _logger.DidNotReceive().Log( Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any(), Arg.Any>()); } [Fact] public async Task InvokeAsync_ShouldLogRequestAndResponse_ForMcpRequests() { // Arrange var context = new DefaultHttpContext(); context.Request.Method = "POST"; context.Request.Path = "/mcp"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"method\":\"initialize\"}")); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; RequestDelegate next = (HttpContext ctx) => { ctx.Response.StatusCode = 200; var responseBytes = Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"result\":{}}"); ctx.Response.Body.Write(responseBytes, 0, responseBytes.Length); return Task.CompletedTask; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert // Should log request (Debug level) _logger.Received().Log( LogLevel.Debug, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Request")), Arg.Any(), Arg.Any>()); // Should log response (Debug level for 2xx status) _logger.Received().Log( LogLevel.Debug, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Response")), Arg.Any(), Arg.Any>()); } [Fact] public async Task InvokeAsync_ShouldLogErrorLevel_ForErrorResponses() { // Arrange var context = new DefaultHttpContext(); context.Request.Method = "POST"; context.Request.Path = "/mcp"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}")); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; RequestDelegate next = (HttpContext ctx) => { ctx.Response.StatusCode = 500; // Error status return Task.CompletedTask; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert // Should log response at Error level for 5xx status _logger.Received().Log( LogLevel.Error, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Response")), Arg.Any(), Arg.Any>()); } [Fact] public async Task InvokeAsync_ShouldLogWarning_ForSlowRequests() { // Arrange var context = new DefaultHttpContext(); context.Request.Method = "POST"; context.Request.Path = "/mcp"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}")); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; RequestDelegate next = async (HttpContext ctx) => { // Simulate slow request (> 1 second) await Task.Delay(1100); ctx.Response.StatusCode = 200; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert // Should log warning for slow requests _logger.Received().Log( LogLevel.Warning, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("Slow MCP Request")), Arg.Any(), Arg.Any>()); } [Fact] public async Task InvokeAsync_ShouldSanitizeSensitiveData() { // Arrange var requestBody = "{\"keyHash\":\"secret-key-hash\",\"password\":\"my-password\"}"; var context = new DefaultHttpContext(); context.Request.Method = "POST"; context.Request.Path = "/mcp"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody)); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; var loggedRequest = string.Empty; _logger.When(x => x.Log( LogLevel.Debug, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Request")), Arg.Any(), Arg.Any>())) .Do(callInfo => { var state = callInfo.ArgAt(2); loggedRequest = state.ToString() ?? ""; }); RequestDelegate next = (HttpContext ctx) => { ctx.Response.StatusCode = 200; return Task.CompletedTask; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert loggedRequest.Should().Contain("[REDACTED]"); loggedRequest.Should().NotContain("secret-key-hash"); loggedRequest.Should().NotContain("my-password"); } [Fact] public async Task InvokeAsync_ShouldIncludePerformanceMetrics() { // Arrange var context = new DefaultHttpContext(); context.Request.Method = "POST"; context.Request.Path = "/mcp"; context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}")); context.Response.Body = new MemoryStream(); context.Items["CorrelationId"] = "test-correlation-id"; var loggedResponse = string.Empty; _logger.When(x => x.Log( LogLevel.Debug, Arg.Any(), Arg.Is(o => o.ToString()!.Contains("MCP Response")), Arg.Any(), Arg.Any>())) .Do(callInfo => { var state = callInfo.ArgAt(2); loggedResponse = state.ToString() ?? ""; }); RequestDelegate next = (HttpContext ctx) => { ctx.Response.StatusCode = 200; return Task.CompletedTask; }; var middleware = new McpLoggingMiddleware(next, _logger); // Act await middleware.InvokeAsync(context); // Assert loggedResponse.Should().Contain("Duration:"); loggedResponse.Should().MatchRegex(@"Duration:\s*\d+ms"); } }