SSO Refactoring (#2818)

# Description of Changes

* Refactoring of SSO code around OAuth & SAML 2
* Enabling auto-login with SAML 2 via the new `SSOAutoLogin` property
* Correcting typos & general cleanup

---

## Checklist

### General

- [x] I have read the [Contribution
Guidelines](https://github.com/Stirling-Tools/Stirling-PDF/blob/main/CONTRIBUTING.md)
- [x] I have read the [Stirling-PDF Developer
Guide](https://github.com/Stirling-Tools/Stirling-PDF/blob/main/DeveloperGuide.md)
(if applicable)
- [x] I have read the [How to add new languages to
Stirling-PDF](https://github.com/Stirling-Tools/Stirling-PDF/blob/main/HowToAddNewLanguage.md)
(if applicable)
- [x] I have performed a self-review of my own code
- [x] My changes generate no new warnings

### Documentation

- [x] I have updated relevant docs on [Stirling-PDF's doc
repo](https://github.com/Stirling-Tools/Stirling-Tools.github.io/blob/main/docs/)
(if functionality has heavily changed)
- [x] I have read the section [Add New Translation
Tags](https://github.com/Stirling-Tools/Stirling-PDF/blob/main/HowToAddNewLanguage.md#add-new-translation-tags)
(for new translation tags only)

### UI Changes (if applicable)

- [ ] Screenshots or videos demonstrating the UI changes are attached
(e.g., as comments or direct attachments in the PR)

### Testing (if applicable)

- [x] I have tested my changes locally. Refer to the [Testing
Guide](https://github.com/Stirling-Tools/Stirling-PDF/blob/main/DeveloperGuide.md#6-testing)
for more details.
This commit is contained in:
Dario Ghunney Ware
2025-02-24 22:18:34 +00:00
committed by GitHub
parent 16295c7bb9
commit 4c701b2e69
85 changed files with 1462 additions and 1144 deletions

View File

@@ -36,7 +36,7 @@ import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.model.Role;
import stirling.software.SPDF.model.User;
import stirling.software.SPDF.model.api.user.UsernameAndPass;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.model.exception.UnsupportedProviderException;
@Controller
@Tag(name = "User", description = "User APIs")
@@ -126,7 +126,7 @@ public class UserController {
return new RedirectView("/change-creds?messageType=notAuthenticated", true);
}
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) {
if (userOpt.isEmpty()) {
return new RedirectView("/change-creds?messageType=userNotFound", true);
}
User user = userOpt.get();
@@ -154,7 +154,7 @@ public class UserController {
return new RedirectView("/account?messageType=notAuthenticated", true);
}
Optional<User> userOpt = userService.findByUsernameIgnoreCase(principal.getName());
if (userOpt == null || userOpt.isEmpty()) {
if (userOpt.isEmpty()) {
return new RedirectView("/account?messageType=userNotFound", true);
}
User user = userOpt.get();
@@ -176,7 +176,7 @@ public class UserController {
for (Map.Entry<String, String[]> entry : paramMap.entrySet()) {
updates.put(entry.getKey(), entry.getValue()[0]);
}
log.debug("Processed updates: " + updates);
log.debug("Processed updates: {}", updates);
// Assuming you have a method in userService to update the settings for a user
userService.updateUserSettings(principal.getName(), updates);
// Redirect to a page of your choice after updating
@@ -199,7 +199,7 @@ public class UserController {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null && user.getUsername().equalsIgnoreCase(username)) {
if (user.getUsername().equalsIgnoreCase(username)) {
return new RedirectView("/addUsers?messageType=usernameExists", true);
}
}
@@ -276,7 +276,7 @@ public class UserController {
Authentication authentication)
throws SQLException, UnsupportedProviderException {
Optional<User> userOpt = userService.findByUsernameIgnoreCase(username);
if (!userOpt.isPresent()) {
if (userOpt.isEmpty()) {
return new RedirectView("/addUsers?messageType=userNotFound", true);
}
if (!userService.usernameExistsIgnoreCase(username)) {
@@ -295,20 +295,20 @@ public class UserController {
List<Object> principals = sessionRegistry.getAllPrincipals();
String userNameP = "";
for (Object principal : principals) {
List<SessionInformation> sessionsInformations =
List<SessionInformation> sessionsInformation =
sessionRegistry.getAllSessions(principal, false);
if (principal instanceof UserDetails) {
userNameP = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
userNameP = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
userNameP = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
userNameP = ((CustomSaml2AuthenticatedPrincipal) principal).name();
} else if (principal instanceof String) {
userNameP = (String) principal;
}
if (userNameP.equalsIgnoreCase(username)) {
for (SessionInformation sessionsInformation : sessionsInformations) {
sessionRegistry.expireSession(sessionsInformation.getSessionId());
for (SessionInformation sessionInfo : sessionsInformation) {
sessionRegistry.expireSession(sessionInfo.getSessionId());
}
}
}

View File

@@ -1,8 +1,15 @@
package stirling.software.SPDF.controller.web;
import static stirling.software.SPDF.utils.validation.Validator.validateProvider;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.springframework.security.access.prepost.PreAuthorize;
@@ -24,12 +31,16 @@ import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.*;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.provider.GithubProvider;
import stirling.software.SPDF.model.Authority;
import stirling.software.SPDF.model.Role;
import stirling.software.SPDF.model.SessionEntity;
import stirling.software.SPDF.model.User;
import stirling.software.SPDF.model.provider.GitHubProvider;
import stirling.software.SPDF.model.provider.GoogleProvider;
import stirling.software.SPDF.model.provider.KeycloakProvider;
import stirling.software.SPDF.repository.UserRepository;
@@ -39,12 +50,12 @@ import stirling.software.SPDF.repository.UserRepository;
@Tag(name = "Account Security", description = "Account Security APIs")
public class AccountWebController {
public static final String OAUTH_2_AUTHORIZATION = "/oauth2/authorization/";
private final ApplicationProperties applicationProperties;
private final SessionPersistentRegistry sessionPersistentRegistry;
private final UserRepository // Assuming you have a repository for user operations
userRepository;
// Assuming you have a repository for user operations
private final UserRepository userRepository;
public AccountWebController(
ApplicationProperties applicationProperties,
@@ -61,132 +72,127 @@ public class AccountWebController {
if (authentication != null && authentication.isAuthenticated()) {
return "redirect:/";
}
Map<String, String> providerList = new HashMap<>();
Security securityProps = applicationProperties.getSecurity();
OAUTH2 oauth = securityProps.getOauth2();
if (oauth != null) {
if (oauth.getEnabled()) {
if (oauth.isSettingsValid()) {
providerList.put("/oauth2/authorization/oidc", oauth.getProvider());
String firstChar = String.valueOf(oauth.getProvider().charAt(0));
String clientName =
oauth.getProvider().replaceFirst(firstChar, firstChar.toUpperCase());
providerList.put(OAUTH_2_AUTHORIZATION + oauth.getProvider(), clientName);
}
Client client = oauth.getClient();
if (client != null) {
GoogleProvider google = client.getGoogle();
if (google.isSettingsValid()) {
if (validateProvider(google)) {
providerList.put(
"/oauth2/authorization/" + google.getName(),
google.getClientName());
OAUTH_2_AUTHORIZATION + google.getName(), google.getClientName());
}
GithubProvider github = client.getGithub();
if (github.isSettingsValid()) {
GitHubProvider github = client.getGithub();
if (validateProvider(github)) {
providerList.put(
"/oauth2/authorization/" + github.getName(),
github.getClientName());
OAUTH_2_AUTHORIZATION + github.getName(), github.getClientName());
}
KeycloakProvider keycloak = client.getKeycloak();
if (keycloak.isSettingsValid()) {
if (validateProvider(keycloak)) {
providerList.put(
"/oauth2/authorization/" + keycloak.getName(),
OAUTH_2_AUTHORIZATION + keycloak.getName(),
keycloak.getClientName());
}
}
}
}
SAML2 saml2 = securityProps.getSaml2();
if (securityProps.isSaml2Activ()
&& applicationProperties.getSystem().getEnableAlphaFunctionality()) {
providerList.put("/saml2/authenticate/" + saml2.getRegistrationId(), "SAML 2");
if (securityProps.isSaml2Active()
&& applicationProperties.getSystem().getEnableAlphaFunctionality()
&& applicationProperties.getEnterpriseEdition().isEnabled()) {
String samlIdp = saml2.getProvider();
String saml2AuthenticationPath = "/saml2/authenticate/" + saml2.getRegistrationId();
if (applicationProperties.getEnterpriseEdition().isSsoAutoLogin()) {
return "redirect:"
+ request.getRequestURL()
+ saml2AuthenticationPath;
} else {
providerList.put(saml2AuthenticationPath, samlIdp + " (SAML 2)");
}
}
// Remove any null keys/values from the providerList
providerList
.entrySet()
.removeIf(entry -> entry.getKey() == null || entry.getValue() == null);
model.addAttribute("providerlist", providerList);
model.addAttribute("providerList", providerList);
model.addAttribute("loginMethod", securityProps.getLoginMethod());
boolean altLogin = providerList.size() > 0 ? securityProps.isAltLogin() : false;
boolean altLogin = !providerList.isEmpty() ? securityProps.isAltLogin() : false;
model.addAttribute("altLogin", altLogin);
model.addAttribute("currentPage", "login");
String error = request.getParameter("error");
if (error != null) {
switch (error) {
case "badcredentials":
error = "login.invalid";
break;
case "locked":
error = "login.locked";
break;
case "oauth2AuthenticationError":
error = "userAlreadyExistsOAuthMessage";
break;
default:
break;
case "badCredentials" -> error = "login.invalid";
case "locked" -> error = "login.locked";
case "oauth2AuthenticationError" -> error = "userAlreadyExistsOAuthMessage";
}
model.addAttribute("error", error);
}
String erroroauth = request.getParameter("erroroauth");
if (erroroauth != null) {
switch (erroroauth) {
case "oauth2AutoCreateDisabled":
erroroauth = "login.oauth2AutoCreateDisabled";
break;
case "invalidUsername":
erroroauth = "login.invalid";
break;
case "userAlreadyExistsWeb":
erroroauth = "userAlreadyExistsWebMessage";
break;
case "oauth2AuthenticationErrorWeb":
erroroauth = "login.oauth2InvalidUserType";
break;
case "invalid_token_response":
erroroauth = "login.oauth2InvalidTokenResponse";
break;
case "authorization_request_not_found":
erroroauth = "login.oauth2RequestNotFound";
break;
case "access_denied":
erroroauth = "login.oauth2AccessDenied";
break;
case "invalid_user_info_response":
erroroauth = "login.oauth2InvalidUserInfoResponse";
break;
case "invalid_request":
erroroauth = "login.oauth2invalidRequest";
break;
case "invalid_id_token":
erroroauth = "login.oauth2InvalidIdToken";
break;
case "oauth2_admin_blocked_user":
erroroauth = "login.oauth2AdminBlockedUser";
break;
case "userIsDisabled":
erroroauth = "login.userIsDisabled";
break;
case "invalid_destination":
erroroauth = "login.invalid_destination";
break;
case "relying_party_registration_not_found":
erroroauth = "login.relyingPartyRegistrationNotFound";
break;
String errorOAuth = request.getParameter("errorOAuth");
if (errorOAuth != null) {
switch (errorOAuth) {
case "oAuth2AutoCreateDisabled" -> errorOAuth = "login.oAuth2AutoCreateDisabled";
case "invalidUsername" -> errorOAuth = "login.invalid";
case "userAlreadyExistsWeb" -> errorOAuth = "userAlreadyExistsWebMessage";
case "oAuth2AuthenticationErrorWeb" -> errorOAuth = "login.oauth2InvalidUserType";
case "invalid_token_response" -> errorOAuth = "login.oauth2InvalidTokenResponse";
case "authorization_request_not_found" ->
errorOAuth = "login.oauth2RequestNotFound";
case "access_denied" -> errorOAuth = "login.oauth2AccessDenied";
case "invalid_user_info_response" ->
errorOAuth = "login.oauth2InvalidUserInfoResponse";
case "invalid_request" -> errorOAuth = "login.oauth2invalidRequest";
case "invalid_id_token" -> errorOAuth = "login.oauth2InvalidIdToken";
case "oAuth2AdminBlockedUser" -> errorOAuth = "login.oAuth2AdminBlockedUser";
case "userIsDisabled" -> errorOAuth = "login.userIsDisabled";
case "invalid_destination" -> errorOAuth = "login.invalid_destination";
case "relying_party_registration_not_found" ->
errorOAuth = "login.relyingPartyRegistrationNotFound";
// Valid InResponseTo was not available from the validation context, unable to
// evaluate
case "invalid_in_response_to":
erroroauth = "login.invalid_in_response_to";
break;
case "not_authentication_provider_found":
erroroauth = "login.not_authentication_provider_found";
break;
default:
break;
case "invalid_in_response_to" -> errorOAuth = "login.invalid_in_response_to";
case "not_authentication_provider_found" ->
errorOAuth = "login.not_authentication_provider_found";
}
model.addAttribute("erroroauth", erroroauth);
model.addAttribute("errorOAuth", errorOAuth);
}
if (request.getParameter("messageType") != null) {
model.addAttribute("messageType", "changedCredsMessage");
}
if (request.getParameter("logout") != null) {
model.addAttribute("logoutMessage", "You have been logged out.");
}
return "login";
}
@@ -230,13 +236,11 @@ public class AccountWebController {
.plus(maxInactiveInterval, ChronoUnit.SECONDS);
if (now.isAfter(expirationTime)) {
sessionPersistentRegistry.expireSession(sessionEntity.getSessionId());
hasActiveSession = false;
} else {
hasActiveSession = !sessionEntity.isExpired();
}
lastRequest = sessionEntity.getLastRequest();
} else {
hasActiveSession = false;
// No session, set default last request time
lastRequest = new Date(0);
}
@@ -273,53 +277,41 @@ public class AccountWebController {
})
.collect(Collectors.toList());
String messageType = request.getParameter("messageType");
String deleteMessage = null;
String deleteMessage;
if (messageType != null) {
switch (messageType) {
case "deleteCurrentUser":
deleteMessage = "deleteCurrentUserMessage";
break;
case "deleteUsernameExists":
deleteMessage = "deleteUsernameExistsMessage";
break;
default:
break;
}
deleteMessage =
switch (messageType) {
case "deleteCurrentUser" -> "deleteCurrentUserMessage";
case "deleteUsernameExists" -> "deleteUsernameExistsMessage";
default -> null;
};
model.addAttribute("deleteMessage", deleteMessage);
String addMessage = null;
switch (messageType) {
case "usernameExists":
addMessage = "usernameExistsMessage";
break;
case "invalidUsername":
addMessage = "invalidUsernameMessage";
break;
case "invalidPassword":
addMessage = "invalidPasswordMessage";
break;
default:
break;
}
String addMessage;
addMessage =
switch (messageType) {
case "usernameExists" -> "usernameExistsMessage";
case "invalidUsername" -> "invalidUsernameMessage";
case "invalidPassword" -> "invalidPasswordMessage";
default -> null;
};
model.addAttribute("addMessage", addMessage);
}
String changeMessage = null;
String changeMessage;
if (messageType != null) {
switch (messageType) {
case "userNotFound":
changeMessage = "userNotFoundMessage";
break;
case "downgradeCurrentUser":
changeMessage = "downgradeCurrentUserMessage";
break;
case "disabledCurrentUser":
changeMessage = "disabledCurrentUserMessage";
break;
default:
changeMessage = messageType;
break;
}
changeMessage =
switch (messageType) {
case "userNotFound" -> "userNotFoundMessage";
case "downgradeCurrentUser" -> "downgradeCurrentUserMessage";
case "disabledCurrentUser" -> "disabledCurrentUserMessage";
default -> messageType;
};
model.addAttribute("changeMessage", changeMessage);
}
model.addAttribute("users", sortedUsers);
model.addAttribute("currentUsername", authentication.getName());
model.addAttribute("roleDetails", roleDetails);
@@ -340,78 +332,51 @@ public class AccountWebController {
if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal();
String username = null;
if (principal instanceof UserDetails) {
// Cast the principal object to UserDetails
UserDetails userDetails = (UserDetails) principal;
// Retrieve username and other attributes
// Retrieve username and other attributes and add login attributes to the model
if (principal instanceof UserDetails userDetails) {
username = userDetails.getUsername();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", false);
}
if (principal instanceof OAuth2User) {
// Cast the principal object to OAuth2User
OAuth2User userDetails = (OAuth2User) principal;
// Retrieve username and other attributes
username =
userDetails.getAttribute(
applicationProperties.getSecurity().getOauth2().getUseAsUsername());
// Add oAuth2 Login attributes to the model
if (principal instanceof OAuth2User userDetails) {
username = userDetails.getName();
model.addAttribute("oAuth2Login", true);
}
if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
// Cast the principal object to OAuth2User
CustomSaml2AuthenticatedPrincipal userDetails =
(CustomSaml2AuthenticatedPrincipal) principal;
// Retrieve username and other attributes
username = userDetails.getName();
// Add oAuth2 Login attributes to the model
model.addAttribute("oAuth2Login", true);
if (principal instanceof CustomSaml2AuthenticatedPrincipal userDetails) {
username = userDetails.name();
model.addAttribute("saml2Login", true);
}
if (username != null) {
// Fetch user details from the database
Optional<User> user =
userRepository
.findByUsernameIgnoreCaseWithSettings( // Assuming findByUsername
// method exists
username);
if (!user.isPresent()) {
Optional<User> user = userRepository.findByUsernameIgnoreCaseWithSettings(username);
if (user.isEmpty()) {
return "redirect:/error";
}
// Convert settings map to JSON string
ObjectMapper objectMapper = new ObjectMapper();
String settingsJson;
try {
settingsJson = objectMapper.writeValueAsString(user.get().getSettings());
} catch (JsonProcessingException e) {
// Handle JSON conversion error
log.error("exception", e);
log.error("Error converting settings map", e);
return "redirect:/error";
}
String messageType = request.getParameter("messageType");
if (messageType != null) {
switch (messageType) {
case "notAuthenticated":
messageType = "notAuthenticatedMessage";
break;
case "userNotFound":
messageType = "userNotFoundMessage";
break;
case "incorrectPassword":
messageType = "incorrectPasswordMessage";
break;
case "usernameExists":
messageType = "usernameExistsMessage";
break;
case "invalidUsername":
messageType = "invalidUsernameMessage";
break;
default:
break;
case "notAuthenticated" -> messageType = "notAuthenticatedMessage";
case "userNotFound" -> messageType = "userNotFoundMessage";
case "incorrectPassword" -> messageType = "incorrectPasswordMessage";
case "usernameExists" -> messageType = "usernameExistsMessage";
case "invalidUsername" -> messageType = "invalidUsernameMessage";
}
model.addAttribute("messageType", messageType);
}
// Add attributes to the model
model.addAttribute("username", username);
model.addAttribute("messageType", messageType);
model.addAttribute("role", user.get().getRolesAsString());
model.addAttribute("settings", settingsJson);
model.addAttribute("changeCredsFlag", user.get().isFirstLogin());
@@ -432,19 +397,12 @@ public class AccountWebController {
}
if (authentication != null && authentication.isAuthenticated()) {
Object principal = authentication.getPrincipal();
if (principal instanceof UserDetails) {
// Cast the principal object to UserDetails
UserDetails userDetails = (UserDetails) principal;
// Retrieve username and other attributes
if (principal instanceof UserDetails userDetails) {
String username = userDetails.getUsername();
// Fetch user details from the database
Optional<User> user =
userRepository
.findByUsernameIgnoreCase( // Assuming findByUsername method exists
username);
if (!user.isPresent()) {
// Handle error appropriately
// Example redirection in case of error
Optional<User> user = userRepository.findByUsernameIgnoreCase(username);
if (user.isEmpty()) {
// Handle error appropriately, example redirection in case of error
return "redirect:/error";
}
String messageType = request.getParameter("messageType");
@@ -467,7 +425,7 @@ public class AccountWebController {
}
model.addAttribute("messageType", messageType);
}
// Add attributes to the model
model.addAttribute("username", username);
}
} else {