diff --git a/service/src/main/java/eu/gaiax/difs/aas/cache/DataCache.java b/service/src/main/java/eu/gaiax/difs/aas/cache/DataCache.java new file mode 100644 index 0000000000000000000000000000000000000000..d6621861d018270db55dd937b0321c22e38152bd --- /dev/null +++ b/service/src/main/java/eu/gaiax/difs/aas/cache/DataCache.java @@ -0,0 +1,16 @@ +package eu.gaiax.difs.aas.cache; + +import java.util.Collection; +import java.util.Map; + +public interface DataCache<K, V> { + + void clean(); + V get(K key); + Map<K, V> getAll(Collection<? extends K> keys); + void put(K key, V value); + void putAll(Map<? extends K, ? extends V> entries); + void remove(K key); + long estimatedSize(); + +} diff --git a/service/src/main/java/eu/gaiax/difs/aas/cache/TriConsumer.java b/service/src/main/java/eu/gaiax/difs/aas/cache/TriConsumer.java new file mode 100644 index 0000000000000000000000000000000000000000..52f0374adf1297414891730c8fd69cb16d906479 --- /dev/null +++ b/service/src/main/java/eu/gaiax/difs/aas/cache/TriConsumer.java @@ -0,0 +1,13 @@ +package eu.gaiax.difs.aas.cache; + +@FunctionalInterface +public interface TriConsumer<K, V> { + + void apply(K key, V value, boolean replaced); + + //default <V> TriFunction<A, B, C, V> andThen( + // Function<? super R, ? extends V> after) { + // Objects.requireNonNull(after); + // return (A a, B b, C c) -> after.apply(apply(a, b, c)); + //} +} \ No newline at end of file diff --git a/service/src/main/java/eu/gaiax/difs/aas/cache/caffeine/CaffeineDataCache.java b/service/src/main/java/eu/gaiax/difs/aas/cache/caffeine/CaffeineDataCache.java new file mode 100644 index 0000000000000000000000000000000000000000..d9c8a410354760e0c0fe25c1b7ee11b6ba59e5db --- /dev/null +++ b/service/src/main/java/eu/gaiax/difs/aas/cache/caffeine/CaffeineDataCache.java @@ -0,0 +1,85 @@ +package eu.gaiax.difs.aas.cache.caffeine; + +import java.time.Duration; +import java.util.Collection; +import java.util.Map; + +import org.checkerframework.checker.nullness.qual.Nullable; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import com.github.benmanes.caffeine.cache.RemovalListener; + +import eu.gaiax.difs.aas.cache.DataCache; +import eu.gaiax.difs.aas.cache.TriConsumer; + +public class CaffeineDataCache<K, V> implements DataCache<K, V> { + + private Cache<K, V> dataCache; + + @SuppressWarnings("unchecked") + public CaffeineDataCache(int cacheSize, Duration ttl, TriConsumer<K, V> synchronizer) { + Caffeine<K, V> cache = (Caffeine<K, V>) Caffeine.newBuilder().expireAfterAccess(ttl); + if (cacheSize > 0) { + cache = cache.maximumSize(cacheSize); + } + if (synchronizer != null) { + cache = cache.removalListener(new DataListener<>(synchronizer)); + } + dataCache = cache.build(); + } + + + @Override + public void clean() { + dataCache.cleanUp(); + } + + @Override + public V get(K key) { + return dataCache.getIfPresent(key); + } + + @Override + public Map<K, V> getAll(Collection<? extends K> keys) { + return dataCache.getAllPresent(keys); + } + + @Override + public void put(K key, V value) { + dataCache.put(key, value); + } + + @Override + public void putAll(Map<? extends K, ? extends V> entries) { + dataCache.putAll(entries); + } + + @Override + public void remove(K key) { + dataCache.invalidate(key); + } + + @Override + public long estimatedSize() { + return dataCache.estimatedSize(); + } + + + private static class DataListener<K, V> implements RemovalListener<K, V> { + + private TriConsumer<K, V> synchronizer; + + DataListener(TriConsumer<K, V> synchronizer) { + this.synchronizer = synchronizer; + } + + @Override + public void onRemoval(@Nullable K key, @Nullable V value, RemovalCause cause) { + synchronizer.apply(key, value, cause == RemovalCause.REPLACED); + } + + } + +} diff --git a/service/src/main/java/eu/gaiax/difs/aas/cache/hazelcast/HazelcastDataCache.java b/service/src/main/java/eu/gaiax/difs/aas/cache/hazelcast/HazelcastDataCache.java new file mode 100644 index 0000000000000000000000000000000000000000..93e956d977df3ffc5f0799712c224f675e417db3 --- /dev/null +++ b/service/src/main/java/eu/gaiax/difs/aas/cache/hazelcast/HazelcastDataCache.java @@ -0,0 +1,52 @@ +package eu.gaiax.difs.aas.cache.hazelcast; + +import java.util.Collection; +import java.util.Map; + +import eu.gaiax.difs.aas.cache.DataCache; + +public class HazelcastDataCache<K, V> implements DataCache<K, V> { + + @Override + public void clean() { + // TODO Auto-generated method stub + + } + + @Override + public V get(K key) { + // TODO Auto-generated method stub + return null; + } + + @Override + public Map<K, V> getAll(Collection<? extends K> keys) { + // TODO Auto-generated method stub + return null; + } + + @Override + public void put(K key, V value) { + // TODO Auto-generated method stub + + } + + @Override + public void putAll(Map<? extends K, ? extends V> map) { + // TODO Auto-generated method stub + + } + + @Override + public void remove(K key) { + // TODO Auto-generated method stub + + } + + @Override + public long estimatedSize() { + // TODO Auto-generated method stub + return 0; + } + +} diff --git a/service/src/main/java/eu/gaiax/difs/aas/service/SsiAuthorizationService.java b/service/src/main/java/eu/gaiax/difs/aas/service/SsiAuthorizationService.java index d1d4ed0fa43522426111a20bab8e8e83034835e4..b53c54d78c4efc1c469a914e6c6e69cd832ab06c 100644 --- a/service/src/main/java/eu/gaiax/difs/aas/service/SsiAuthorizationService.java +++ b/service/src/main/java/eu/gaiax/difs/aas/service/SsiAuthorizationService.java @@ -3,7 +3,6 @@ package eu.gaiax.difs.aas.service; import java.time.Duration; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import javax.annotation.PostConstruct; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -17,61 +16,62 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService; import org.springframework.util.Assert; -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; -import com.github.benmanes.caffeine.cache.RemovalCause; -import com.github.benmanes.caffeine.cache.RemovalListener; +import eu.gaiax.difs.aas.cache.DataCache; +import eu.gaiax.difs.aas.cache.caffeine.CaffeineDataCache; -public class SsiAuthorizationService implements OAuth2AuthorizationService, RemovalListener<String, OAuth2Authorization> { +public class SsiAuthorizationService implements OAuth2AuthorizationService { private static final Logger log = LoggerFactory.getLogger(SsiAuthorizationService.class); - private final int cacheSize; - private final Duration ttl; - - private Cache<String, OAuth2Authorization> authorizations; - + private final DataCache<String, OAuth2Authorization> authorizations; private final Map<String, String> codes; public SsiAuthorizationService(int cacheSize, Duration ttl) { - this.cacheSize = cacheSize; - this.ttl = ttl; + this.authorizations = new CaffeineDataCache<>(cacheSize, ttl, this::synchronize); this.codes = new ConcurrentHashMap<>(); } - @PostConstruct - public void init() { - Caffeine<Object, Object> cache = Caffeine.newBuilder().expireAfterAccess(ttl); - if (cacheSize > 0) { - cache = cache.maximumSize(cacheSize); - } - authorizations = cache.removalListener(this).build(); + public void synchronize(String key, OAuth2Authorization value, boolean replaced) { + boolean removed = false; + log.debug("synchronize; got key: {}, authorization: {}, replaced: {}", key, printAuth(value), replaced); + if (replaced) { + // + } else { + OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = value.getToken(OAuth2AuthorizationCode.class); + if (authorizationCode != null) { + removed = codes.remove(authorizationCode.getToken().getTokenValue()) != null; + } + } + log.debug("synchronize.exit; removed: {}", removed); } - + @Override public void save(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); + log.debug("save.enter; got authorization: {}", printAuth(authorization)); OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class); if (authorizationCode != null) { codes.put(authorizationCode.getToken().getTokenValue(), authorization.getId()); } this.authorizations.put(authorization.getId(), authorization); + log.debug("save.exit; authorizations: {}, codes: {}", authorizations.estimatedSize(), codes.size()); } @Override public void remove(OAuth2Authorization authorization) { Assert.notNull(authorization, "authorization cannot be null"); - log.debug("remove.enter; got authorization: {}", authorization); - this.authorizations.invalidate(authorization.getId()); + log.debug("remove.enter; got authorization: {}", printAuth(authorization)); + this.authorizations.remove(authorization.getId()); + log.debug("remove.exit; authorizations: {}, codes: {}", authorizations.estimatedSize(), codes.size()); } @Nullable @Override public OAuth2Authorization findById(String id) { Assert.hasText(id, "id cannot be empty"); - return this.authorizations.getIfPresent(id); + return this.authorizations.get(id); } @Nullable @@ -84,7 +84,7 @@ public class SsiAuthorizationService implements OAuth2AuthorizationService, Remo String id = codes.get(token); if (id != null) { OAuth2Authorization authorization = findById(id); - log.debug("findByToken.exit; returning codes: {}", authorization); + log.debug("findByToken.exit; returning auth from codes: {}", printAuth(authorization)); return authorization; } } @@ -106,20 +106,25 @@ public class SsiAuthorizationService implements OAuth2AuthorizationService, Remo //} log.info("findByToken.exit; no authorization found for token: {}, type: {}; authorizations size: {}, codes size: {}", token, tkType, authorizations.estimatedSize(), codes.size()); + if (token.startsWith("${")) { + try { + throw new Exception("debug"); + } catch (Exception ex) { + ex.printStackTrace(); + } + } return null; } - - @Override - public void onRemoval(@org.checkerframework.checker.nullness.qual.Nullable String key, - @org.checkerframework.checker.nullness.qual.Nullable OAuth2Authorization value, RemovalCause cause) { - boolean removed = false; - log.debug("onRemoval.enter; got key: {}, authorization: {}, cause: {}", key, value, cause); - OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = value.getToken(OAuth2AuthorizationCode.class); - if (authorizationCode != null) { - removed = codes.remove(authorizationCode.getToken().getTokenValue()) != null; - } - log.debug("onRemoval.exit; removed: {}", removed); + + private String printAuth(OAuth2Authorization authorization) { + return "[id: " + authorization.getId() + ", principalName: " + authorization.getPrincipalName() + + ", registeredClientId: " + authorization.getRegisteredClientId() + ", accessToken: " + (authorization.getAccessToken() == null ? + null : authorization.getAccessToken().getToken().getTokenType().getValue() + ":" + authorization.getAccessToken().getToken().getTokenValue()) + + ", attributes: " + authorization.getAttributes() + ", authorizationGrantType: " + (authorization.getAuthorizationGrantType() == null ? + null : authorization.getAuthorizationGrantType().getValue()) + + ", refreshToken: " + authorization.getRefreshToken() + "]"; } + /* private static boolean isComplete(OAuth2Authorization authorization) { return authorization.getAccessToken() != null; diff --git a/service/src/main/java/eu/gaiax/difs/aas/service/SsiBrokerService.java b/service/src/main/java/eu/gaiax/difs/aas/service/SsiBrokerService.java index bc9a548230b1b17e0f53cf6e6e05afd77b8eeb22..3b3501d1216458798a5ce7d49805e0cb0e02a11a 100644 --- a/service/src/main/java/eu/gaiax/difs/aas/service/SsiBrokerService.java +++ b/service/src/main/java/eu/gaiax/difs/aas/service/SsiBrokerService.java @@ -174,7 +174,7 @@ public class SsiBrokerService extends SsiClaimsService { String error = (String) response.get(OAuth2ParameterNames.ERROR); if (error == null) { - Collection<String> requestedScopes = (Collection<String>) claimsCache.getIfPresent(requestId).get(OAuth2ParameterNames.SCOPE); + Collection<String> requestedScopes = (Collection<String>) claimsCache.get(requestId).get(OAuth2ParameterNames.SCOPE); Set<String> requestedClaims = scopeProperties.getScopes().entrySet().stream() .filter(e -> requestedScopes.contains(e.getKey())).flatMap(e -> e.getValue().stream()).collect(Collectors.toSet()); // special handling for auth_time.. @@ -187,7 +187,7 @@ public class SsiBrokerService extends SsiClaimsService { try { verifier.verify(JWTClaimsSet.parse(response), null); } catch(ParseException | BadJWTException ex) { - claimsCache.invalidate(requestId); + claimsCache.remove(requestId); throw new ResponseStatusException(HttpStatus.BAD_REQUEST, "invalid_response: " + ex.getMessage()); } @@ -225,7 +225,7 @@ public class SsiBrokerService extends SsiClaimsService { public boolean setAdditionalParameters(String requestId, Map<String, Object> additionalParams) { log.debug("setAdditionalParameters.enter; got request: {} params: {}", requestId, additionalParams); boolean result = true; - Map<String, Object> request = claimsCache.getIfPresent(requestId); + Map<String, Object> request = claimsCache.get(requestId); if (request == null) { // throw error? result = false; @@ -241,7 +241,7 @@ public class SsiBrokerService extends SsiClaimsService { private boolean addAuthData(String requestId, Map<String, Object> data) { log.debug("addAuthData.enter; got request: {} claims size: {}", requestId, data.size()); boolean result = true; - Map<String, Object> request = claimsCache.getIfPresent(requestId); + Map<String, Object> request = claimsCache.get(requestId); if (request == null) { // throw error? result = false; @@ -254,7 +254,7 @@ public class SsiBrokerService extends SsiClaimsService { } private Boolean isValidRequest(String requestId) { - Map<String, Object> request = claimsCache.getIfPresent(requestId); + Map<String, Object> request = claimsCache.get(requestId); if (request == null) { return null; } @@ -306,7 +306,7 @@ public class SsiBrokerService extends SsiClaimsService { } public Map<String, Object> getUserClaims(String requestId, boolean required) { - Map<String, Object> userClaims = claimsCache.getIfPresent(requestId); + Map<String, Object> userClaims = claimsCache.get(requestId); if (userClaims == null) { log.warn("getUserClaims; no claims found for request: {}, required: {}", requestId, required); if (required) { @@ -325,7 +325,7 @@ public class SsiBrokerService extends SsiClaimsService { } public Set<String> getUserScopes(String requestId) { - Map<String, Object> userClaims = claimsCache.getIfPresent(requestId); + Map<String, Object> userClaims = claimsCache.get(requestId); if (userClaims == null) { log.warn("getUserScopes; no claims found for request: {}", requestId); throw new OAuth2AuthenticationException(INVALID_REQUEST); @@ -342,7 +342,7 @@ public class SsiBrokerService extends SsiClaimsService { } public Map<String, Object> getAdditionalParameters(String requestId) { - Map<String, Object> userClaims = claimsCache.getIfPresent(requestId); + Map<String, Object> userClaims = claimsCache.get(requestId); if (userClaims == null) { // log it.. return null; diff --git a/service/src/main/java/eu/gaiax/difs/aas/service/SsiClaimsService.java b/service/src/main/java/eu/gaiax/difs/aas/service/SsiClaimsService.java index 251876e79001c44df62bae77d5e1fd2275feb8d6..abfa062ea1797642c3d8f9260f2afdb3f18d489f 100644 --- a/service/src/main/java/eu/gaiax/difs/aas/service/SsiClaimsService.java +++ b/service/src/main/java/eu/gaiax/difs/aas/service/SsiClaimsService.java @@ -13,9 +13,8 @@ import javax.annotation.PostConstruct; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.oauth2.core.OAuth2AuthenticationException; -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; - +import eu.gaiax.difs.aas.cache.DataCache; +import eu.gaiax.difs.aas.cache.caffeine.CaffeineDataCache; import eu.gaiax.difs.aas.client.TrustServiceClient; import eu.gaiax.difs.aas.generated.model.AccessRequestStatusDto; @@ -32,7 +31,7 @@ public abstract class SsiClaimsService { protected final TrustServiceClient trustServiceClient; - protected Cache<String, Map<String, Object>> claimsCache; + protected DataCache<String, Map<String, Object>> claimsCache; public SsiClaimsService(TrustServiceClient trustServiceClient) { this.trustServiceClient = trustServiceClient; @@ -40,11 +39,7 @@ public abstract class SsiClaimsService { @PostConstruct public void init() { - Caffeine<Object, Object> cache = Caffeine.newBuilder().expireAfterAccess(ttl); - if (cacheSize > 0) { - cache = cache.maximumSize(cacheSize); - } - claimsCache = cache.build(); + claimsCache = new CaffeineDataCache<>(cacheSize, ttl, null); } protected Map<String, Object> loadTrustedClaims(String policy, String requestId) { diff --git a/service/src/main/java/eu/gaiax/difs/aas/service/SsiIatService.java b/service/src/main/java/eu/gaiax/difs/aas/service/SsiIatService.java index ccfb7775327d3ce40c9038d816d02f1a448b3a6a..0d23ce3109b7dbe36e1dfecab0652a4a165db466 100644 --- a/service/src/main/java/eu/gaiax/difs/aas/service/SsiIatService.java +++ b/service/src/main/java/eu/gaiax/difs/aas/service/SsiIatService.java @@ -89,7 +89,7 @@ public class SsiIatService extends SsiClaimsService { public AccessResponseDto getIatProofResult(String requestId) { log.debug("getIatProofResult.enter; got request: {}", requestId); AccessResponseDto accessResponseDto; - Map<String, Object> iatClaims = claimsCache.getIfPresent(requestId); + Map<String, Object> iatClaims = claimsCache.get(requestId); if (iatClaims == null) { iatClaims = loadTrustedClaims(GET_IAT_PROOF_RESULT, requestId); //addAuthData(requestId, iatClaims); @@ -106,7 +106,7 @@ public class SsiIatService extends SsiClaimsService { public Map<String, Object> getIatProofClaims(String subjectId, String scope, Map<String, Object> params) { log.debug("getIatProofClaims.enter; got params: {}", params); - Map<String, Object> iatClaims = claimsCache.getIfPresent(subjectId); + Map<String, Object> iatClaims = claimsCache.get(subjectId); if (iatClaims == null) { List<String> scopes = Arrays.asList(scope.split(" ")); params.put(OAuth2ParameterNames.SCOPE, scopes); @@ -130,7 +130,7 @@ public class SsiIatService extends SsiClaimsService { String entity = null; String subject = null; Collection<String> scopes = null; - Map<String, Object> iatRequest = claimsCache.getIfPresent(requestId); + Map<String, Object> iatRequest = claimsCache.get(requestId); if (iatRequest == null) { log.info("mapToIatAccessResponse; no data found for requestId: {}", requestId); } else {