Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import ai.giskard.security.AuthoritiesConstants;
import ai.giskard.security.ee.GiskardAuthConfigurer;
import ai.giskard.security.ee.jwt.TokenProvider;
import ai.giskard.service.ApiKeyService;
import ai.giskard.service.ee.LicenseService;
import org.springframework.beans.factory.annotation.Autowired;
import lombok.RequiredArgsConstructor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.http.HttpMethod;
Expand All @@ -29,11 +30,11 @@
@EnableWebSecurity
@EnableMethodSecurity(securedEnabled = true)
@Configuration
@RequiredArgsConstructor
public class SecurityConfiguration {
@Autowired
private TokenProvider tokenProvider;
@Autowired
private LicenseService licenseService;
private final TokenProvider tokenProvider;
private final LicenseService licenseService;
private final ApiKeyService apiKeyService;


@Bean
Expand Down Expand Up @@ -73,7 +74,7 @@ public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
antMatcher("/api/admin/**"),
antMatcher("/management/**")
).hasAuthority(AuthoritiesConstants.ADMIN)
.requestMatchers(antMatcher("/api/v2/settings/ml-worker-connect")).hasAuthority(AuthoritiesConstants.API)
.requestMatchers(antMatcher("/public-api/**")).hasAuthority(AuthoritiesConstants.API)
.requestMatchers(antMatcher("/api/**")).authenticated()
)
.sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS))
Expand All @@ -88,6 +89,6 @@ public PasswordEncoder passwordEncoder() {
}

private SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> securityConfigurerAdapter() {
return new GiskardAuthConfigurer(licenseService, tokenProvider);
return new GiskardAuthConfigurer(licenseService, apiKeyService, tokenProvider);
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package ai.giskard.config;

import ai.giskard.domain.ApiKey;
import ai.giskard.ml.MLWorkerID;
import ai.giskard.security.ee.ApiKeyAuthFilter;
import ai.giskard.security.ee.jwt.TokenProvider;
import ai.giskard.service.ApiKeyService;
import ai.giskard.service.FileLocationService;
import ai.giskard.service.ee.FeatureFlag;
import ai.giskard.service.ee.LicenseService;
Expand Down Expand Up @@ -35,6 +38,7 @@ public class WebSocketChannelInterceptor implements ChannelInterceptor {
private final LicenseService licenseService;
private final MLWorkerWSService mlWorkerWSService;
private final FileLocationService fileLocationService;
private final ApiKeyService apiKeyService;

private boolean isValidateTokenForInternalMLWorker(List<String> internalTokenHeaders) {
if (internalTokenHeaders != null && !internalTokenHeaders.isEmpty()
Expand Down Expand Up @@ -66,23 +70,48 @@ private Message<?> processConnectMessage(Message<?> message, StompHeaderAccessor
}

if (licenseService.hasFeature(FeatureFlag.AUTH)) {
List<String> apiKeyHeaders = accessor.getNativeHeader("api-key");
List<String> jwtHeaders = accessor.getNativeHeader("jwt");
if (jwtHeaders == null || jwtHeaders.isEmpty() || !StringUtils.hasText(jwtHeaders.get(0))) {
log.warn("Missing JWT token");
throw new AccessDeniedException("Missing JWT token");
} else if (!tokenProvider.validateToken(jwtHeaders.get(0))) {
log.warn("Invalid JWT token");
throw new AccessDeniedException("Invalid JWT token");
if (jwtHeaders != null) {
// Websocket connection is coming from the UI
extractUserFromJWTtoken(accessor, jwtHeaders);
} else if (apiKeyHeaders != null) {
// Websocket connection is coming from the ML Worker
extractUserFromAPIkey(accessor, apiKeyHeaders);
}
Authentication authentication = tokenProvider.getAuthentication(jwtHeaders.get(0));
accessor.setUser(authentication);
} else {
accessor.setUser(getDummyAuthentication());
}

return message;
}

private void extractUserFromAPIkey(StompHeaderAccessor accessor, List<String> apiKeyHeaders) {
if (apiKeyHeaders.isEmpty() || !StringUtils.hasText(apiKeyHeaders.get(0))) {
log.warn("Missing API key header");
throw new AccessDeniedException("Missing API key");
}
String apiKey = apiKeyHeaders.get(0);
if (!ApiKey.doesStringLookLikeApiKey(apiKey) || apiKeyService.getKey(apiKey).isEmpty()) {
log.warn("Invalid API key");
throw new AccessDeniedException("Invalid API key");
}
Authentication authentication = ApiKeyAuthFilter.getAuthentication(apiKeyService.getKey(apiKey).orElseThrow());
accessor.setUser(authentication);
}

private void extractUserFromJWTtoken(StompHeaderAccessor accessor, List<String> jwtHeaders) {
if (jwtHeaders.isEmpty() || !StringUtils.hasText(jwtHeaders.get(0))) {
log.warn("Missing JWT token");
throw new AccessDeniedException("Missing JWT token");
} else if (!tokenProvider.validateToken(jwtHeaders.get(0))) {
log.warn("Invalid JWT token");
throw new AccessDeniedException("Invalid JWT token");
}
Authentication authentication = tokenProvider.getAuthentication(jwtHeaders.get(0));
accessor.setUser(authentication);
}

@Override
public Message<?> preSend(Message<?> message, MessageChannel channel) {
StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package ai.giskard.config;

import ai.giskard.security.ee.jwt.TokenProvider;
import ai.giskard.service.ApiKeyService;
import ai.giskard.service.FileLocationService;
import ai.giskard.service.ee.LicenseService;
import ai.giskard.service.ml.MLWorkerWSService;
Expand All @@ -17,6 +18,7 @@ public class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBro
private final LicenseService licenseService;
private final MLWorkerWSService mlWorkerWSService;
private final FileLocationService fileLocationService;
private final ApiKeyService apiKeyService;

@Override
protected boolean sameOriginDisabled() {
Expand All @@ -26,7 +28,7 @@ protected boolean sameOriginDisabled() {
@Override
protected void customizeClientInboundChannel(ChannelRegistration registration) {
registration.interceptors(
new WebSocketChannelInterceptor(tokenProvider, licenseService, mlWorkerWSService, fileLocationService)
new WebSocketChannelInterceptor(tokenProvider, licenseService, mlWorkerWSService, fileLocationService, apiKeyService)
);
}

Expand Down
38 changes: 38 additions & 0 deletions backend/src/main/java/ai/giskard/domain/ApiKey.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package ai.giskard.domain;

import jakarta.persistence.*;
import lombok.Getter;
import lombok.NoArgsConstructor;
import org.apache.commons.lang3.RandomStringUtils;

import java.util.UUID;

@Entity(name = "api_keys")
@Getter
@NoArgsConstructor
public class ApiKey extends AbstractAuditingEntity {

public static final String PREFIX = "gsk-";
public static final int KEY_LENGTH = 32;
@Id
@GeneratedValue(strategy = GenerationType.UUID)
private UUID id;

@ManyToOne
private User user;

@Column(name = "api_key", length = KEY_LENGTH, unique = true, nullable = false)
private String key;

private String name;

public ApiKey(User user) {
this.user = user;
id = UUID.randomUUID();
key = PREFIX + RandomStringUtils.randomAlphanumeric(KEY_LENGTH - PREFIX.length()); //NOSONAR
}

public static boolean doesStringLookLikeApiKey(String str) {
return str != null && str.length() == KEY_LENGTH && str.startsWith(PREFIX);
}
}
11 changes: 11 additions & 0 deletions backend/src/main/java/ai/giskard/repository/ApiKeyRepository.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package ai.giskard.repository;

import ai.giskard.domain.ApiKey;
import org.springframework.stereotype.Repository;

import java.util.UUID;

@Repository
public interface ApiKeyRepository extends MappableJpaRepository<ApiKey, UUID> {
void deleteApiKeyByIdAndUserLogin(UUID keyId, String login);
}
50 changes: 50 additions & 0 deletions backend/src/main/java/ai/giskard/security/ee/ApiKeyAuthFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package ai.giskard.security.ee;

import ai.giskard.domain.ApiKey;
import ai.giskard.domain.Role;
import ai.giskard.security.GiskardUser;
import ai.giskard.service.ApiKeyService;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import lombok.RequiredArgsConstructor;
import org.springframework.http.HttpHeaders;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.SimpleGrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.web.filter.GenericFilterBean;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Optional;

import static ai.giskard.security.AuthoritiesConstants.API;

@RequiredArgsConstructor
public class ApiKeyAuthFilter extends GenericFilterBean {
private final ApiKeyService apiKeyService;

public static Authentication getAuthentication(ApiKey apiKey) {
Collection<SimpleGrantedAuthority> authorities = new ArrayList<>();
authorities.add(new SimpleGrantedAuthority(API));
for (Role role : apiKey.getUser().getRoles()) {
authorities.add(new SimpleGrantedAuthority(role.getName()));
}

GiskardUser principal = new GiskardUser(apiKey.getUser().getId(), apiKey.getUser().getLogin(), "", authorities);

return new UsernamePasswordAuthenticationToken(principal, apiKey.getKey(), authorities);
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
String apiKey = ((HttpServletRequest) request).getHeader(HttpHeaders.AUTHORIZATION).substring(7);
Optional<ApiKey> foundKey = apiKeyService.getKey(apiKey);
foundKey.ifPresent(key -> SecurityContextHolder.getContext().setAuthentication(getAuthentication(key)));
chain.doFilter(request, response);
}
}
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
package ai.giskard.security.ee;

import ai.giskard.security.ee.jwt.TokenProvider;
import ai.giskard.service.ApiKeyService;
import ai.giskard.service.ee.LicenseService;
import lombok.RequiredArgsConstructor;
import org.springframework.security.config.annotation.SecurityConfigurerAdapter;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.web.DefaultSecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;

@RequiredArgsConstructor
public class GiskardAuthConfigurer extends SecurityConfigurerAdapter<DefaultSecurityFilterChain, HttpSecurity> {
private final LicenseService licenseService;
private final ApiKeyService apiKeyService;
private final TokenProvider tokenProvider;

public GiskardAuthConfigurer(LicenseService licenseService, TokenProvider tokenProvider) {
this.licenseService = licenseService;
this.tokenProvider = tokenProvider;
}

@Override
public void configure(HttpSecurity http) {
GiskardAuthFilter customFilter = new GiskardAuthFilter(licenseService, tokenProvider);
GiskardAuthFilter customFilter = new GiskardAuthFilter(licenseService, apiKeyService, tokenProvider);
http.addFilterBefore(customFilter, UsernamePasswordAuthenticationFilter.class);

}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
package ai.giskard.security.ee;

import ai.giskard.domain.ApiKey;
import ai.giskard.security.ee.jwt.JWTFilter;
import ai.giskard.security.ee.jwt.TokenProvider;
import ai.giskard.service.ApiKeyService;
import ai.giskard.service.ee.FeatureFlag;
import ai.giskard.service.ee.LicenseService;
import org.springframework.web.filter.GenericFilterBean;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.web.filter.GenericFilterBean;

import java.io.IOException;

import static ai.giskard.security.ee.jwt.JWTFilter.AUTHORIZATION_HEADER;

/**
* This filter is applied for every request and will check for authentication.
Expand All @@ -25,24 +27,35 @@ public class GiskardAuthFilter extends GenericFilterBean {

private final JWTFilter jwtFilter;
private final NoAuthFilter noAuthFilter;
private final ApiKeyAuthFilter apiKeyAuthFilter;
private final NoLicenseAuthFilter noLicenseAuthFilter;

public GiskardAuthFilter(LicenseService licenseService, TokenProvider tokenProvider) {
public GiskardAuthFilter(LicenseService licenseService, ApiKeyService apiKeyService, TokenProvider tokenProvider) {
this.licenseService = licenseService;

this.jwtFilter = new JWTFilter(tokenProvider);
this.noAuthFilter = new NoAuthFilter();
this.noLicenseAuthFilter = new NoLicenseAuthFilter();
this.apiKeyAuthFilter = new ApiKeyAuthFilter(apiKeyService);
}

@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
HttpServletRequest httpServletRequest = (HttpServletRequest) request;
if (!this.licenseService.getCurrentLicense().isActive()) {
this.noLicenseAuthFilter.doFilter(request, response, chain);
} else if (this.licenseService.hasFeature(FeatureFlag.AUTH) || httpServletRequest.getHeader(AUTHORIZATION_HEADER) != null) {
} else if (this.licenseService.hasFeature(FeatureFlag.AUTH) || httpServletRequest.getHeader(HttpHeaders.AUTHORIZATION) != null) {
String authHeader = httpServletRequest.getHeader(HttpHeaders.AUTHORIZATION);
Comment thread
andreybavt marked this conversation as resolved.
if (authHeader != null) {
// remove the "Bearer " prefix
authHeader = authHeader.substring(7);
}
// even if AUTH isn't enabled (no multi-user support), check the token for python client/ML Worker requests
this.jwtFilter.doFilter(request, response, chain);
if (ApiKey.doesStringLookLikeApiKey(authHeader)) {
this.apiKeyAuthFilter.doFilter(request, response, chain);
} else {
this.jwtFilter.doFilter(request, response, chain);
}
} else {
this.noAuthFilter.doFilter(request, response, chain);
}
Expand Down
15 changes: 7 additions & 8 deletions backend/src/main/java/ai/giskard/security/ee/jwt/JWTFilter.java
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package ai.giskard.security.ee.jwt;

import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.GenericFilterBean;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletRequest;
import org.springframework.http.HttpHeaders;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.GenericFilterBean;

import java.io.IOException;

/**
Expand All @@ -18,8 +19,6 @@
*/
public class JWTFilter extends GenericFilterBean {

public static final String AUTHORIZATION_HEADER = "Authorization";

private final TokenProvider tokenProvider;

public JWTFilter(TokenProvider tokenProvider) {
Expand All @@ -39,7 +38,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
}

private String resolveToken(HttpServletRequest request) {
String bearerToken = request.getHeader(AUTHORIZATION_HEADER);
String bearerToken = request.getHeader(HttpHeaders.AUTHORIZATION);
if (StringUtils.hasText(bearerToken) && bearerToken.startsWith("Bearer ")) {
return bearerToken.substring(7);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
package ai.giskard.security.ee.jwt;

public enum JWTTokenType {
UI, API, INVITATION
UI, INVITATION
}
Loading