Implement comprehensive error handling and structured logging for MCP module. **Exception Hierarchy**: - Created McpException base class with JSON-RPC error mapping - Implemented 8 specific exception types (Parse, InvalidRequest, MethodNotFound, etc.) - Each exception maps to correct HTTP status code (401, 403, 404, 422, 400, 500) **Middleware**: - McpCorrelationIdMiddleware: Generates/extracts correlation ID for request tracking - McpExceptionHandlerMiddleware: Global exception handler with JSON-RPC error responses - McpLoggingMiddleware: Request/response logging with sensitive data sanitization **Serilog Integration**: - Configured structured logging with Console and File sinks - Log rotation (daily, 30-day retention) - Correlation ID enrichment in all log entries **Features**: - Correlation ID propagation across request chain - Structured logging with TenantId, UserId, ApiKeyId - Sensitive data sanitization (API keys, passwords) - Performance metrics (request duration, slow request warnings) - JSON-RPC 2.0 compliant error responses **Testing**: - 174 tests passing (all MCP module tests) - Unit tests for all exception classes - Unit tests for all middleware components - 100% coverage of error mapping and HTTP status codes **Files Added**: - 9 exception classes in Domain/Exceptions/ - 3 middleware classes in Infrastructure/Middleware/ - 4 test files with comprehensive coverage **Files Modified**: - Program.cs: Serilog configuration - McpServiceExtensions.cs: Middleware pipeline registration - JsonRpcError.cs: Added parameterless constructor for deserialization - MCP Infrastructure .csproj: Added Serilog package reference **Verification**: ✅ All 174 MCP module tests passing ✅ Build successful with no errors ✅ Exception-to-HTTP-status mapping verified ✅ Correlation ID propagation tested ✅ Sensitive data sanitization verified Story: docs/stories/sprint_5/story_5_4.md 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
242 lines
7.8 KiB
C#
242 lines
7.8 KiB
C#
using ColaFlow.Modules.Mcp.Infrastructure.Middleware;
|
|
using FluentAssertions;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.Logging;
|
|
using NSubstitute;
|
|
using System.Text;
|
|
|
|
namespace ColaFlow.Modules.Mcp.Tests.Infrastructure.Middleware;
|
|
|
|
/// <summary>
|
|
/// Unit tests for McpLoggingMiddleware
|
|
/// </summary>
|
|
public class McpLoggingMiddlewareTests
|
|
{
|
|
private readonly ILogger<McpLoggingMiddleware> _logger;
|
|
|
|
public McpLoggingMiddlewareTests()
|
|
{
|
|
_logger = Substitute.For<ILogger<McpLoggingMiddleware>>();
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldSkipLogging_ForNonMcpRequests()
|
|
{
|
|
// Arrange
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "GET";
|
|
context.Request.Path = "/api/tasks";
|
|
|
|
var nextCalled = false;
|
|
RequestDelegate next = (HttpContext ctx) =>
|
|
{
|
|
nextCalled = true;
|
|
return Task.CompletedTask;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
nextCalled.Should().BeTrue();
|
|
// Logger should not be called for non-MCP requests
|
|
_logger.DidNotReceive().Log(
|
|
Arg.Any<LogLevel>(),
|
|
Arg.Any<EventId>(),
|
|
Arg.Any<object>(),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>());
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldLogRequestAndResponse_ForMcpRequests()
|
|
{
|
|
// Arrange
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "POST";
|
|
context.Request.Path = "/mcp";
|
|
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"method\":\"initialize\"}"));
|
|
context.Response.Body = new MemoryStream();
|
|
context.Items["CorrelationId"] = "test-correlation-id";
|
|
|
|
RequestDelegate next = (HttpContext ctx) =>
|
|
{
|
|
ctx.Response.StatusCode = 200;
|
|
var responseBytes = Encoding.UTF8.GetBytes("{\"jsonrpc\":\"2.0\",\"result\":{}}");
|
|
ctx.Response.Body.Write(responseBytes, 0, responseBytes.Length);
|
|
return Task.CompletedTask;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
// Should log request (Debug level)
|
|
_logger.Received().Log(
|
|
LogLevel.Debug,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("MCP Request")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>());
|
|
|
|
// Should log response (Debug level for 2xx status)
|
|
_logger.Received().Log(
|
|
LogLevel.Debug,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("MCP Response")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>());
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldLogErrorLevel_ForErrorResponses()
|
|
{
|
|
// Arrange
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "POST";
|
|
context.Request.Path = "/mcp";
|
|
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
|
|
context.Response.Body = new MemoryStream();
|
|
context.Items["CorrelationId"] = "test-correlation-id";
|
|
|
|
RequestDelegate next = (HttpContext ctx) =>
|
|
{
|
|
ctx.Response.StatusCode = 500; // Error status
|
|
return Task.CompletedTask;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
// Should log response at Error level for 5xx status
|
|
_logger.Received().Log(
|
|
LogLevel.Error,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("MCP Response")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>());
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldLogWarning_ForSlowRequests()
|
|
{
|
|
// Arrange
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "POST";
|
|
context.Request.Path = "/mcp";
|
|
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
|
|
context.Response.Body = new MemoryStream();
|
|
context.Items["CorrelationId"] = "test-correlation-id";
|
|
|
|
RequestDelegate next = async (HttpContext ctx) =>
|
|
{
|
|
// Simulate slow request (> 1 second)
|
|
await Task.Delay(1100);
|
|
ctx.Response.StatusCode = 200;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
// Should log warning for slow requests
|
|
_logger.Received().Log(
|
|
LogLevel.Warning,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("Slow MCP Request")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>());
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldSanitizeSensitiveData()
|
|
{
|
|
// Arrange
|
|
var requestBody = "{\"keyHash\":\"secret-key-hash\",\"password\":\"my-password\"}";
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "POST";
|
|
context.Request.Path = "/mcp";
|
|
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes(requestBody));
|
|
context.Response.Body = new MemoryStream();
|
|
context.Items["CorrelationId"] = "test-correlation-id";
|
|
|
|
var loggedRequest = string.Empty;
|
|
_logger.When(x => x.Log(
|
|
LogLevel.Debug,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("MCP Request")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>()))
|
|
.Do(callInfo =>
|
|
{
|
|
var state = callInfo.ArgAt<object>(2);
|
|
loggedRequest = state.ToString() ?? "";
|
|
});
|
|
|
|
RequestDelegate next = (HttpContext ctx) =>
|
|
{
|
|
ctx.Response.StatusCode = 200;
|
|
return Task.CompletedTask;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
loggedRequest.Should().Contain("[REDACTED]");
|
|
loggedRequest.Should().NotContain("secret-key-hash");
|
|
loggedRequest.Should().NotContain("my-password");
|
|
}
|
|
|
|
[Fact]
|
|
public async Task InvokeAsync_ShouldIncludePerformanceMetrics()
|
|
{
|
|
// Arrange
|
|
var context = new DefaultHttpContext();
|
|
context.Request.Method = "POST";
|
|
context.Request.Path = "/mcp";
|
|
context.Request.Body = new MemoryStream(Encoding.UTF8.GetBytes("{}"));
|
|
context.Response.Body = new MemoryStream();
|
|
context.Items["CorrelationId"] = "test-correlation-id";
|
|
|
|
var loggedResponse = string.Empty;
|
|
_logger.When(x => x.Log(
|
|
LogLevel.Debug,
|
|
Arg.Any<EventId>(),
|
|
Arg.Is<object>(o => o.ToString()!.Contains("MCP Response")),
|
|
Arg.Any<Exception>(),
|
|
Arg.Any<Func<object, Exception?, string>>()))
|
|
.Do(callInfo =>
|
|
{
|
|
var state = callInfo.ArgAt<object>(2);
|
|
loggedResponse = state.ToString() ?? "";
|
|
});
|
|
|
|
RequestDelegate next = (HttpContext ctx) =>
|
|
{
|
|
ctx.Response.StatusCode = 200;
|
|
return Task.CompletedTask;
|
|
};
|
|
|
|
var middleware = new McpLoggingMiddleware(next, _logger);
|
|
|
|
// Act
|
|
await middleware.InvokeAsync(context);
|
|
|
|
// Assert
|
|
loggedResponse.Should().Contain("Duration:");
|
|
loggedResponse.Should().MatchRegex(@"Duration:\s*\d+ms");
|
|
}
|
|
}
|