Major Enhancements to SAML2 and OAuth2 Integration with Simplified Security Configurations (#2040)

* implement Saml2 login/logout

* changed: deprecation code

* relyingPartyRegistrations only enabled samle
This commit is contained in:
Ludy
2024-10-20 13:30:58 +02:00
committed by GitHub
parent 227d18a469
commit eff1843061
32 changed files with 1080 additions and 839 deletions

View File

@@ -1,27 +1,237 @@
package stirling.software.SPDF.config.security;
import java.io.IOException;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.util.ArrayList;
import java.util.List;
import org.springframework.core.io.Resource;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import com.coveo.saml.SamlClient;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.SPdfApplication;
import stirling.software.SPDF.config.security.saml2.CertificateUtils;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.Provider;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
@AllArgsConstructor
public class CustomLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
private final ApplicationProperties applicationProperties;
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
if (request.getParameter("userIsDisabled") != null) {
getRedirectStrategy()
.sendRedirect(request, response, "/login?erroroauth=userIsDisabled");
return;
if (!response.isCommitted()) {
// Handle user logout due to disabled account
if (request.getParameter("userIsDisabled") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userIsDisabled");
return;
}
// Handle OAuth2 authentication error
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userAlreadyExistsWeb");
return;
}
if (authentication != null) {
// Handle SAML2 logout redirection
if (authentication instanceof Saml2Authentication) {
getRedirect_saml2(request, response, authentication);
return;
}
// Handle OAuth2 logout redirection
else if (authentication instanceof OAuth2AuthenticationToken) {
getRedirect_oauth2(request, response, authentication);
return;
}
// Handle Username/Password logout
else if (authentication instanceof UsernamePasswordAuthenticationToken) {
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
// Handle unknown authentication types
else {
log.error(
"authentication class unknown: "
+ authentication.getClass().getSimpleName());
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
} else {
// Redirect to login page after logout
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
return;
}
}
}
// Redirect for SAML2 authentication logout
private void getRedirect_saml2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
String registrationId = samlConf.getRegistrationId();
Saml2Authentication samlAuthentication = (Saml2Authentication) authentication;
CustomSaml2AuthenticatedPrincipal principal =
(CustomSaml2AuthenticatedPrincipal) samlAuthentication.getPrincipal();
String nameIdValue = principal.getName();
try {
// Read certificate from the resource
Resource certificateResource = samlConf.getSpCert();
X509Certificate certificate = CertificateUtils.readCertificate(certificateResource);
List<X509Certificate> certificates = new ArrayList<>();
certificates.add(certificate);
// Construct URLs required for SAML configuration
String serverUrl =
SPdfApplication.getStaticBaseUrl() + ":" + SPdfApplication.getStaticPort();
String relyingPartyIdentifier =
serverUrl + "/saml2/service-provider-metadata/" + registrationId;
String assertionConsumerServiceUrl = serverUrl + "/login/saml2/sso/" + registrationId;
String idpUrl = samlConf.getIdpSingleLogoutUrl();
String idpIssuer = samlConf.getIdpIssuer();
// Create SamlClient instance for SAML logout
SamlClient samlClient =
new SamlClient(
relyingPartyIdentifier,
assertionConsumerServiceUrl,
idpUrl,
idpIssuer,
certificates,
SamlClient.SamlIdpBinding.POST);
// Read private key for service provider
Resource privateKeyResource = samlConf.getPrivateKey();
RSAPrivateKey privateKey = CertificateUtils.readPrivateKey(privateKeyResource);
// Set service provider keys for the SamlClient
samlClient.setSPKeys(certificate, privateKey);
// Redirect to identity provider for logout
samlClient.redirectToIdentityProvider(response, null, nameIdValue);
} catch (Exception e) {
log.error(nameIdValue, e);
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
}
}
// Redirect for OAuth2 authentication logout
private void getRedirect_oauth2(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException {
String param = "logout=true";
String registrationId = null;
String issuer = null;
String clientId = null;
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (authentication instanceof OAuth2AuthenticationToken) {
OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication;
registrationId = oauthToken.getAuthorizedClientRegistrationId();
try {
// Get OAuth2 provider details from configuration
Provider provider = oauth.getClient().get(registrationId);
issuer = provider.getIssuer();
clientId = provider.getClientId();
} catch (UnsupportedProviderException e) {
log.error(e.getMessage());
}
} else {
registrationId = oauth.getProvider() != null ? oauth.getProvider() : "";
issuer = oauth.getIssuer();
clientId = oauth.getClientId();
}
String errorMessage = "";
// Handle different error scenarios during logout
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
param = "erroroauth=oauth2AuthenticationErrorWeb";
} else if ((errorMessage = request.getParameter("error")) != null) {
param = "error=" + sanitizeInput(errorMessage);
} else if ((errorMessage = request.getParameter("erroroauth")) != null) {
param = "erroroauth=" + sanitizeInput(errorMessage);
} else if (request.getParameter("oauth2AutoCreateDisabled") != null) {
param = "error=oauth2AutoCreateDisabled";
} else if (request.getParameter("oauth2_admin_blocked_user") != null) {
param = "erroroauth=oauth2_admin_blocked_user";
} else if (request.getParameter("userIsDisabled") != null) {
param = "erroroauth=userIsDisabled";
} else if (request.getParameter("badcredentials") != null) {
param = "error=badcredentials";
}
getRedirectStrategy().sendRedirect(request, response, "/login?logout=true");
String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param;
// Redirect based on OAuth2 provider
switch (registrationId.toLowerCase()) {
case "keycloak":
// Add Keycloak specific logout URL if needed
String logoutUrl =
issuer
+ "/protocol/openid-connect/logout"
+ "?client_id="
+ clientId
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirect_url);
log.info("Redirecting to Keycloak logout URL: " + logoutUrl);
response.sendRedirect(logoutUrl);
break;
case "github":
// Add GitHub specific logout URL if needed
String githubLogoutUrl = "https://github.com/logout";
log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
response.sendRedirect(githubLogoutUrl);
break;
case "google":
// Add Google specific logout URL if needed
// String googleLogoutUrl =
// "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue="
// + response.encodeRedirectURL(redirect_url);
log.info("Google does not have a specific logout URL");
// log.info("Redirecting to Google logout URL: " + googleLogoutUrl);
// response.sendRedirect(googleLogoutUrl);
// break;
default:
String defaultRedirectUrl = request.getContextPath() + "/login?" + param;
log.info("Redirecting to default logout URL: " + defaultRedirectUrl);
response.sendRedirect(defaultRedirectUrl);
break;
}
}
// Sanitize input to avoid potential security vulnerabilities
private String sanitizeInput(String input) {
return input.replaceAll("[^a-zA-Z0-9 ]", "");
}
}

