using ColaFlow.Modules.Identity.Application.Services;
using ColaFlow.Modules.Identity.Domain.Entities;
using ColaFlow.Modules.Identity.Infrastructure.Persistence;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
namespace ColaFlow.Modules.Identity.Infrastructure.Services;
///
/// Database-backed rate limiting service implementation.
/// Persists rate limit state in PostgreSQL to survive server restarts.
/// Prevents email bombing attacks even after application restart.
///
public class DatabaseEmailRateLimiter : IRateLimitService
{
private readonly IdentityDbContext _context;
private readonly ILogger _logger;
public DatabaseEmailRateLimiter(
IdentityDbContext context,
ILogger logger)
{
_context = context;
_logger = logger;
}
public async Task IsAllowedAsync(
string key,
int maxAttempts,
TimeSpan window,
CancellationToken cancellationToken = default)
{
// Parse key format: "operation:email:tenantId"
// Examples:
// - "forgot-password:user@example.com:tenant-guid"
// - "verification:user@example.com:tenant-guid"
// - "invitation:user@example.com:tenant-guid"
var parts = key.Split(':');
if (parts.Length != 3)
{
_logger.LogWarning("Invalid rate limit key format: {Key}. Expected format: 'operation:email:tenantId'", key);
return true; // Fail open (allow request) if key format is invalid
}
var operationType = parts[0];
var email = parts[1].ToLower();
var tenantIdStr = parts[2];
if (!Guid.TryParse(tenantIdStr, out var tenantId))
{
_logger.LogWarning("Invalid tenant ID in rate limit key: {Key}", key);
return true; // Fail open
}
// Find existing rate limit record
var rateLimit = await _context.EmailRateLimits
.FirstOrDefaultAsync(
r => r.Email == email &&
r.TenantId == tenantId &&
r.OperationType == operationType,
cancellationToken);
// No existing record - create new one and allow
if (rateLimit == null)
{
var newRateLimit = EmailRateLimit.Create(email, tenantId, operationType);
_context.EmailRateLimits.Add(newRateLimit);
try
{
await _context.SaveChangesAsync(cancellationToken);
_logger.LogInformation(
"Rate limit record created for {Email} - {Operation} (Attempt 1/{MaxAttempts})",
email, operationType, maxAttempts);
}
catch (DbUpdateException ex)
{
// Handle race condition: another request created the record simultaneously
_logger.LogWarning(ex,
"Race condition detected while creating rate limit record for {Key}. Retrying...", key);
// Re-fetch the record created by the concurrent request
rateLimit = await _context.EmailRateLimits
.FirstOrDefaultAsync(
r => r.Email == email &&
r.TenantId == tenantId &&
r.OperationType == operationType,
cancellationToken);
if (rateLimit == null)
{
_logger.LogError("Failed to fetch rate limit record after race condition for {Key}", key);
return true; // Fail open
}
// Fall through to existing record logic below
}
if (rateLimit == null)
return true; // Record was successfully created, allow the request
}
// Check if time window has expired
if (rateLimit.IsWindowExpired(window))
{
// Window expired - reset counter and allow
rateLimit.ResetAttempts();
_context.EmailRateLimits.Update(rateLimit);
await _context.SaveChangesAsync(cancellationToken);
_logger.LogInformation(
"Rate limit window expired for {Email} - {Operation}. Counter reset (Attempt 1/{MaxAttempts})",
email, operationType, maxAttempts);
return true;
}
// Window still active - check attempt count
if (rateLimit.AttemptsCount >= maxAttempts)
{
// Rate limit exceeded
var remainingTime = window - (DateTime.UtcNow - rateLimit.LastSentAt);
_logger.LogWarning(
"Rate limit EXCEEDED for {Email} - {Operation}: {Attempts}/{MaxAttempts} attempts. " +
"Retry after {RemainingSeconds} seconds",
email, operationType, rateLimit.AttemptsCount, maxAttempts,
(int)remainingTime.TotalSeconds);
return false;
}
// Still within limit - increment counter and allow
rateLimit.RecordAttempt();
_context.EmailRateLimits.Update(rateLimit);
await _context.SaveChangesAsync(cancellationToken);
_logger.LogInformation(
"Rate limit check passed for {Email} - {Operation} (Attempt {Attempts}/{MaxAttempts})",
email, operationType, rateLimit.AttemptsCount, maxAttempts);
return true;
}
///
/// Cleanup expired rate limit records (call this from a background job)
///
public async Task CleanupExpiredRecordsAsync(TimeSpan retentionPeriod, CancellationToken cancellationToken = default)
{
var cutoffDate = DateTime.UtcNow - retentionPeriod;
var expiredRecords = await _context.EmailRateLimits
.Where(r => r.LastSentAt < cutoffDate)
.ToListAsync(cancellationToken);
if (expiredRecords.Any())
{
_context.EmailRateLimits.RemoveRange(expiredRecords);
await _context.SaveChangesAsync(cancellationToken);
_logger.LogInformation(
"Cleaned up {Count} expired rate limit records older than {CutoffDate}",
expiredRecords.Count, cutoffDate);
}
}
}