Admin panel - Enhanced User Management & Fix: #1630 (#1658)

* Prevents SSO login due to faulty verification

* add translation & fix show error message

* Update settings.yml.template

---------

Co-authored-by: Anthony Stirling <77850077+Frooodle@users.noreply.github.com>
This commit is contained in:
Ludy
2024-08-16 12:57:37 +02:00
committed by GitHub
parent 2cbe34ea24
commit 29fcbf30d7
61 changed files with 1318 additions and 221 deletions

View File

@@ -3,9 +3,8 @@ package stirling.software.SPDF.config.security;
import java.io.IOException;
import java.util.Optional;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.InternalAuthenticationServiceException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.AuthenticationException;
@@ -15,17 +14,16 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationFa
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.User;
@Slf4j
public class CustomAuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
private LoginAttemptService loginAttemptService;
private UserService userService;
private static final Logger logger =
LoggerFactory.getLogger(CustomAuthenticationFailureHandler.class);
public CustomAuthenticationFailureHandler(
final LoginAttemptService loginAttemptService, UserService userService) {
this.loginAttemptService = loginAttemptService;
@@ -39,14 +37,17 @@ public class CustomAuthenticationFailureHandler extends SimpleUrlAuthenticationF
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof DisabledException) {
log.error("User is deactivated: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
String ip = request.getRemoteAddr();
logger.error("Failed login attempt from IP: {}", ip);
log.error("Failed login attempt from IP: {}", ip);
String contextPath = request.getContextPath();
if (exception.getClass().isAssignableFrom(InternalAuthenticationServiceException.class)
|| "Password must not be null".equalsIgnoreCase(exception.getMessage())) {
response.sendRedirect(contextPath + "/login?error=oauth2AuthenticationError");
if (exception instanceof LockedException) {
getRedirectStrategy().sendRedirect(request, response, "/login?error=locked");
return;
}
@@ -54,20 +55,25 @@ public class CustomAuthenticationFailureHandler extends SimpleUrlAuthenticationF
Optional<User> optUser = userService.findByUsernameIgnoreCase(username);
if (username != null && optUser.isPresent() && !isDemoUser(optUser)) {
logger.info(
log.info(
"Remaining attempts for user {}: {}",
optUser.get().getUsername(),
username,
loginAttemptService.getRemainingAttempts(username));
loginAttemptService.loginFailed(username);
if (loginAttemptService.isBlocked(username)
|| exception.getClass().isAssignableFrom(LockedException.class)) {
response.sendRedirect(contextPath + "/login?error=locked");
if (loginAttemptService.isBlocked(username) || exception instanceof LockedException) {
getRedirectStrategy().sendRedirect(request, response, "/login?error=locked");
return;
}
}
if (exception.getClass().isAssignableFrom(BadCredentialsException.class)
|| exception.getClass().isAssignableFrom(UsernameNotFoundException.class)) {
response.sendRedirect(contextPath + "/login?error=badcredentials");
if (exception instanceof BadCredentialsException
|| exception instanceof UsernameNotFoundException) {
getRedirectStrategy().sendRedirect(request, response, "/login?error=badcredentials");
return;
}
if (exception instanceof InternalAuthenticationServiceException
|| "Password must not be null".equalsIgnoreCase(exception.getMessage())) {
getRedirectStrategy()
.sendRedirect(request, response, "/login?error=oauth2AuthenticationError");
return;
}

View File

@@ -10,15 +10,20 @@ import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.utils.RequestUriUtils;
@Slf4j
public class CustomAuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private UserService userService;
public CustomAuthenticationSuccessHandler(LoginAttemptService loginAttemptService) {
public CustomAuthenticationSuccessHandler(
LoginAttemptService loginAttemptService, UserService userService) {
this.loginAttemptService = loginAttemptService;
this.userService = userService;
}
@Override
@@ -27,6 +32,10 @@ public class CustomAuthenticationSuccessHandler
throws ServletException, IOException {
String userName = request.getParameter("username");
if (userService.isUserDisabled(userName)) {
getRedirectStrategy().sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
loginAttemptService.loginSucceeded(userName);
// Get the saved request

View File

@@ -2,32 +2,26 @@ package stirling.software.SPDF.config.security;
import java.io.IOException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
@Autowired SessionRegistry sessionRegistry;
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
HttpSession session = request.getSession(false);
if (session != null) {
String sessionId = session.getId();
sessionRegistry.removeSessionInformation(sessionId);
session.invalidate();
logger.debug("Session invalidated: " + sessionId);
if (request.getParameter("userIsDisabled") != null) {
getRedirectStrategy()
.sendRedirect(request, response, "/login?erroroauth=userIsDisabled");
return;
}
response.sendRedirect(request.getContextPath() + "/login?logout=true");
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
}
}

View File

@@ -3,8 +3,6 @@ package stirling.software.SPDF.config.security;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
@@ -17,8 +15,6 @@ public class LoginAttemptService {
@Autowired ApplicationProperties applicationProperties;
private static final Logger logger = LoggerFactory.getLogger(LoginAttemptService.class);
private int MAX_ATTEMPT;
private long ATTEMPT_INCREMENT_TIME;
private ConcurrentHashMap<String, AttemptCounter> attemptsCache;

View File

@@ -18,8 +18,6 @@ import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.core.session.SessionRegistryImpl;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -37,6 +35,7 @@ import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationF
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2LogoutSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2UserService;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client;
@@ -47,7 +46,7 @@ import stirling.software.SPDF.model.provider.KeycloakProvider;
import stirling.software.SPDF.repository.JPATokenRepositoryImpl;
@Configuration
@EnableWebSecurity()
@EnableWebSecurity
@EnableMethodSecurity
public class SecurityConfiguration {
@@ -73,11 +72,7 @@ public class SecurityConfiguration {
@Autowired private LoginAttemptService loginAttemptService;
@Autowired private FirstLoginFilter firstLoginFilter;
@Bean
public SessionRegistry sessionRegistry() {
return new SessionRegistryImpl();
}
@Autowired private SessionPersistentRegistry sessionRegistry;
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
@@ -94,7 +89,7 @@ public class SecurityConfiguration {
.sessionCreationPolicy(SessionCreationPolicy.IF_REQUIRED)
.maximumSessions(10)
.maxSessionsPreventsLogin(false)
.sessionRegistry(sessionRegistry())
.sessionRegistry(sessionRegistry)
.expiredUrl("/login?logout=true"));
http.formLogin(
@@ -103,7 +98,7 @@ public class SecurityConfiguration {
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService))
loginAttemptService, userService))
.defaultSuccessUrl("/")
.failureHandler(
new CustomAuthenticationFailureHandler(
@@ -160,7 +155,11 @@ public class SecurityConfiguration {
// Handle OAUTH2 Logins
if (applicationProperties.getSecurity().getOAUTH2() != null
&& applicationProperties.getSecurity().getOAUTH2().getEnabled()) {
&& applicationProperties.getSecurity().getOAUTH2().getEnabled()
&& !applicationProperties
.getSecurity()
.getLoginMethod()
.equalsIgnoreCase("normal")) {
http.oauth2Login(
oauth2 ->
@@ -191,10 +190,8 @@ public class SecurityConfiguration {
.logout(
logout ->
logout.logoutSuccessHandler(
new CustomOAuth2LogoutSuccessHandler(
this.applicationProperties,
sessionRegistry()))
.invalidateHttpSession(true));
new CustomOAuth2LogoutSuccessHandler(
applicationProperties)));
}
} else {
http.csrf(csrf -> csrf.disable())

View File

@@ -1,6 +1,7 @@
package stirling.software.SPDF.config.security;
import java.io.IOException;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
@@ -9,8 +10,9 @@ import org.springframework.http.HttpStatus;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UserDetailsService;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component;
import org.springframework.web.filter.OncePerRequestFilter;
@@ -18,15 +20,16 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApiKeyAuthenticationToken;
@Component
public class UserAuthenticationFilter extends OncePerRequestFilter {
@Autowired private UserDetailsService userDetailsService;
@Autowired @Lazy private UserService userService;
@Autowired private SessionPersistentRegistry sessionPersistentRegistry;
@Autowired
@Qualifier("loginEnabled")
public boolean loginEnabledValue;
@@ -87,6 +90,43 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
}
}
// Check if the authenticated user is disabled and invalidate their session if so
if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal();
String username = null;
if (principal instanceof UserDetails) {
username = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
username = ((OAuth2User) principal).getName();
} else if (principal instanceof String) {
username = (String) principal;
}
List<SessionInformation> sessionsInformations =
sessionPersistentRegistry.getAllSessions(principal, false);
if (username != null) {
boolean isUserExists = userService.usernameExistsIgnoreCase(username);
boolean isUserDisabled = userService.isUserDisabled(username);
if (!isUserExists || isUserDisabled) {
for (SessionInformation sessionsInformation : sessionsInformations) {
sessionsInformation.expireNow();
sessionPersistentRegistry.expireSession(sessionsInformation.getSessionId());
}
}
if (!isUserExists) {
response.sendRedirect(request.getContextPath() + "/logout?badcredentials=true");
return;
}
if (isUserDisabled) {
response.sendRedirect(request.getContextPath() + "/logout?userIsDisabled=true");
return;
}
}
}
filterChain.doFilter(request, response);
}

View File

@@ -15,12 +15,15 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.core.userdetails.UsernameNotFoundException;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import stirling.software.SPDF.config.DatabaseBackupInterface;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.controller.api.pipeline.UserServiceInterface;
import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.model.Authority;
@@ -40,6 +43,8 @@ public class UserService implements UserServiceInterface {
@Autowired private MessageSource messageSource;
@Autowired private SessionPersistentRegistry sessionRegistry;
@Autowired DatabaseBackupInterface databaseBackupHelper;
// Handle OAUTH2 login and user auto creation.
@@ -48,7 +53,7 @@ public class UserService implements UserServiceInterface {
if (!isUsernameValid(username)) {
return false;
}
Optional<User> existingUser = userRepository.findByUsernameIgnoreCase(username);
Optional<User> existingUser = findByUsernameIgnoreCase(username);
if (existingUser.isPresent()) {
return true;
}
@@ -90,8 +95,7 @@ public class UserService implements UserServiceInterface {
public User addApiKeyToUser(String username) {
User user =
userRepository
.findByUsernameIgnoreCase(username)
findByUsernameIgnoreCase(username)
.orElseThrow(() -> new UsernameNotFoundException("User not found"));
user.setApiKey(generateApiKey());
@@ -104,8 +108,7 @@ public class UserService implements UserServiceInterface {
public String getApiKeyForUser(String username) {
User user =
userRepository
.findByUsernameIgnoreCase(username)
findByUsernameIgnoreCase(username)
.orElseThrow(() -> new UsernameNotFoundException("User not found"));
return user.getApiKey();
}
@@ -131,12 +134,17 @@ public class UserService implements UserServiceInterface {
}
public boolean validateApiKeyForUser(String username, String apiKey) {
Optional<User> userOpt = userRepository.findByUsernameIgnoreCase(username);
Optional<User> userOpt = findByUsernameIgnoreCase(username);
return userOpt.isPresent() && apiKey.equals(userOpt.get().getApiKey());
}
public void saveUser(String username, AuthenticationType authenticationType)
throws IllegalArgumentException, IOException {
saveUser(username, authenticationType, Role.USER.getRoleId());
}
public void saveUser(String username, AuthenticationType authenticationType, String role)
throws IllegalArgumentException, IOException {
if (!isUsernameValid(username)) {
throw new IllegalArgumentException(getInvalidUsernameMessage());
}
@@ -144,7 +152,7 @@ public class UserService implements UserServiceInterface {
user.setUsername(username);
user.setEnabled(true);
user.setFirstLogin(false);
user.addAuthority(new Authority(Role.USER.getRoleId(), user));
user.addAuthority(new Authority(role, user));
user.setAuthenticationType(authenticationType);
userRepository.save(user);
databaseBackupHelper.exportDatabase();
@@ -186,7 +194,7 @@ public class UserService implements UserServiceInterface {
}
public void deleteUser(String username) {
Optional<User> userOpt = userRepository.findByUsernameIgnoreCase(username);
Optional<User> userOpt = findByUsernameIgnoreCase(username);
if (userOpt.isPresent()) {
for (Authority authority : userOpt.get().getAuthorities()) {
if (authority.getAuthority().equals(Role.INTERNAL_API_USER.getRoleId())) {
@@ -195,21 +203,20 @@ public class UserService implements UserServiceInterface {
}
userRepository.delete(userOpt.get());
}
invalidateUserSessions(username);
}
public boolean usernameExists(String username) {
return userRepository.findByUsername(username).isPresent();
return findByUsername(username).isPresent();
}
public boolean usernameExistsIgnoreCase(String username) {
return userRepository.findByUsernameIgnoreCase(username).isPresent();
return findByUsernameIgnoreCase(username).isPresent();
}
public boolean hasUsers() {
long userCount = userRepository.count();
if (userRepository
.findByUsernameIgnoreCase(Role.INTERNAL_API_USER.getRoleId())
.isPresent()) {
if (findByUsernameIgnoreCase(Role.INTERNAL_API_USER.getRoleId()).isPresent()) {
userCount -= 1;
}
return userCount > 0;
@@ -217,7 +224,7 @@ public class UserService implements UserServiceInterface {
public void updateUserSettings(String username, Map<String, String> updates)
throws IOException {
Optional<User> userOpt = userRepository.findByUsernameIgnoreCase(username);
Optional<User> userOpt = findByUsernameIgnoreCase(username);
if (userOpt.isPresent()) {
User user = userOpt.get();
Map<String, String> settingsMap = user.getSettings();
@@ -268,10 +275,17 @@ public class UserService implements UserServiceInterface {
databaseBackupHelper.exportDatabase();
}
public void changeRole(User user, String newRole) {
public void changeRole(User user, String newRole) throws IOException {
Authority userAuthority = this.findRole(user);
userAuthority.setAuthority(newRole);
authorityRepository.save(userAuthority);
databaseBackupHelper.exportDatabase();
}
public void changeUserEnabled(User user, Boolean enbeled) throws IOException {
user.setEnabled(enbeled);
userRepository.save(user);
databaseBackupHelper.exportDatabase();
}
public boolean isPasswordCorrect(User user, String currentPassword) {
@@ -295,14 +309,40 @@ public class UserService implements UserServiceInterface {
}
public boolean hasPassword(String username) {
Optional<User> user = userRepository.findByUsernameIgnoreCase(username);
Optional<User> user = findByUsernameIgnoreCase(username);
return user.isPresent() && user.get().hasPassword();
}
public boolean isAuthenticationTypeByUsername(
String username, AuthenticationType authenticationType) {
Optional<User> user = userRepository.findByUsernameIgnoreCase(username);
Optional<User> user = findByUsernameIgnoreCase(username);
return user.isPresent()
&& authenticationType.name().equalsIgnoreCase(user.get().getAuthenticationType());
}
public boolean isUserDisabled(String username) {
Optional<User> userOpt = findByUsernameIgnoreCase(username);
return userOpt.map(user -> !user.isEnabled()).orElse(false);
}
public void invalidateUserSessions(String username) {
String usernameP = "";
for (Object principal : sessionRegistry.getAllPrincipals()) {
for (SessionInformation sessionsInformation :
sessionRegistry.getAllSessions(principal, false)) {
if (principal instanceof UserDetails) {
UserDetails userDetails = (UserDetails) principal;
usernameP = userDetails.getUsername();
} else if (principal instanceof OAuth2User) {
OAuth2User oAuth2User = (OAuth2User) principal;
usernameP = oAuth2User.getName();
} else if (principal instanceof String) {
usernameP = (String) principal;
}
if (usernameP.equalsIgnoreCase(username)) {
sessionRegistry.expireSession(sessionsInformation.getSessionId());
}
}
}
}
}

View File

@@ -2,8 +2,8 @@ package stirling.software.SPDF.config.security.oauth2;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -13,19 +13,34 @@ import org.springframework.security.web.authentication.SimpleUrlAuthenticationFa
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CustomOAuth2AuthenticationFailureHandler
extends SimpleUrlAuthenticationFailureHandler {
private static final Logger logger =
LoggerFactory.getLogger(CustomOAuth2AuthenticationFailureHandler.class);
@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof BadCredentialsException) {
log.error("BadCredentialsException", exception);
getRedirectStrategy().sendRedirect(request, response, "/login?error=badcredentials");
return;
}
if (exception instanceof DisabledException) {
log.error("User is deactivated: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (exception instanceof LockedException) {
log.error("Account locked: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?error=locked");
return;
}
if (exception instanceof OAuth2AuthenticationException) {
OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
@@ -34,17 +49,13 @@ public class CustomOAuth2AuthenticationFailureHandler
if (error.getErrorCode().equals("Password must not be null")) {
errorCode = "userAlreadyExistsWeb";
}
logger.error("OAuth2 Authentication error: " + errorCode);
log.error("OAuth2 Authentication error: " + errorCode);
log.error("OAuth2AuthenticationException", exception);
getRedirectStrategy()
.sendRedirect(request, response, "/logout?erroroauth=" + errorCode);
return;
} else if (exception instanceof LockedException) {
logger.error("Account locked: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?error=locked");
return;
} else {
logger.error("Unhandled authentication exception", exception);
super.onAuthenticationFailure(request, response, exception);
}
log.error("Unhandled authentication exception", exception);
super.onAuthenticationFailure(request, response, exception);
}
}

View File

@@ -2,10 +2,9 @@ package stirling.software.SPDF.config.security.oauth2;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
@@ -26,9 +25,6 @@ public class CustomOAuth2AuthenticationSuccessHandler
private LoginAttemptService loginAttemptService;
private static final Logger logger =
LoggerFactory.getLogger(CustomOAuth2AuthenticationSuccessHandler.class);
private ApplicationProperties applicationProperties;
private UserService userService;
@@ -46,6 +42,17 @@ public class CustomOAuth2AuthenticationSuccessHandler
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
Object principal = authentication.getPrincipal();
String username = "";
if (principal instanceof OAuth2User) {
OAuth2User oauthUser = (OAuth2User) principal;
username = oauthUser.getName();
} else if (principal instanceof UserDetails) {
UserDetails oauthUser = (UserDetails) principal;
username = oauthUser.getUsername();
}
// Get the saved request
HttpSession session = request.getSession(false);
String contextPath = request.getContextPath();
@@ -59,11 +66,8 @@ public class CustomOAuth2AuthenticationSuccessHandler
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
OAuth2User oauthUser = (OAuth2User) authentication.getPrincipal();
OAUTH2 oAuth = applicationProperties.getSecurity().getOAUTH2();
String username = oauthUser.getName();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
@@ -78,15 +82,21 @@ public class CustomOAuth2AuthenticationSuccessHandler
&& oAuth.getAutoCreateUser()) {
response.sendRedirect(contextPath + "/logout?oauth2AuthenticationErrorWeb=true");
return;
} else {
try {
userService.processOAuth2PostLogin(username, oAuth.getAutoCreateUser());
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
}
try {
if (oAuth.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(contextPath + "/logout?oauth2_admin_blocked_user=true");
return;
}
if (principal instanceof OAuth2User) {
userService.processOAuth2PostLogin(username, oAuth.getAutoCreateUser());
}
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
}

View File

@@ -2,34 +2,26 @@ package stirling.software.SPDF.config.security.oauth2;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.Provider;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
private static final Logger logger =
LoggerFactory.getLogger(CustomOAuth2LogoutSuccessHandler.class);
private final SessionRegistry sessionRegistry;
private final ApplicationProperties applicationProperties;
public CustomOAuth2LogoutSuccessHandler(
ApplicationProperties applicationProperties, SessionRegistry sessionRegistry) {
this.sessionRegistry = sessionRegistry;
public CustomOAuth2LogoutSuccessHandler(ApplicationProperties applicationProperties) {
this.applicationProperties = applicationProperties;
}
@@ -42,6 +34,15 @@ public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHand
String issuer = null;
String clientId = null;
if (authentication == null) {
if (request.getParameter("userIsDisabled") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userIsDisabled");
} else {
super.onLogoutSuccess(request, response, authentication);
}
return;
}
OAUTH2 oauth = applicationProperties.getSecurity().getOAUTH2();
if (authentication instanceof OAuth2AuthenticationToken) {
@@ -53,9 +54,8 @@ public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHand
issuer = provider.getIssuer();
clientId = provider.getClientId();
} catch (UnsupportedProviderException e) {
logger.error(e.getMessage());
log.error(e.getMessage());
}
} else {
registrationId = oauth.getProvider() != null ? oauth.getProvider() : "";
issuer = oauth.getIssuer();
@@ -70,18 +70,16 @@ public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHand
param = "erroroauth=" + sanitizeInput(errorMessage);
} else if (request.getParameter("oauth2AutoCreateDisabled") != null) {
param = "error=oauth2AutoCreateDisabled";
} else if (request.getParameter("oauth2_admin_blocked_user") != null) {
param = "erroroauth=oauth2_admin_blocked_user";
} else if (request.getParameter("userIsDisabled") != null) {
param = "erroroauth=userIsDisabled";
} else if (request.getParameter("badcredentials") != null) {
param = "error=badcredentials";
}
String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param;
HttpSession session = request.getSession(false);
if (session != null) {
String sessionId = session.getId();
sessionRegistry.removeSessionInformation(sessionId);
session.invalidate();
logger.info("Session invalidated: " + sessionId);
}
switch (registrationId.toLowerCase()) {
case "keycloak":
// Add Keycloak specific logout URL if needed
@@ -92,13 +90,13 @@ public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHand
+ clientId
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirect_url);
logger.info("Redirecting to Keycloak logout URL: " + logoutUrl);
log.info("Redirecting to Keycloak logout URL: " + logoutUrl);
response.sendRedirect(logoutUrl);
break;
case "github":
// Add GitHub specific logout URL if needed
String githubLogoutUrl = "https://github.com/logout";
logger.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
response.sendRedirect(githubLogoutUrl);
break;
case "google":
@@ -106,13 +104,14 @@ public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHand
// String googleLogoutUrl =
// "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue="
// + response.encodeRedirectURL(redirect_url);
// logger.info("Redirecting to Google logout URL: " + googleLogoutUrl);
log.info("Google does not have a specific logout URL");
// log.info("Redirecting to Google logout URL: " + googleLogoutUrl);
// response.sendRedirect(googleLogoutUrl);
// break;
default:
String redirectUrl = request.getContextPath() + "/login?" + param;
logger.info("Redirecting to default logout URL: " + redirectUrl);
response.sendRedirect(redirectUrl);
String defaultRedirectUrl = request.getContextPath() + "/login?" + param;
log.info("Redirecting to default logout URL: " + defaultRedirectUrl);
response.sendRedirect(defaultRedirectUrl);
break;
}
}

View File

@@ -0,0 +1,26 @@
package stirling.software.SPDF.config.security.session;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import jakarta.servlet.http.HttpSessionEvent;
import jakarta.servlet.http.HttpSessionListener;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
public class CustomHttpSessionListener implements HttpSessionListener {
@Autowired private SessionPersistentRegistry sessionPersistentRegistry;
@Override
public void sessionCreated(HttpSessionEvent se) {
log.info("Session created: " + se.getSession().getId());
}
@Override
public void sessionDestroyed(HttpSessionEvent se) {
log.info("Session destroyed: " + se.getSession().getId());
sessionPersistentRegistry.expireSession(se.getSession().getId());
}
}

View File

@@ -0,0 +1,183 @@
package stirling.software.SPDF.config.security.session;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Date;
import java.util.List;
import java.util.Optional;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.security.core.session.SessionRegistry;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component;
import jakarta.transaction.Transactional;
import stirling.software.SPDF.model.SessionEntity;
@Component
public class SessionPersistentRegistry implements SessionRegistry {
private final SessionRepository sessionRepository;
@Value("${server.servlet.session.timeout:30m}")
private Duration defaultMaxInactiveInterval;
public SessionPersistentRegistry(SessionRepository sessionRepository) {
this.sessionRepository = sessionRepository;
}
@Override
public List<Object> getAllPrincipals() {
List<SessionEntity> sessions = sessionRepository.findAll();
List<Object> principals = new ArrayList<>();
for (SessionEntity session : sessions) {
principals.add(session.getPrincipalName());
}
return principals;
}
@Override
public List<SessionInformation> getAllSessions(
Object principal, boolean includeExpiredSessions) {
List<SessionInformation> sessionInformations = new ArrayList<>();
String principalName = null;
if (principal instanceof UserDetails) {
principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof String) {
principalName = (String) principal;
}
if (principalName != null) {
List<SessionEntity> sessionEntities =
sessionRepository.findByPrincipalName(principalName);
for (SessionEntity sessionEntity : sessionEntities) {
if (includeExpiredSessions || !sessionEntity.isExpired()) {
sessionInformations.add(
new SessionInformation(
sessionEntity.getPrincipalName(),
sessionEntity.getSessionId(),
sessionEntity.getLastRequest()));
}
}
}
return sessionInformations;
}
@Override
@Transactional
public void registerNewSession(String sessionId, Object principal) {
String principalName = null;
if (principal instanceof UserDetails) {
principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof String) {
principalName = (String) principal;
}
if (principalName != null) {
SessionEntity sessionEntity = new SessionEntity();
sessionEntity.setSessionId(sessionId);
sessionEntity.setPrincipalName(principalName);
sessionEntity.setLastRequest(new Date()); // Set lastRequest to the current date
sessionEntity.setExpired(false);
sessionRepository.save(sessionEntity);
}
}
@Override
@Transactional
public void removeSessionInformation(String sessionId) {
sessionRepository.deleteById(sessionId);
}
@Override
@Transactional
public void refreshLastRequest(String sessionId) {
Optional<SessionEntity> sessionEntityOpt = sessionRepository.findById(sessionId);
if (sessionEntityOpt.isPresent()) {
SessionEntity sessionEntity = sessionEntityOpt.get();
sessionEntity.setLastRequest(new Date()); // Update lastRequest to the current date
sessionRepository.save(sessionEntity);
}
}
@Override
public SessionInformation getSessionInformation(String sessionId) {
Optional<SessionEntity> sessionEntityOpt = sessionRepository.findById(sessionId);
if (sessionEntityOpt.isPresent()) {
SessionEntity sessionEntity = sessionEntityOpt.get();
return new SessionInformation(
sessionEntity.getPrincipalName(),
sessionEntity.getSessionId(),
sessionEntity.getLastRequest());
}
return null;
}
// Retrieve all non-expired sessions
public List<SessionEntity> getAllSessionsNotExpired() {
return sessionRepository.findByExpired(false);
}
// Retrieve all sessions
public List<SessionEntity> getAllSessions() {
return sessionRepository.findAll();
}
// Mark a session as expired
public void expireSession(String sessionId) {
Optional<SessionEntity> sessionEntityOpt = sessionRepository.findById(sessionId);
if (sessionEntityOpt.isPresent()) {
SessionEntity sessionEntity = sessionEntityOpt.get();
sessionEntity.setExpired(true); // Set expired to true
sessionRepository.save(sessionEntity);
}
}
// Get the maximum inactive interval for sessions
public int getMaxInactiveInterval() {
return (int) defaultMaxInactiveInterval.getSeconds();
}
// Retrieve a session entity by session ID
public SessionEntity getSessionEntity(String sessionId) {
return sessionRepository.findBySessionId(sessionId);
}
// Update session details by principal name
public void updateSessionByPrincipalName(
String principalName, boolean expired, Date lastRequest) {
sessionRepository.saveByPrincipalName(expired, lastRequest, principalName);
}
// Find the latest session for a given principal name
public Optional<SessionEntity> findLatestSession(String principalName) {
List<SessionEntity> allSessions = sessionRepository.findByPrincipalName(principalName);
if (allSessions.isEmpty()) {
return Optional.empty();
}
// Sort sessions by lastRequest in descending order
Collections.sort(
allSessions,
new Comparator<SessionEntity>() {
@Override
public int compare(SessionEntity s1, SessionEntity s2) {
// Sort by lastRequest in descending order
return s2.getLastRequest().compareTo(s1.getLastRequest());
}
});
// The first session in the list is the latest session for the given principal name
return Optional.of(allSessions.get(0));
}
}

View File

@@ -0,0 +1,20 @@
package stirling.software.SPDF.config.security.session;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.core.session.SessionRegistryImpl;
@Configuration
public class SessionRegistryConfig {
@Bean
public SessionRegistryImpl sessionRegistry() {
return new SessionRegistryImpl();
}
@Bean
public SessionPersistentRegistry sessionPersistentRegistry(
SessionRepository sessionRepository) {
return new SessionPersistentRegistry(sessionRepository);
}
}

View File

@@ -0,0 +1,31 @@
package stirling.software.SPDF.config.security.session;
import java.util.Date;
import java.util.List;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Modifying;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.springframework.stereotype.Repository;
import jakarta.transaction.Transactional;
import stirling.software.SPDF.model.SessionEntity;
@Repository
public interface SessionRepository extends JpaRepository<SessionEntity, String> {
List<SessionEntity> findByPrincipalName(String principalName);
List<SessionEntity> findByExpired(boolean expired);
SessionEntity findBySessionId(String sessionId);
@Modifying
@Transactional
@Query(
"UPDATE SessionEntity s SET s.expired = :expired, s.lastRequest = :lastRequest WHERE s.principalName = :principalName")
void saveByPrincipalName(
@Param("expired") boolean expired,
@Param("lastRequest") Date lastRequest,
@Param("principalName") String principalName);
}

View File

@@ -0,0 +1,35 @@
package stirling.software.SPDF.config.security.session;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Date;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.security.core.session.SessionInformation;
import org.springframework.stereotype.Component;
@Component
public class SessionScheduled {
@Autowired private SessionPersistentRegistry sessionPersistentRegistry;
@Scheduled(cron = "0 0/5 * * * ?")
public void expireSessions() {
Instant now = Instant.now();
for (Object principal : sessionPersistentRegistry.getAllPrincipals()) {
List<SessionInformation> sessionInformations =
sessionPersistentRegistry.getAllSessions(principal, false);
for (SessionInformation sessionInformation : sessionInformations) {
Date lastRequest = sessionInformation.getLastRequest();
int maxInactiveInterval = sessionPersistentRegistry.getMaxInactiveInterval();
Instant expirationTime =
lastRequest.toInstant().plus(maxInactiveInterval, ChronoUnit.SECONDS);
if (now.isAfter(expirationTime)) {
sessionPersistentRegistry.expireSession(sessionInformation.getSessionId());
}
}
}
}
}