View File

@@ -1,22 +1,36 @@
package stirling.software.SPDF.config.security;
import java.security.cert.X509Certificate;
import java.util.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Lazy;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.core.io.Resource;
import org.springframework.security.authentication.AuthenticationProvider;
import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
import org.springframework.security.authentication.dao.DaoAuthenticationProvider;
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.ClientRegistrations;
import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
import org.springframework.security.oauth2.core.user.OAuth2UserAuthority;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.web.SecurityFilterChain;
@@ -28,13 +42,20 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationFailureHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2AuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2LogoutSuccessHandler;
import stirling.software.SPDF.config.security.oauth2.CustomOAuth2UserService;
import stirling.software.SPDF.config.security.saml.ConvertResponseToAuthentication;
import stirling.software.SPDF.config.security.saml.CustomSAMLAuthenticationFailureHandler;
import stirling.software.SPDF.config.security.saml.CustomSAMLAuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.saml2.CertificateUtils;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticationFailureHandler;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticationSuccessHandler;
import stirling.software.SPDF.config.security.saml2.CustomSaml2ResponseAuthenticationConverter;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2.Client;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.User;
import stirling.software.SPDF.model.provider.GithubProvider;
import stirling.software.SPDF.model.provider.GoogleProvider;
import stirling.software.SPDF.model.provider.KeycloakProvider;
import stirling.software.SPDF.repository.JPATokenRepositoryImpl;
@Configuration
@@ -45,12 +66,6 @@ public class SecurityConfiguration {
@Autowired private CustomUserDetailsService userDetailsService;
@Autowired(required = false)
private GrantedAuthoritiesMapper userAuthoritiesMapper;
@Autowired(required = false)
private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
@@ -71,11 +86,8 @@ public class SecurityConfiguration {
@Autowired private FirstLoginFilter firstLoginFilter;
@Autowired private SessionPersistentRegistry sessionRegistry;
@Autowired private ConvertResponseToAuthentication convertResponseToAuthentication;
@Bean
public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
http.authenticationManager(authenticationManager(http));
if (loginEnabledValue) {
http.addFilterBefore(
@@ -94,128 +106,116 @@ public class SecurityConfiguration {
.sessionRegistry(sessionRegistry)
.expiredUrl("/login?logout=true"));
http.formLogin(
formLogin ->
formLogin
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService, userService))
.defaultSuccessUrl("/")
.failureHandler(
new CustomAuthenticationFailureHandler(
loginAttemptService, userService))
.permitAll())
.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()))
.logout(
logout ->
logout.logoutRequestMatcher(
new AntPathRequestMatcher("/logout"))
.logoutSuccessHandler(new CustomLogoutSuccessHandler())
.invalidateHttpSession(true) // Invalidate session
.deleteCookies("JSESSIONID", "remember-me"))
.rememberMe(
rememberMeConfigurer ->
rememberMeConfigurer // Use the configurator directly
.key("uniqueAndSecret")
.tokenRepository(persistentTokenRepository())
.tokenValiditySeconds(1209600) // 2 weeks
)
.authorizeHttpRequests(
authz ->
authz.requestMatchers(
req -> {
String uri = req.getRequestURI();
String contextPath = req.getContextPath();
http.authenticationProvider(daoAuthenticationProvider());
http.requestCache(requestCache -> requestCache.requestCache(new NullRequestCache()));
http.logout(
logout ->
logout.logoutRequestMatcher(new AntPathRequestMatcher("/logout"))
.logoutSuccessHandler(
new CustomLogoutSuccessHandler(applicationProperties))
.invalidateHttpSession(true) // Invalidate session
.deleteCookies("JSESSIONID", "remember-me"));
http.rememberMe(
rememberMeConfigurer ->
rememberMeConfigurer // Use the configurator directly
.key("uniqueAndSecret")
.tokenRepository(persistentTokenRepository())
.tokenValiditySeconds(1209600) // 2 weeks
);
http.authorizeHttpRequests(
authz ->
authz.requestMatchers(
req -> {
String uri = req.getRequestURI();
String contextPath = req.getContextPath();
// Remove the context path from the URI
String trimmedUri =
uri.startsWith(contextPath)
? uri.substring(
contextPath
.length())
: uri;
// Remove the context path from the URI
String trimmedUri =
uri.startsWith(contextPath)
? uri.substring(
contextPath.length())
: uri;
return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/oauth")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.endsWith(".svg")
|| trimmedUri.startsWith(
"/register")
|| trimmedUri.startsWith("/error")
|| trimmedUri.startsWith("/images/")
|| trimmedUri.startsWith("/public/")
|| trimmedUri.startsWith("/css/")
|| trimmedUri.startsWith("/fonts/")
|| trimmedUri.startsWith("/js/")
|| trimmedUri.startsWith(
"/api/v1/info/status");
})
.permitAll()
.anyRequest()
.authenticated());
return trimmedUri.startsWith("/login")
|| trimmedUri.startsWith("/oauth")
|| trimmedUri.startsWith("/saml2")
|| trimmedUri.endsWith(".svg")
|| trimmedUri.startsWith("/register")
|| trimmedUri.startsWith("/error")
|| trimmedUri.startsWith("/images/")
|| trimmedUri.startsWith("/public/")
|| trimmedUri.startsWith("/css/")
|| trimmedUri.startsWith("/fonts/")
|| trimmedUri.startsWith("/js/")
|| trimmedUri.startsWith(
"/api/v1/info/status");
})
.permitAll()
.anyRequest()
.authenticated());
// Handle User/Password Logins
if (applicationProperties.getSecurity().isUserPass()) {
http.formLogin(
formLogin ->
formLogin
.loginPage("/login")
.successHandler(
new CustomAuthenticationSuccessHandler(
loginAttemptService, userService))
.failureHandler(
new CustomAuthenticationFailureHandler(
loginAttemptService, userService))
.defaultSuccessUrl("/")
.permitAll());
}
// Handle OAUTH2 Logins
if (applicationProperties.getSecurity().getOauth2() != null
&& applicationProperties.getSecurity().getOauth2().getEnabled()
&& !applicationProperties
.getSecurity()
.getLoginMethod()
.equalsIgnoreCase("normal")) {
if (applicationProperties.getSecurity().isOauth2Activ()) {
http.oauth2Login(
oauth2 ->
oauth2.loginPage("/oauth2")
/*
This Custom handler is used to check if the OAUTH2 user trying to log in, already exists in the database.
If user exists, login proceeds as usual. If user does not exist, then it is autocreated but only if 'OAUTH2AutoCreateUser'
is set as true, else login fails with an error message advising the same.
*/
oauth2 ->
oauth2.loginPage("/oauth2")
/*
This Custom handler is used to check if the OAUTH2 user trying to log in, already exists in the database.
If user exists, login proceeds as usual. If user does not exist, then it is autocreated but only if 'OAUTH2AutoCreateUser'
is set as true, else login fails with an error message advising the same.
*/
.successHandler(
new CustomOAuth2AuthenticationSuccessHandler(
loginAttemptService,
applicationProperties,
userService))
.failureHandler(
new CustomOAuth2AuthenticationFailureHandler())
// Add existing Authorities from the database
.userInfoEndpoint(
userInfoEndpoint ->
userInfoEndpoint
.oidcUserService(
new CustomOAuth2UserService(
applicationProperties,
userService,
loginAttemptService))
.userAuthoritiesMapper(
userAuthoritiesMapper()))
.permitAll());
}
// Handle SAML
if (applicationProperties.getSecurity().isSaml2Activ()) {
http.authenticationProvider(samlAuthenticationProvider());
http.saml2Login(
saml2 ->
saml2.loginPage("/saml2")
.successHandler(
new CustomOAuth2AuthenticationSuccessHandler(
new CustomSaml2AuthenticationSuccessHandler(
loginAttemptService,
applicationProperties,
userService))
.failureHandler(
new CustomOAuth2AuthenticationFailureHandler())
// Add existing Authorities from the database
.userInfoEndpoint(
userInfoEndpoint ->
userInfoEndpoint
.oidcUserService(
new CustomOAuth2UserService(
applicationProperties,
userService,
loginAttemptService))
.userAuthoritiesMapper(
userAuthoritiesMapper)))
.logout(
logout ->
logout.logoutSuccessHandler(
new CustomOAuth2LogoutSuccessHandler(
applicationProperties)));
}
// Handle SAML
if (applicationProperties.getSecurity().getSaml() != null
&& applicationProperties.getSecurity().getSaml().getEnabled()
&& !applicationProperties
.getSecurity()
.getLoginMethod()
.equalsIgnoreCase("normal")) {
http.saml2Login(
saml2 -> {
saml2.loginPage("/saml2")
.relyingPartyRegistrationRepository(
relyingPartyRegistrationRepository)
.successHandler(
new CustomSAMLAuthenticationSuccessHandler(
loginAttemptService,
userService,
applicationProperties))
.failureHandler(
new CustomSAMLAuthenticationFailureHandler());
})
new CustomSaml2AuthenticationFailureHandler())
.permitAll())
.addFilterBefore(
userAuthenticationFilter, Saml2WebSsoAuthenticationFilter.class);
}
@@ -231,39 +231,234 @@ public class SecurityConfiguration {
@Bean
@ConditionalOnProperty(
name = "security.saml.enabled",
name = "security.saml2.enabled",
havingValue = "true",
matchIfMissing = false)
public AuthenticationProvider samlAuthenticationProvider() {
OpenSaml4AuthenticationProvider authenticationProvider =
new OpenSaml4AuthenticationProvider();
authenticationProvider.setResponseAuthenticationConverter(convertResponseToAuthentication);
authenticationProvider.setResponseAuthenticationConverter(
new CustomSaml2ResponseAuthenticationConverter(userService));
return authenticationProvider;
}
// @Bean
// public AuthenticationProvider daoAuthenticationProvider() {
// DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
// provider.setUserDetailsService(userDetailsService); // UserDetailsService
// provider.setPasswordEncoder(passwordEncoder()); // PasswordEncoder
// return provider;
// }
// Client Registration Repository for OAUTH2 OIDC Login
@Bean
@ConditionalOnProperty(
value = "security.oauth2.enabled",
havingValue = "true",
matchIfMissing = false)
public ClientRegistrationRepository clientRegistrationRepository() {
List<ClientRegistration> registrations = new ArrayList<>();
githubClientRegistration().ifPresent(registrations::add);
oidcClientRegistration().ifPresent(registrations::add);
googleClientRegistration().ifPresent(registrations::add);
keycloakClientRegistration().ifPresent(registrations::add);
if (registrations.isEmpty()) {
log.error("At least one OAuth2 provider must be configured");
System.exit(1);
}
return new InMemoryClientRegistrationRepository(registrations);
}
private Optional<ClientRegistration> googleClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
GoogleProvider google = client.getGoogle();
return google != null && google.isSettingsValid()
? Optional.of(
ClientRegistration.withRegistrationId(google.getName())
.clientId(google.getClientId())
.clientSecret(google.getClientSecret())
.scope(google.getScopes())
.authorizationUri(google.getAuthorizationuri())
.tokenUri(google.getTokenuri())
.userInfoUri(google.getUserinfouri())
.userNameAttributeName(google.getUseAsUsername())
.clientName(google.getClientName())
.redirectUri("{baseUrl}/login/oauth2/code/" + google.getName())
.authorizationGrantType(
org.springframework.security.oauth2.core
.AuthorizationGrantType.AUTHORIZATION_CODE)
.build())
: Optional.empty();
}
private Optional<ClientRegistration> keycloakClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
KeycloakProvider keycloak = client.getKeycloak();
return keycloak != null && keycloak.isSettingsValid()
? Optional.of(
ClientRegistrations.fromIssuerLocation(keycloak.getIssuer())
.registrationId(keycloak.getName())
.clientId(keycloak.getClientId())
.clientSecret(keycloak.getClientSecret())
.scope(keycloak.getScopes())
.userNameAttributeName(keycloak.getUseAsUsername())
.clientName(keycloak.getClientName())
.build())
: Optional.empty();
}
private Optional<ClientRegistration> githubClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null || !oauth.getEnabled()) {
return Optional.empty();
}
Client client = oauth.getClient();
if (client == null) {
return Optional.empty();
}
GithubProvider github = client.getGithub();
return github != null && github.isSettingsValid()
? Optional.of(
ClientRegistration.withRegistrationId(github.getName())
.clientId(github.getClientId())
.clientSecret(github.getClientSecret())
.scope(github.getScopes())
.authorizationUri(github.getAuthorizationuri())
.tokenUri(github.getTokenuri())
.userInfoUri(github.getUserinfouri())
.userNameAttributeName(github.getUseAsUsername())
.clientName(github.getClientName())
.redirectUri("{baseUrl}/login/oauth2/code/" + github.getName())
.authorizationGrantType(
org.springframework.security.oauth2.core
.AuthorizationGrantType.AUTHORIZATION_CODE)
.build())
: Optional.empty();
}
private Optional<ClientRegistration> oidcClientRegistration() {
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (oauth == null
|| oauth.getIssuer() == null
|| oauth.getIssuer().isEmpty()
|| oauth.getClientId() == null
|| oauth.getClientId().isEmpty()
|| oauth.getClientSecret() == null
|| oauth.getClientSecret().isEmpty()
|| oauth.getScopes() == null
|| oauth.getScopes().isEmpty()
|| oauth.getUseAsUsername() == null
|| oauth.getUseAsUsername().isEmpty()) {
return Optional.empty();
}
return Optional.of(
ClientRegistrations.fromIssuerLocation(oauth.getIssuer())
.registrationId("oidc")
.clientId(oauth.getClientId())
.clientSecret(oauth.getClientSecret())
.scope(oauth.getScopes())
.userNameAttributeName(oauth.getUseAsUsername())
.clientName("OIDC")
.build());
}
@Bean
public AuthenticationManager authenticationManager(HttpSecurity http) throws Exception {
AuthenticationManagerBuilder authenticationManagerBuilder =
http.getSharedObject(AuthenticationManagerBuilder.class);
@ConditionalOnProperty(
name = "security.saml2.enabled",
havingValue = "true",
matchIfMissing = false)
public RelyingPartyRegistrationRepository relyingPartyRegistrations() throws Exception {
// authenticationManagerBuilder =
// authenticationManagerBuilder.authenticationProvider(
// daoAuthenticationProvider()); // Benutzername/Passwort
SAML2 samlConf = applicationProperties.getSecurity().getSaml2();
if (applicationProperties.getSecurity().getSaml() != null
&& applicationProperties.getSecurity().getSaml().getEnabled()) {
authenticationManagerBuilder.authenticationProvider(
samlAuthenticationProvider()); // SAML
}
return authenticationManagerBuilder.build();
Resource privateKeyResource = samlConf.getPrivateKey();
Resource certificateResource = samlConf.getSpCert();
Saml2X509Credential signingCredential =
new Saml2X509Credential(
CertificateUtils.readPrivateKey(privateKeyResource),
CertificateUtils.readCertificate(certificateResource),
Saml2X509CredentialType.SIGNING);
X509Certificate idpCert = CertificateUtils.readCertificate(samlConf.getidpCert());
Saml2X509Credential verificationCredential = Saml2X509Credential.verification(idpCert);
RelyingPartyRegistration rp =
RelyingPartyRegistration.withRegistrationId(samlConf.getRegistrationId())
.signingX509Credentials((c) -> c.add(signingCredential))
.assertingPartyDetails(
(details) ->
details.entityId(samlConf.getIdpIssuer())
.singleSignOnServiceLocation(
samlConf.getIdpSingleLoginUrl())
.verificationX509Credentials(
(c) -> c.add(verificationCredential))
.wantAuthnRequestsSigned(true))
.build();
return new InMemoryRelyingPartyRegistrationRepository(rp);
}
@Bean
public DaoAuthenticationProvider daoAuthenticationProvider() {
DaoAuthenticationProvider provider = new DaoAuthenticationProvider();
provider.setUserDetailsService(userDetailsService);
provider.setPasswordEncoder(passwordEncoder());
return provider;
}
/*
This following function is to grant Authorities to the OAUTH2 user from the values stored in the database.
This is required for the internal; 'hasRole()' function to give out the correct role.
*/
@Bean
@ConditionalOnProperty(
value = "security.oauth2.enabled",
havingValue = "true",
matchIfMissing = false)
GrantedAuthoritiesMapper userAuthoritiesMapper() {
return (authorities) -> {
Set<GrantedAuthority> mappedAuthorities = new HashSet<>();
authorities.forEach(
authority -> {
// Add existing OAUTH2 Authorities
mappedAuthorities.add(new SimpleGrantedAuthority(authority.getAuthority()));
// Add Authorities from database for existing user, if user is present.
if (authority instanceof OAuth2UserAuthority oauth2Auth) {
String useAsUsername =
applicationProperties
.getSecurity()
.getOauth2()
.getUseAsUsername();
Optional<User> userOpt =
userService.findByUsernameIgnoreCase(
(String) oauth2Auth.getAttributes().get(useAsUsername));
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null) {
mappedAuthorities.add(
new SimpleGrantedAuthority(
userService.findRole(user).getAuthority()));
}
}
}
});
return mappedAuthorities;
};
}
@Bean

