This commit is contained in:
Anthony Stirling
2023-08-13 22:46:18 +01:00
parent cadc8e499d
commit 35a998b934
6 changed files with 68 additions and 11 deletions

View File

@@ -9,6 +9,7 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.stereotype.Component;
@@ -23,10 +24,12 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.SPDF.model.Role;
@Component
public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
private final Map<String, Bucket> buckets = new ConcurrentHashMap<>();
private final Map<String, Bucket> apiBuckets = new ConcurrentHashMap<>();
private final Map<String, Bucket> webBuckets = new ConcurrentHashMap<>();
@Autowired
private UserDetailsService userDetailsService;
@@ -39,7 +42,6 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
if (!rateLimit) {
// If rateLimit is not enabled, just pass all requests without rate limiting
filterChain.doFilter(request, response);
@@ -47,7 +49,6 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
}
String method = request.getMethod();
if (!"POST".equalsIgnoreCase(method)) {
// If the request is not a POST, just pass it through without rate limiting
filterChain.doFilter(request, response);
@@ -73,7 +74,34 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
identifier = request.getRemoteAddr();
}
Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket());
Role userRole = getRoleFromAuthentication(SecurityContextHolder.getContext().getAuthentication());
if (request.getHeader("X-API-Key") != null) {
// It's an API call
processRequest(userRole.getApiCallsPerDay(), identifier, apiBuckets, request, response, filterChain);
} else {
// It's a Web UI call
processRequest(userRole.getWebCallsPerDay(), identifier, webBuckets, request, response, filterChain);
}
}
private Role getRoleFromAuthentication(Authentication authentication) {
if (authentication != null && authentication.isAuthenticated()) {
for (GrantedAuthority authority : authentication.getAuthorities()) {
try {
return Role.fromString(authority.getAuthority());
} catch (IllegalArgumentException ex) {
// Ignore and continue to next authority.
}
}
}
throw new IllegalStateException("User does not have a valid role.");
}
private void processRequest(int limitPerDay, String identifier, Map<String, Bucket> buckets,
HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
throws IOException, ServletException {
Bucket userBucket = buckets.computeIfAbsent(identifier, k -> createUserBucket(limitPerDay));
ConsumptionProbe probe = userBucket.tryConsumeAndReturnRemaining(1);
if (probe.isConsumed()) {
@@ -84,12 +112,11 @@ public class UserBasedRateLimitingFilter extends OncePerRequestFilter {
response.setStatus(HttpStatus.TOO_MANY_REQUESTS.value());
response.setHeader("X-Rate-Limit-Retry-After-Seconds", String.valueOf(waitForRefill));
response.getWriter().write("Rate limit exceeded for POST requests.");
return;
}
}
private Bucket createUserBucket() {
Bandwidth limit = Bandwidth.classic(1000, Refill.intervally(1000, Duration.ofDays(1)));
private Bucket createUserBucket(int limitPerDay) {
Bandwidth limit = Bandwidth.classic(limitPerDay, Refill.intervally(limitPerDay, Duration.ofDays(1)));
return Bucket.builder().addLimit(limit).build();
}
}

View File

@@ -19,9 +19,10 @@ public class InitialSetup {
String initialPassword = System.getenv("INITIAL_PASSWORD");
if(initialUsername != null && initialPassword != null) {
userService.saveUser(initialUsername, initialPassword, Role.ADMIN.getRoleId());
} else {
userService.saveUser("admin", "password", Role.ADMIN.getRoleId());
}
// else {
// userService.saveUser("admin", "password", Role.ADMIN.getRoleId());
// }
}
}
}