View File

@@ -22,6 +22,7 @@ import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.model.ApiKeyAuthenticationToken;
import stirling.software.SPDF.model.User;
@@ -111,7 +112,9 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
response.setStatus(HttpStatus.UNAUTHORIZED.value());
response.getWriter()
.write(
"Authentication required. Please provide a X-API-KEY in request header.\nThis is found in Settings -> Account Settings -> API Key\nAlternatively you can disable authentication if this is unexpected");
"Authentication required. Please provide a X-API-KEY in request header.\n"
+ "This is found in Settings -> Account Settings -> API Key\n"
+ "Alternatively you can disable authentication if this is unexpected");
return;
}
}
@@ -124,6 +127,8 @@ public class UserAuthenticationFilter extends OncePerRequestFilter {
username = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
username = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) {
username = (String) principal;
}

View File

@@ -20,6 +20,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Service;
import stirling.software.SPDF.config.interfaces.DatabaseBackupInterface;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.config.security.session.SessionPersistentRegistry;
import stirling.software.SPDF.controller.api.pipeline.UserServiceInterface;
import stirling.software.SPDF.model.AuthenticationType;
@@ -338,6 +339,10 @@ public class UserService implements UserServiceInterface {
} else if (principal instanceof OAuth2User) {
OAuth2User oAuth2User = (OAuth2User) principal;
usernameP = oAuth2User.getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
CustomSaml2AuthenticatedPrincipal saml2User =
(CustomSaml2AuthenticatedPrincipal) principal;
usernameP = saml2User.getName();
} else if (principal instanceof String) {
usernameP = (String) principal;
}

View File

@@ -51,8 +51,7 @@ public class CustomOAuth2AuthenticationFailureHandler
}
log.error("OAuth2 Authentication error: " + errorCode);
log.error("OAuth2AuthenticationException", exception);
getRedirectStrategy()
.sendRedirect(request, response, "/logout?erroroauth=" + errorCode);
getRedirectStrategy().sendRedirect(request, response, "/login?erroroauth=" + errorCode);
return;
}
log.error("Unhandled authentication exception", exception);

View File

@@ -75,6 +75,11 @@ public class CustomOAuth2AuthenticationSuccessHandler
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (userService.isUserDisabled(username)) {
getRedirectStrategy()
.sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(

View File

@@ -1,122 +0,0 @@
package stirling.software.SPDF.config.security.oauth2;
import java.io.IOException;
import org.springframework.security.core.Authentication;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.Provider;
import stirling.software.SPDF.model.provider.UnsupportedProviderException;
import stirling.software.SPDF.utils.UrlUtils;
@Slf4j
public class CustomOAuth2LogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
private final ApplicationProperties applicationProperties;
public CustomOAuth2LogoutSuccessHandler(ApplicationProperties applicationProperties) {
this.applicationProperties = applicationProperties;
}
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
String param = "logout=true";
String registrationId = null;
String issuer = null;
String clientId = null;
if (authentication == null) {
if (request.getParameter("userIsDisabled") != null) {
response.sendRedirect(
request.getContextPath() + "/login?erroroauth=userIsDisabled");
} else {
super.onLogoutSuccess(request, response, authentication);
}
return;
}
OAUTH2 oauth = applicationProperties.getSecurity().getOauth2();
if (authentication instanceof OAuth2AuthenticationToken) {
OAuth2AuthenticationToken oauthToken = (OAuth2AuthenticationToken) authentication;
registrationId = oauthToken.getAuthorizedClientRegistrationId();
try {
Provider provider = oauth.getClient().get(registrationId);
issuer = provider.getIssuer();
clientId = provider.getClientId();
} catch (UnsupportedProviderException e) {
log.error(e.getMessage());
}
} else {
registrationId = oauth.getProvider() != null ? oauth.getProvider() : "";
issuer = oauth.getIssuer();
clientId = oauth.getClientId();
}
String errorMessage = "";
if (request.getParameter("oauth2AuthenticationErrorWeb") != null) {
param = "erroroauth=oauth2AuthenticationErrorWeb";
} else if ((errorMessage = request.getParameter("error")) != null) {
param = "error=" + sanitizeInput(errorMessage);
} else if ((errorMessage = request.getParameter("erroroauth")) != null) {
param = "erroroauth=" + sanitizeInput(errorMessage);
} else if (request.getParameter("oauth2AutoCreateDisabled") != null) {
param = "error=oauth2AutoCreateDisabled";
} else if (request.getParameter("oauth2_admin_blocked_user") != null) {
param = "erroroauth=oauth2_admin_blocked_user";
} else if (request.getParameter("userIsDisabled") != null) {
param = "erroroauth=userIsDisabled";
} else if (request.getParameter("badcredentials") != null) {
param = "error=badcredentials";
}
String redirect_url = UrlUtils.getOrigin(request) + "/login?" + param;
switch (registrationId.toLowerCase()) {
case "keycloak":
// Add Keycloak specific logout URL if needed
String logoutUrl =
issuer
+ "/protocol/openid-connect/logout"
+ "?client_id="
+ clientId
+ "&post_logout_redirect_uri="
+ response.encodeRedirectURL(redirect_url);
log.info("Redirecting to Keycloak logout URL: " + logoutUrl);
response.sendRedirect(logoutUrl);
break;
case "github":
// Add GitHub specific logout URL if needed
String githubLogoutUrl = "https://github.com/logout";
log.info("Redirecting to GitHub logout URL: " + githubLogoutUrl);
response.sendRedirect(githubLogoutUrl);
break;
case "google":
// Add Google specific logout URL if needed
// String googleLogoutUrl =
// "https://accounts.google.com/Logout?continue=https://appengine.google.com/_ah/logout?continue="
// + response.encodeRedirectURL(redirect_url);
log.info("Google does not have a specific logout URL");
// log.info("Redirecting to Google logout URL: " + googleLogoutUrl);
// response.sendRedirect(googleLogoutUrl);
// break;
default:
String defaultRedirectUrl = request.getContextPath() + "/login?" + param;
log.info("Redirecting to default logout URL: " + defaultRedirectUrl);
response.sendRedirect(defaultRedirectUrl);
break;
}
}
private String sanitizeInput(String input) {
return input.replaceAll("[^a-zA-Z0-9 ]", "");
}
}

View File

@@ -1,68 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.opensaml.saml.saml2.core.Assertion;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import lombok.extern.slf4j.Slf4j;
@Component
@Slf4j
public class ConvertResponseToAuthentication
implements Converter<ResponseToken, Saml2Authentication> {
private final Saml2AuthorityAttributeLookup saml2AuthorityAttributeLookup;
public ConvertResponseToAuthentication(
Saml2AuthorityAttributeLookup saml2AuthorityAttributeLookup) {
this.saml2AuthorityAttributeLookup = saml2AuthorityAttributeLookup;
}
@Override
public Saml2Authentication convert(ResponseToken responseToken) {
final Assertion assertion =
CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
final Map<String, List<Object>> attributes =
SamlAssertionUtils.getAssertionAttributes(assertion);
final String registrationId =
responseToken.getToken().getRelyingPartyRegistration().getRegistrationId();
final ScimSaml2AuthenticatedPrincipal principal =
new ScimSaml2AuthenticatedPrincipal(
assertion,
attributes,
saml2AuthorityAttributeLookup.getIdentityMappings(registrationId));
final Collection<? extends GrantedAuthority> assertionAuthorities =
getAssertionAuthorities(
attributes,
saml2AuthorityAttributeLookup.getAuthorityAttribute(registrationId));
return new Saml2Authentication(
principal, responseToken.getToken().getSaml2Response(), assertionAuthorities);
}
private static Collection<? extends GrantedAuthority> getAssertionAuthorities(
final Map<String, List<Object>> attributes, final String authoritiesAttributeName) {
if (attributes == null || attributes.isEmpty()) {
return Collections.emptySet();
}
final List<Object> groups = new ArrayList<>(attributes.get(authoritiesAttributeName));
return groups.stream()
.filter(String.class::isInstance)
.map(String.class::cast)
.map(String::toLowerCase)
.map(SimpleGrantedAuthority::new)
.collect(Collectors.toSet());
}
}

View File

@@ -1,51 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.authentication.BadCredentialsException;
import org.springframework.security.authentication.DisabledException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CustomSAMLAuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof BadCredentialsException) {
log.error("BadCredentialsException", exception);
getRedirectStrategy().sendRedirect(request, response, "/login?error=badcredentials");
return;
}
if (exception instanceof DisabledException) {
log.error("User is deactivated: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?userIsDisabled=true");
return;
}
if (exception instanceof LockedException) {
log.error("Account locked: ", exception);
getRedirectStrategy().sendRedirect(request, response, "/logout?error=locked");
return;
}
if (exception instanceof Saml2AuthenticationException) {
log.error("SAML2 Authentication error: ", exception);
getRedirectStrategy()
.sendRedirect(request, response, "/logout?error=saml2AuthenticationError");
return;
}
log.error("Unhandled authentication exception", exception);
super.onAuthenticationFailure(request, response, exception);
}
}

View File

@@ -1,108 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.userdetails.UserDetails;
import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.LoginAttemptService;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.OAUTH2;
import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.utils.RequestUriUtils;
@Slf4j
public class CustomSAMLAuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private UserService userService;
private ApplicationProperties applicationProperties;
public CustomSAMLAuthenticationSuccessHandler(
LoginAttemptService loginAttemptService,
UserService userService,
ApplicationProperties applicationProperties) {
this.loginAttemptService = loginAttemptService;
this.userService = userService;
this.applicationProperties = applicationProperties;
}
@Override
public void onAuthenticationSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
Object principal = authentication.getPrincipal();
String username = "";
if (principal instanceof OAuth2User) {
OAuth2User oauthUser = (OAuth2User) principal;
username = oauthUser.getName();
} else if (principal instanceof UserDetails) {
UserDetails oauthUser = (UserDetails) principal;
username = oauthUser.getUsername();
} else if (principal instanceof ScimSaml2AuthenticatedPrincipal) {
ScimSaml2AuthenticatedPrincipal samlPrincipal =
(ScimSaml2AuthenticatedPrincipal) principal;
username = samlPrincipal.getName();
}
// Get the saved request
HttpSession session = request.getSession(false);
String contextPath = request.getContextPath();
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(contextPath, savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
OAUTH2 oAuth = applicationProperties.getSecurity().getOauth2();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
}
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(
username, AuthenticationType.OAUTH2)
&& oAuth.getAutoCreateUser()) {
response.sendRedirect(contextPath + "/logout?oauth2AuthenticationErrorWeb=true");
return;
}
try {
if (oAuth.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(contextPath + "/logout?oauth2_admin_blocked_user=true");
return;
}
if (principal instanceof OAuth2User) {
userService.processOAuth2PostLogin(username, oAuth.getAutoCreateUser());
}
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
}
}

View File

@@ -1,38 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.IOException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.logout.SimpleUrlLogoutSuccessHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class SAMLLogoutSuccessHandler extends SimpleUrlLogoutSuccessHandler {
@Override
public void onLogoutSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws IOException, ServletException {
String redirectUrl = determineTargetUrl(request, response, authentication);
if (response.isCommitted()) {
log.debug("Response has already been committed. Unable to redirect to " + redirectUrl);
return;
}
getRedirectStrategy().sendRedirect(request, response, redirectUrl);
}
protected String determineTargetUrl(
HttpServletRequest request,
HttpServletResponse response,
Authentication authentication) {
// Default to the root URL
return "/";
}
}

View File

@@ -1,7 +0,0 @@
package stirling.software.SPDF.config.security.saml;
public interface Saml2AuthorityAttributeLookup {
String getAuthorityAttribute(String registrationId);
SimpleScimMappings getIdentityMappings(String registrationId);
}

View File

@@ -1,17 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import org.springframework.stereotype.Component;
@Component
public class Saml2AuthorityAttributeLookupImpl implements Saml2AuthorityAttributeLookup {
@Override
public String getAuthorityAttribute(String registrationId) {
return "authorityAttributeName";
}
@Override
public SimpleScimMappings getIdentityMappings(String registrationId) {
return new SimpleScimMappings();
}
}

View File

@@ -1,63 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.time.Instant;
import java.util.*;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.*;
import org.opensaml.saml.saml2.core.Assertion;
public class SamlAssertionUtils {
public static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
assertion
.getAttributeStatements()
.forEach(
attributeStatement -> {
attributeStatement
.getAttributes()
.forEach(
attribute -> {
List<Object> attributeValues = new ArrayList<>();
attribute
.getAttributeValues()
.forEach(
xmlObject -> {
Object attributeValue =
getXmlObjectValue(
xmlObject);
if (attributeValue != null) {
attributeValues.add(
attributeValue);
}
});
attributeMap.put(
attribute.getName(), attributeValues);
});
});
return attributeMap;
}
public static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSAny) {
return ((XSAny) xmlObject).getTextContent();
} else if (xmlObject instanceof XSString) {
return ((XSString) xmlObject).getValue();
} else if (xmlObject instanceof XSInteger) {
return ((XSInteger) xmlObject).getValue();
} else if (xmlObject instanceof XSURI) {
return ((XSURI) xmlObject).getURI();
} else if (xmlObject instanceof XSBoolean) {
return ((XSBoolean) xmlObject).getValue().getValue();
} else if (xmlObject instanceof XSDateTime) {
Instant dateTime = ((XSDateTime) xmlObject).getValue();
return (dateTime != null) ? Instant.ofEpochMilli(dateTime.toEpochMilli()) : null;
}
return null;
}
}

View File

@@ -1,42 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.security.cert.CertificateException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.model.ApplicationProperties;
@Configuration
@Slf4j
public class SamlConfig {
@Autowired ApplicationProperties applicationProperties;
@Bean
@ConditionalOnProperty(
value = "security.saml.enabled",
havingValue = "true",
matchIfMissing = false)
public RelyingPartyRegistrationRepository relyingPartyRegistrationRepository()
throws CertificateException {
RelyingPartyRegistration registration =
RelyingPartyRegistrations.fromMetadataLocation(
applicationProperties
.getSecurity()
.getSaml()
.getIdpMetadataLocation())
.entityId(applicationProperties.getSecurity().getSaml().getEntityId())
.registrationId(
applicationProperties.getSecurity().getSaml().getRegistrationId())
.build();
return new InMemoryRelyingPartyRegistrationRepository(registration);
}
}

View File

@@ -1,89 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.opensaml.saml.saml2.core.Assertion;
import org.springframework.security.core.AuthenticatedPrincipal;
import org.springframework.util.Assert;
import com.unboundid.scim2.common.types.Email;
import com.unboundid.scim2.common.types.Name;
import com.unboundid.scim2.common.types.UserResource;
public class ScimSaml2AuthenticatedPrincipal implements AuthenticatedPrincipal, Serializable {
private static final long serialVersionUID = 1L;
private final transient UserResource userResource;
public ScimSaml2AuthenticatedPrincipal(
final Assertion assertion,
final Map<String, List<Object>> attributes,
final SimpleScimMappings attributeMappings) {
Assert.notNull(assertion, "assertion cannot be null");
Assert.notNull(assertion.getSubject(), "assertion subject cannot be null");
Assert.notNull(
assertion.getSubject().getNameID(), "assertion subject NameID cannot be null");
Assert.notNull(attributes, "attributes cannot be null");
Assert.notNull(attributeMappings, "attributeMappings cannot be null");
final Name name =
new Name()
.setFamilyName(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getFamilyName))
.setGivenName(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getGivenName));
final List<Email> emails = new ArrayList<>(1);
emails.add(
new Email()
.setValue(
getAttribute(
attributes,
attributeMappings,
SimpleScimMappings::getEmail))
.setPrimary(true));
userResource =
new UserResource()
.setUserName(assertion.getSubject().getNameID().getValue())
.setName(name)
.setEmails(emails);
}
private static String getAttribute(
final Map<String, List<Object>> attributes,
final SimpleScimMappings simpleScimMappings,
final Function<SimpleScimMappings, String> attributeMapper) {
final String key = attributeMapper.apply(simpleScimMappings);
final List<Object> values = attributes.getOrDefault(key, Collections.emptyList());
return values.stream()
.filter(String.class::isInstance)
.map(String.class::cast)
.findFirst()
.orElse(null);
}
@Override
public String getName() {
return this.userResource.getUserName();
}
public UserResource getUserResource() {
return this.userResource;
}
}

View File

@@ -1,10 +0,0 @@
package stirling.software.SPDF.config.security.saml;
import lombok.Data;
@Data
public class SimpleScimMappings {
String givenName;
String familyName;
String email;
}

View File

@@ -0,0 +1,48 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.ByteArrayInputStream;
import java.nio.charset.StandardCharsets;
import java.security.KeyFactory;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.PKCS8EncodedKeySpec;
import java.util.Base64;
import org.springframework.core.io.Resource;
import org.springframework.util.FileCopyUtils;
public class CertificateUtils {
public static X509Certificate readCertificate(Resource certificateResource) throws Exception {
String certificateString =
new String(
FileCopyUtils.copyToByteArray(certificateResource.getInputStream()),
StandardCharsets.UTF_8);
String certContent =
certificateString
.replace("-----BEGIN CERTIFICATE-----", "")
.replace("-----END CERTIFICATE-----", "")
.replaceAll("\\R", "")
.replaceAll("\\s+", "");
CertificateFactory cf = CertificateFactory.getInstance("X.509");
byte[] decodedCert = Base64.getDecoder().decode(certContent);
return (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(decodedCert));
}
public static RSAPrivateKey readPrivateKey(Resource privateKeyResource) throws Exception {
String privateKeyString =
new String(
FileCopyUtils.copyToByteArray(privateKeyResource.getInputStream()),
StandardCharsets.UTF_8);
String privateKeyContent =
privateKeyString
.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replaceAll("\\R", "")
.replaceAll("\\s+", "");
KeyFactory kf = KeyFactory.getInstance("RSA");
byte[] decodedKey = Base64.getDecoder().decode(privateKeyContent);
return (RSAPrivateKey) kf.generatePrivate(new PKCS8EncodedKeySpec(decodedKey));
}
}

View File

@@ -0,0 +1,45 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticatedPrincipal;
public class CustomSaml2AuthenticatedPrincipal
implements Saml2AuthenticatedPrincipal, Serializable {
private final String name;
private final Map<String, List<Object>> attributes;
private final String nameId;
private final List<String> sessionIndexes;
public CustomSaml2AuthenticatedPrincipal(
String name,
Map<String, List<Object>> attributes,
String nameId,
List<String> sessionIndexes) {
this.name = name;
this.attributes = attributes;
this.nameId = nameId;
this.sessionIndexes = sessionIndexes;
}
@Override
public String getName() {
return this.name;
}
@Override
public Map<String, List<Object>> getAttributes() {
return this.attributes;
}
public String getNameId() {
return this.nameId;
}
public List<String> getSessionIndexes() {
return this.sessionIndexes;
}
}

View File

@@ -0,0 +1,38 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.IOException;
import org.springframework.security.authentication.ProviderNotFoundException;
import org.springframework.security.core.AuthenticationException;
import org.springframework.security.saml2.core.Saml2Error;
import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class CustomSaml2AuthenticationFailureHandler extends SimpleUrlAuthenticationFailureHandler {
@Override
public void onAuthenticationFailure(
HttpServletRequest request,
HttpServletResponse response,
AuthenticationException exception)
throws IOException, ServletException {
if (exception instanceof Saml2AuthenticationException) {
Saml2Error error = ((Saml2AuthenticationException) exception).getSaml2Error();
getRedirectStrategy()
.sendRedirect(request, response, "/login?erroroauth=" + error.getErrorCode());
} else if (exception instanceof ProviderNotFoundException) {
getRedirectStrategy()
.sendRedirect(
request,
response,
"/login?erroroauth=not_authentication_provider_found");
}
log.error("AuthenticationException: " + exception);
}
}

View File

@@ -0,0 +1,91 @@
package stirling.software.SPDF.config.security.saml2;
import java.io.IOException;
import org.springframework.security.authentication.LockedException;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
import org.springframework.security.web.savedrequest.SavedRequest;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import lombok.AllArgsConstructor;
import stirling.software.SPDF.config.security.LoginAttemptService;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.ApplicationProperties;
import stirling.software.SPDF.model.ApplicationProperties.Security.SAML2;
import stirling.software.SPDF.model.AuthenticationType;
import stirling.software.SPDF.utils.RequestUriUtils;
@AllArgsConstructor
public class CustomSaml2AuthenticationSuccessHandler
extends SavedRequestAwareAuthenticationSuccessHandler {
private LoginAttemptService loginAttemptService;
private ApplicationProperties applicationProperties;
private UserService userService;
@Override
public void onAuthenticationSuccess(
HttpServletRequest request, HttpServletResponse response, Authentication authentication)
throws ServletException, IOException {
Object principal = authentication.getPrincipal();
if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
String username = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
// Get the saved request
HttpSession session = request.getSession(false);
String contextPath = request.getContextPath();
SavedRequest savedRequest =
(session != null)
? (SavedRequest) session.getAttribute("SPRING_SECURITY_SAVED_REQUEST")
: null;
if (savedRequest != null
&& !RequestUriUtils.isStaticResource(
contextPath, savedRequest.getRedirectUrl())) {
// Redirect to the original destination
super.onAuthenticationSuccess(request, response, authentication);
} else {
SAML2 saml2 = applicationProperties.getSecurity().getSaml2();
if (loginAttemptService.isBlocked(username)) {
if (session != null) {
session.removeAttribute("SPRING_SECURITY_SAVED_REQUEST");
}
throw new LockedException(
"Your account has been locked due to too many failed login attempts.");
}
if (userService.usernameExistsIgnoreCase(username)
&& userService.hasPassword(username)
&& !userService.isAuthenticationTypeByUsername(
username, AuthenticationType.OAUTH2)
&& saml2.getAutoCreateUser()) {
response.sendRedirect(
contextPath + "/logout?oauth2AuthenticationErrorWeb=true");
return;
}
try {
if (saml2.getBlockRegistration()
&& !userService.usernameExistsIgnoreCase(username)) {
response.sendRedirect(
contextPath + "/login?erroroauth=oauth2_admin_blocked_user");
return;
}
userService.processOAuth2PostLogin(username, saml2.getAutoCreateUser());
response.sendRedirect(contextPath + "/");
return;
} catch (IllegalArgumentException e) {
response.sendRedirect(contextPath + "/logout?invalidUsername=true");
return;
}
}
} else {
super.onAuthenticationSuccess(request, response, authentication);
}
}
}

View File

@@ -0,0 +1,86 @@
package stirling.software.SPDF.config.security.saml2;
import java.util.*;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.schema.XSBoolean;
import org.opensaml.core.xml.schema.XSString;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
import org.springframework.security.saml2.provider.service.authentication.Saml2Authentication;
import org.springframework.stereotype.Component;
import lombok.extern.slf4j.Slf4j;
import stirling.software.SPDF.config.security.UserService;
import stirling.software.SPDF.model.User;
@Component
@Slf4j
public class CustomSaml2ResponseAuthenticationConverter
implements Converter<ResponseToken, Saml2Authentication> {
private UserService userService;
public CustomSaml2ResponseAuthenticationConverter(UserService userService) {
this.userService = userService;
}
@Override
public Saml2Authentication convert(ResponseToken responseToken) {
// Extract the assertion from the response
Assertion assertion = responseToken.getResponse().getAssertions().get(0);
// Extract the NameID
String nameId = assertion.getSubject().getNameID().getValue();
Optional<User> userOpt = userService.findByUsernameIgnoreCase(nameId);
SimpleGrantedAuthority simpleGrantedAuthority = new SimpleGrantedAuthority("ROLE_USER");
if (userOpt.isPresent()) {
User user = userOpt.get();
if (user != null) {
simpleGrantedAuthority =
new SimpleGrantedAuthority(userService.findRole(user).getAuthority());
}
}
// Extract the SessionIndexes
List<String> sessionIndexes = new ArrayList<>();
for (AuthnStatement authnStatement : assertion.getAuthnStatements()) {
sessionIndexes.add(authnStatement.getSessionIndex());
}
// Extract the Attributes
Map<String, List<Object>> attributes = extractAttributes(assertion);
// Create the custom principal
CustomSaml2AuthenticatedPrincipal principal =
new CustomSaml2AuthenticatedPrincipal(nameId, attributes, nameId, sessionIndexes);
// Create the Saml2Authentication
return new Saml2Authentication(
principal,
responseToken.getToken().getSaml2Response(),
Collections.singletonList(simpleGrantedAuthority));
}
private Map<String, List<Object>> extractAttributes(Assertion assertion) {
Map<String, List<Object>> attributes = new HashMap<>();
for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
for (Attribute attribute : attributeStatement.getAttributes()) {
String attributeName = attribute.getName();
List<Object> values = new ArrayList<>();
for (XMLObject xmlObject : attribute.getAttributeValues()) {
log.info("BOOL: " + ((XSBoolean) xmlObject).getValue());
values.add(((XSString) xmlObject).getValue());
}
attributes.put(attributeName, values);
}
}
return attributes;
}
}

View File

@@ -16,6 +16,7 @@ import org.springframework.security.oauth2.core.user.OAuth2User;
import org.springframework.stereotype.Component;
import jakarta.transaction.Transactional;
import stirling.software.SPDF.config.security.saml2.CustomSaml2AuthenticatedPrincipal;
import stirling.software.SPDF.model.SessionEntity;
@Component
@@ -50,6 +51,8 @@ public class SessionPersistentRegistry implements SessionRegistry {
principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) {
principalName = (String) principal;
}
@@ -79,6 +82,8 @@ public class SessionPersistentRegistry implements SessionRegistry {
principalName = ((UserDetails) principal).getUsername();
} else if (principal instanceof OAuth2User) {
principalName = ((OAuth2User) principal).getName();
} else if (principal instanceof CustomSaml2AuthenticatedPrincipal) {
principalName = ((CustomSaml2AuthenticatedPrincipal) principal).getName();
} else if (principal instanceof String) {
principalName = (String) principal;
}