Skip to content
Snippets Groups Projects
Commit ec658765 authored by Martin Lowe's avatar Martin Lowe :flag_ca:
Browse files

Merge branch 'malowe/master/3' into 'master'

Add CSRF filter + response validation to the core Quarkus lib

See merge request eclipsefdn/webdev/eclipsefdn-api-common!3
parents 82b48092 7dec6db7
No related branches found
No related tags found
1 merge request!3Add CSRF filter + response validation to the core Quarkus lib
Showing
with 719 additions and 235 deletions
package org.eclipsefoundation.core.exception;
/**
* Represents an unauthorized request with no redirect (standard UnauthorizedException gets routed
* to OIDC login page when active, which is not desired).
*
* @author Martin Lowe
*/
public class FinalUnauthorizedException extends RuntimeException {
public FinalUnauthorizedException(String message) {
super(message);
}
/** */
private static final long serialVersionUID = 1L;
}
package org.eclipsefoundation.core.helper;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import javax.annotation.PostConstruct;
import javax.inject.Singleton;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.eclipsefoundation.core.exception.FinalUnauthorizedException;
import org.eclipsefoundation.core.model.AdditionalUserData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.undertow.util.HexConverter;
/**
* Helper class for interacting with CSRF tokens within the server. Generates secure CSRF tokens and compares them to
* the copy that exists within the current session object.
*
* @author Martin Lowe
*
*/
@Singleton
public final class CSRFHelper {
public static final Logger LOGGER = LoggerFactory.getLogger(CSRFHelper.class);
public static final String CSRF_HEADER_NAME = "x-csrf-token";
@ConfigProperty(name = "security.token.salt", defaultValue = "short-salt")
String salt;
@ConfigProperty(name = "security.csrf.enabled", defaultValue = "false")
boolean csrfEnabled;
// cryptographically secure random number generator
private SecureRandom rnd;
@PostConstruct
void init() {
// create a secure Random impl using salt + timestamp bytes
rnd = new SecureRandom(Long.toString(System.currentTimeMillis()).getBytes());
}
/**
* Generate a new CSRF token that has been hardened to make it more difficult to predict.
*
* @return a cryptographically-secure CSRF token to use in a session.
*/
public String getNewCSRFToken() {
// use a random value salted with a configured static value
byte[] bytes = rnd.generateSeed(24);
String secureRnd = new String(bytes);
// create a secure random secret to embed in the user session
String preHash = secureRnd + salt;
// create new digest to hash the result
MessageDigest md;
try {
md = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Could not find SHA-256 algorithm to encode CSRF token", e);
}
// hash the results using the message digest
byte[] array = md.digest(preHash.getBytes());
// convert back to a hex string to act as a token
return HexConverter.convertToHexString(array);
}
/**
* Compares the passed CSRF token to the token for the current user session.
*
* @param aud session data for current user
* @param passedCSRF the passed CSRF header data
* @throws FinalUnauthorizedException when CSRF token is missing in the user data, the passed header value, or does
* not match
*/
public void compareCSRF(AdditionalUserData aud, String passedCSRF) {
if (csrfEnabled) {
LOGGER.debug("Comparing following tokens:\n{}\n{}", aud == null ? null : aud.getCsrf(), passedCSRF);
if (aud == null || aud.getCsrf() == null) {
throw new FinalUnauthorizedException(
"CSRF token not generated for current request and is required, refusing request");
} else if (passedCSRF == null) {
throw new FinalUnauthorizedException("No CSRF token passed for current request, refusing request");
} else if (!passedCSRF.equals(aud.getCsrf())) {
throw new FinalUnauthorizedException("CSRF tokens did not match, refusing request");
}
}
}
}
package org.eclipsefoundation.core.model;
import javax.enterprise.context.SessionScoped;
@SessionScoped
public class AdditionalUserData {
private String csrf;
/** @return the csrf */
public String getCsrf() {
return csrf;
}
/** @param csrf the csrf to set */
public void setCsrf(String csrf) {
this.csrf = csrf;
}
}
package org.eclipsefoundation.core.request;
import java.io.IOException;
import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.ext.Provider;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.eclipsefoundation.core.exception.FinalUnauthorizedException;
import org.eclipsefoundation.core.helper.CSRFHelper;
import org.eclipsefoundation.core.model.AdditionalUserData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.undertow.httpcore.HttpMethodNames;
/**
* Creates a security layer in front of mutation requests to require CSRF tokens (if enabled). This layer does not
* perform the check of the token in-case there are other conditions that would rebuff the request.
*
* @author Martin Lowe
*/
@Provider
public class CSRFSecurityFilter implements ContainerRequestFilter {
public static final Logger LOGGER = LoggerFactory.getLogger(CSRFSecurityFilter.class);
@ConfigProperty(name = "security.csrf.enabled", defaultValue = "false")
boolean csrfEnabled;
@Inject
CSRFHelper csrf;
@Inject
AdditionalUserData aud;
@Override
public void filter(ContainerRequestContext requestContext) throws IOException {
if (csrfEnabled) {
// check if the HTTP method indicates a mutation
String method = requestContext.getMethod();
if (HttpMethodNames.DELETE.equals(method) || HttpMethodNames.POST.equals(method)
|| HttpMethodNames.PUT.equals(method)) {
// check csrf token presence (not value)
String token = requestContext.getHeaderString(CSRFHelper.CSRF_HEADER_NAME);
if (token == null || "".equals(token.trim())) {
throw new FinalUnauthorizedException("No CSRF token passed for mutation call, refusing connection");
} else {
// run comparison. If error, exception will be thrown
csrf.compareCSRF(aud, token);
}
}
}
}
}
/* Copyright (c) 2019 Eclipse Foundation and others.
* This program and the accompanying materials are made available
* under the terms of the Eclipse Public License 2.0
* which is available at http://www.eclipse.org/legal/epl-v20.html,
* SPDX-License-Identifier: EPL-2.0
*/
package org.eclipsefoundation.core.resource.mapper;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.ExceptionMapper;
import javax.ws.rs.ext.Provider;
import org.eclipsefoundation.core.exception.FinalUnauthorizedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Exception mapper to allow 403 to be thrown past auth barrier. Typical unauthorized exceptions
* cause redirects through OIDC layers which isn't always wanted
*
* @author Martin Lowe
*/
@Provider
public class FinalUnauthorizedMapper implements ExceptionMapper<FinalUnauthorizedException> {
private static final Logger LOGGER = LoggerFactory.getLogger(FinalUnauthorizedMapper.class);
@Override
public Response toResponse(FinalUnauthorizedException exception) {
LOGGER.error(exception.getMessage(), exception);
// return an empty response with a server error response
return Response.status(403).build();
}
}
package org.eclipsefoundation.core.response;
import java.io.IOException;
import javax.inject.Inject;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerResponseContext;
import javax.ws.rs.container.ContainerResponseFilter;
import javax.ws.rs.ext.Provider;
import org.eclipse.microprofile.config.inject.ConfigProperty;
import org.eclipsefoundation.core.helper.CSRFHelper;
import org.eclipsefoundation.core.model.AdditionalUserData;
/**
* Injects the CSRF header token into the response when enabled for a server.
*
* @author Martin Lowe
*/
@Provider
public class CSRFHeaderFilter implements ContainerResponseFilter {
@ConfigProperty(name = "security.csrf.enabled", defaultValue = "false")
boolean csrfEnabled;
@Inject
CSRFHelper csrf;
@Inject
AdditionalUserData aud;
@Override
public void filter(ContainerRequestContext requestContext, ContainerResponseContext responseContext)
throws IOException {
// only attach if CSRF is enabled for the current runtime
if (csrfEnabled) {
// generate a new token if none is yet present
if (aud.getCsrf() == null) {
aud.setCsrf(csrf.getNewCSRFToken());
}
// attach the current CSRF token as a header on the request
responseContext.getHeaders().add(CSRFHelper.CSRF_HEADER_NAME, aud.getCsrf());
}
}
}
## OAUTH CONFIG ## OAUTH CONFIG
quarkus.oauth2.enabled=true quarkus.oauth2.enabled=false
quarkus.oauth2.introspection-url=http://accounts.eclipse.org/oauth2/introspect quarkus.oauth2.introspection-url=http://accounts.eclipse.org/oauth2/introspect
eclipse.oauth.override=false eclipse.oauth.override=false
......
package org.eclipsefoundation.core.authenticated.helper;
import javax.inject.Inject;
import org.eclipsefoundation.core.exception.FinalUnauthorizedException;
import org.eclipsefoundation.core.helper.CSRFHelper;
import org.eclipsefoundation.core.model.AdditionalUserData;
import org.eclipsefoundation.core.test.AuthenticatedTestProfile;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.TestProfile;
/**
* Test CSRF functionality from the helper directly using the authentication secured test profile.
*
* @author Martin Lowe
*/
@QuarkusTest
@TestProfile(AuthenticatedTestProfile.class)
class CSRFHelperTest {
@Inject
CSRFHelper csrf;
@Test
void compareCSRF_validToken() {
// generate a token to use in test
String csrfToken = csrf.getNewCSRFToken();
// create session object with given CSRF token
AdditionalUserData aud = new AdditionalUserData();
aud.setCsrf(csrfToken);
// this should not throw as the tokens match
Assertions.assertDoesNotThrow(() -> csrf.compareCSRF(aud, csrfToken));
}
@Test
void compareCSRF_invalidToken() {
// generate a token to use in test
String csrfToken = csrf.getNewCSRFToken();
// create session object with given CSRF token
AdditionalUserData aud = new AdditionalUserData();
aud.setCsrf(csrfToken);
// this should throw as the tokens are not the same
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, "some-other-value"));
}
@Test
void compareCSRF_noSubmittedToken() {
// generate a token to use in test
String csrfToken = csrf.getNewCSRFToken();
// create session object with given CSRF token
AdditionalUserData aud = new AdditionalUserData();
aud.setCsrf(csrfToken);
// this should throw as the tokens are not the same
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, null));
// reset token value as its cleared between requests
aud.setCsrf(csrfToken);
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, ""));
}
@Test
void compareCSRF_noGeneratedToken() {
// simulates a session object with no CSRF data (no previous calls)
AdditionalUserData aud = new AdditionalUserData();
String sampleCSRF = csrf.getNewCSRFToken();
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, null));
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, ""));
Assertions.assertThrows(FinalUnauthorizedException.class, () -> csrf.compareCSRF(aud, sampleCSRF));
}
}
package org.eclipsefoundation.core.authenticated.request;
import static io.restassured.RestAssured.given;
import org.eclipsefoundation.core.helper.CSRFHelper;
import org.eclipsefoundation.core.test.AuthenticatedTestProfile;
import org.junit.jupiter.api.Test;
import io.quarkus.test.junit.QuarkusTest;
import io.quarkus.test.junit.TestProfile;
import io.restassured.filter.session.SessionFilter;
import io.restassured.response.Response;
/**
* Test the CSRF security filter which can block requests based on presence of CSRF token. This makes use of the
* authenticated test profile to reduce complexity of testing other facets of the core lib that have no interactions
* with security.
*
* @author Martin Lowe
*
*/
@QuarkusTest
@TestProfile(AuthenticatedTestProfile.class)
class CSRFSecurityFilterTest {
@Test
void validateNoToken() {
// expect rebuff as no CSRF token was passed
given().when().get("/test").then().statusCode(403);
given().when().post("/test").then().statusCode(403);
given().when().put("/test").then().statusCode(403);
given().when().delete("/test").then().statusCode(403);
}
@Test
void validateWrongToken() {
// do a good request to trigger the build of the header internally
given().when().get("/test/unguarded").then().statusCode(200);
// expect rebuff as no CSRF token was passed
given().header(CSRFHelper.CSRF_HEADER_NAME, "bad-header-value").when().get("/test").then().statusCode(403);
given().header(CSRFHelper.CSRF_HEADER_NAME, "bad-header-value").when().post("/test").then().statusCode(403);
given().header(CSRFHelper.CSRF_HEADER_NAME, "bad-header-value").when().put("/test").then().statusCode(403);
given().header(CSRFHelper.CSRF_HEADER_NAME, "bad-header-value").when().delete("/test").then().statusCode(403);
}
@Test
void validateRightCSRFToken() {
SessionFilter sessionFilter = new SessionFilter();
// do a good request to trigger the build of the header internally
Response r = given().filter(sessionFilter).when().get("/test/unguarded");
String expectedHeader = r.getHeader(CSRFHelper.CSRF_HEADER_NAME);
// expect rebuff as no CSRF token was passed
given().filter(sessionFilter).header(CSRFHelper.CSRF_HEADER_NAME, expectedHeader).when().post("/test").then()
.statusCode(200);
given().filter(sessionFilter).header(CSRFHelper.CSRF_HEADER_NAME, expectedHeader).when().delete("/test").then()
.statusCode(200);
given().filter(sessionFilter).header(CSRFHelper.CSRF_HEADER_NAME, expectedHeader).when().put("/test").then()
.statusCode(200);
given().filter(sessionFilter).header(CSRFHelper.CSRF_HEADER_NAME, expectedHeader).when().get("/test").then()
.statusCode(200);
}
}
package org.eclipsefoundation.core.config;
import javax.enterprise.context.ApplicationScoped;
import io.quarkus.security.identity.AuthenticationRequestContext;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.SecurityIdentityAugmentor;
import io.quarkus.security.runtime.QuarkusSecurityIdentity;
import io.smallrye.mutiny.Uni;
/**
* Custom override for test classes that ignores current login state and sets
* all users as admin always. This should only ever be used in testing.
*
* @author Martin Lowe
*/
@ApplicationScoped
public class MockRoleAugmentor implements SecurityIdentityAugmentor {
@Override
public int priority() {
return 0;
}
@Override
public Uni<SecurityIdentity> augment(SecurityIdentity identity, AuthenticationRequestContext context) {
// create a new builder and copy principal, attributes, credentials and roles
// from the original
QuarkusSecurityIdentity.Builder builder = QuarkusSecurityIdentity.builder()
.setPrincipal(identity.getPrincipal()).addAttributes(identity.getAttributes())
.addCredentials(identity.getCredentials()).addRoles(identity.getRoles());
// add custom role source here
builder.addRole("marketplace_admin_access");
return context.runBlocking(builder::build);
}
}
\ No newline at end of file
/* Copyright (c) 2019 Eclipse Foundation and others.
* This program and the accompanying materials are made available
* under the terms of the Eclipse Public License 2.0
* which is available at http://www.eclipse.org/legal/epl-v20.html,
* SPDX-License-Identifier: EPL-2.0
*/
package org.eclipsefoundation.core.model;
/**
* Wraps RequestWrapper for access outside of request scope in tests.
*
* @author Martin Lowe
*/
public class RequestWrapperMock extends RequestWrapper {
}
package org.eclipsefoundation.core.test;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import io.quarkus.test.junit.QuarkusTestProfile;
/**
* Used to enable authentication profile in testing for tests that use this profile. Note that tests that use this
* profile should be grouped in a single package to ensure back to back runs to not increase test run time. More
* available on https://quarkus.io/blog/quarkus-test-profiles/.
*
* @author Martin Lowe
*/
public class AuthenticatedTestProfile implements QuarkusTestProfile {
// private immutable copy of the configs for auth state
private static final Map<String, String> CONFIG_OVERRIDES;
static {
Map<String, String> tmp = new HashMap<>();
tmp.put("quarkus.oauth2.enabled", "true");
tmp.put("security.csrf.enabled", "true");
tmp.put("security.token.salt", "sample-salt-value-64^%$6DG54$DG46%Eas6egf54s%1#g5");
CONFIG_OVERRIDES = Collections.unmodifiableMap(tmp);
}
@Override
public Map<String, String> getConfigOverrides() {
return CONFIG_OVERRIDES;
}
}
package org.eclipsefoundation.core.test;
import javax.inject.Inject;
import javax.ws.rs.DELETE;
import javax.ws.rs.GET;
import javax.ws.rs.HeaderParam;
import javax.ws.rs.POST;
import javax.ws.rs.PUT;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;
import org.eclipsefoundation.core.helper.CSRFHelper;
import org.eclipsefoundation.core.model.AdditionalUserData;
/**
* Test resource for testing basic responses and security/authentication measures.
*
* @author Martin Lowe
*
*/
@Path("/test")
@Produces(MediaType.APPLICATION_JSON)
public class TestResource {
@Inject
CSRFHelper csrf;
@Inject
AdditionalUserData aud;
/**
* Basic sample GET call that can be gated through CSRF if enabled to protect the data further.
*
* @param passedCsrf the passed CSRF header value
* @return empty ok response if CSRF is disabled or properly passed, 403 response otherwise.
*/
@GET
public Response get(@HeaderParam(value = CSRFHelper.CSRF_HEADER_NAME) String passedCsrf) {
// check CSRF manually as its a get request
csrf.compareCSRF(aud, passedCsrf);
return Response.ok().build();
}
/**
* Basic sample GET call that is not gated through CSRF. This could represent a user data endpoint, or a straight
* CSRF endpoint to trigger the header to be returned if enabled.
*
* @return empty ok response
*/
@GET
@Path("unguarded")
public Response get() {
// check CSRF manually as its a get request
return Response.ok().build();
}
/**
* Basic POST call that can be used to assist in validating filters.
*
* @return empty ok response if CSRF is disabled or properly passed, 403 response otherwise.
*/
@POST
public Response post() {
return Response.ok().build();
}
/**
* Basic PUT call that can be used to assist in validating filters.
*
* @return empty ok response if CSRF is disabled or properly passed, 403 response otherwise.
*/
@PUT
public Response put() {
return Response.ok().build();
}
/**
* Basic DELETE call that can be used to assist in validating filters.
*
* @return empty ok response if CSRF is disabled or properly passed, 403 response otherwise.
*/
@DELETE
public Response delete() {
return Response.ok().build();
}
}
...@@ -17,52 +17,60 @@ import org.eclipsefoundation.persistence.model.RDBMSQuery; ...@@ -17,52 +17,60 @@ import org.eclipsefoundation.persistence.model.RDBMSQuery;
import io.quarkus.runtime.StartupEvent; import io.quarkus.runtime.StartupEvent;
/** /**
* Interface for classes communicating with MongoDB. Assumes that reactive * Interface for classes communicating with MongoDB. Assumes that reactive stream asynchronous calls are used rather
* stream asynchronous calls are used rather than blocking methods. * than blocking methods.
* *
* @author Martin Lowe * @author Martin Lowe
*/ */
public interface PersistenceDao extends HealthCheck { public interface PersistenceDao extends HealthCheck {
/** /**
* Retrieves a list of typed results given the query passed. * Retrieves a list of typed results given the query passed.
* *
* @param q the query object for the current operation * @param q the query object for the current operation
* @return a future result set of objects of type set in query * @return a future result set of objects of type set in query
*/ */
<T extends BareNode> List<T> get(RDBMSQuery<T> q); <T extends BareNode> List<T> get(RDBMSQuery<T> q);
/** /**
* Adds a list of typed documents to the currently active database and schema, * Adds a list of typed documents to the currently active database and schema, using the query object to access the
* using the query object to access the document type. * document type.
* *
* @param <T> the type of document to post * @param <T> the type of document to post
* @param q the query object for the current operation * @param q the query object for the current operation
* @param documents the list of typed documents to add to the database instance. * @param documents the list of typed documents to add to the database instance.
* @return a future Void result indicating success on return. * @return a future Void result indicating success on return.
*/ */
<T extends BareNode> List<T> add(RDBMSQuery<T> q, List<T> documents); <T extends BareNode> List<T> add(RDBMSQuery<T> q, List<T> documents);
/** /**
* Deletes documents that match the given query. * Deletes documents that match the given query.
* *
* @param <T> the type of document that is being deleted * @param <T> the type of document that is being deleted
* @param q the query object for the current operation * @param q the query object for the current operation
* @return a future deletion result indicating whether the operation was * @return a future deletion result indicating whether the operation was successful
* successful */
*/ <T extends BareNode> void delete(RDBMSQuery<T> q);
<T extends BareNode> void delete(RDBMSQuery<T> q);
/** /**
* Counts the number of filtered results of the given document type present. * Counts the number of filtered results of the given document type present.
* *
* @param q the query object for the current operation * @param q the query object for the current operation
* @return a long result representing the number of results available for the * @return a long result representing the number of results available for the given query and docuement type.
* given query and docuement type. */
*/ Long count(RDBMSQuery<?> q);
Long count(RDBMSQuery<?> q);
default void startup(@Observes StartupEvent event) { /**
// intentionally empty * Retrieves a reference of an object to be used in operations on the server. This object is a proxy meant to help
} * build FK relationships, but can be used in other operations as well.
*
* @param id the ID of the object to retrieve
* @param type the type of object that should be retrieved
* @return a reference to the DB object if found, null otherwise
*/
<T extends BareNode> T getReference(Object id, Class<T> type);
default void startup(@Observes StartupEvent event) {
// intentionally empty
}
} }
...@@ -32,145 +32,153 @@ import org.slf4j.LoggerFactory; ...@@ -32,145 +32,153 @@ import org.slf4j.LoggerFactory;
* @author Martin Lowe * @author Martin Lowe
*/ */
public class DefaultHibernateDao implements PersistenceDao { public class DefaultHibernateDao implements PersistenceDao {
private static final Logger LOGGER = LoggerFactory.getLogger(DefaultHibernateDao.class); private static final Logger LOGGER = LoggerFactory.getLogger(DefaultHibernateDao.class);
@Inject @Inject
EntityManager em; EntityManager em;
@ConfigProperty(name = "eclipse.db.default.limit") @ConfigProperty(name = "eclipse.db.default.limit")
int defaultLimit; int defaultLimit;
@ConfigProperty(name = "eclipse.db.default.limit.max") @ConfigProperty(name = "eclipse.db.default.limit.max")
int defaultMax; int defaultMax;
@ConfigProperty(name = "eclipse.db.maintenance", defaultValue = "false") @ConfigProperty(name = "eclipse.db.maintenance", defaultValue = "false")
boolean maintenanceFlag; boolean maintenanceFlag;
@Override @Override
public <T extends BareNode> List<T> get(RDBMSQuery<T> q) { public <T extends BareNode> List<T> get(RDBMSQuery<T> q) {
if (maintenanceFlag) { if (maintenanceFlag) {
throw new MaintenanceException(); throw new MaintenanceException();
} }
if (LOGGER.isDebugEnabled()) { if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Querying DB using the following query: {}", q.getFilter().getSelectSql()); LOGGER.debug("Querying DB using the following query: {}", q.getFilter().getSelectSql());
} }
// build base query // build base query
TypedQuery<T> query = em.createQuery(q.getFilter().getSelectSql(), q.getDocType()); TypedQuery<T> query = em.createQuery(q.getFilter().getSelectSql(), q.getDocType());
// add ordinal parameters // add ordinal parameters
int ord = 1; int ord = 1;
for (Clause c : q.getFilter().getClauses()) { for (Clause c : q.getFilter().getClauses()) {
for (Object param : c.getParams()) { for (Object param : c.getParams()) {
query.setParameter(ord++, param); query.setParameter(ord++, param);
} }
} }
// check if result set should be limited // check if result set should be limited
if (q.getDTOFilter().useLimit()) { if (q.getDTOFilter().useLimit()) {
query = query.setFirstResult(getOffset(q)).setMaxResults(getLimit(q)); query = query.setFirstResult(getOffset(q)).setMaxResults(getLimit(q));
} }
// run the query // run the query
return query.getResultList(); return query.getResultList();
} }
@Transactional @Transactional
@Override @Override
public <T extends BareNode> List<T> add(RDBMSQuery<T> q, List<T> documents) { public <T extends BareNode> List<T> add(RDBMSQuery<T> q, List<T> documents) {
if (maintenanceFlag) { if (maintenanceFlag) {
throw new MaintenanceException(); throw new MaintenanceException();
} }
if (LOGGER.isDebugEnabled()) { if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Adding {} documents to DB of type {}", documents.size(), q.getDocType().getSimpleName()); LOGGER.debug("Adding {} documents to DB of type {}", documents.size(), q.getDocType().getSimpleName());
} }
// for each doc, check if update or create // for each doc, check if update or create
List<T> updatedDocs = new ArrayList<>(documents.size()); List<T> updatedDocs = new ArrayList<>(documents.size());
for (T doc : documents) { for (T doc : documents) {
T ref = doc; T ref = doc;
if (doc.getId() != null) { if (doc.getId() != null) {
// ensure this object exists before merging on it // ensure this object exists before merging on it
if (em.find(q.getDocType(), doc.getId()) != null) { if (em.find(q.getDocType(), doc.getId()) != null) {
LOGGER.debug("Merging document with existing document with id '{}'", doc.getId()); LOGGER.debug("Merging document with existing document with id '{}'", doc.getId());
ref = em.merge(doc); ref = em.merge(doc);
} else { } else {
LOGGER.debug("Persisting new document with id '{}'", doc.getId()); LOGGER.debug("Persisting new document with id '{}'", doc.getId());
em.persist(doc); em.persist(doc);
} }
} else { } else {
LOGGER.debug("Persisting new document with generated UUID ID"); LOGGER.debug("Persisting new document with generated UUID ID");
em.persist(doc); em.persist(doc);
} }
// add the ref to the output list // add the ref to the output list
updatedDocs.add(ref); updatedDocs.add(ref);
} }
return updatedDocs; return updatedDocs;
} }
@Transactional @Transactional
@Override @Override
public <T extends BareNode> void delete(RDBMSQuery<T> q) { public <T extends BareNode> void delete(RDBMSQuery<T> q) {
if (maintenanceFlag) { if (maintenanceFlag) {
throw new MaintenanceException(); throw new MaintenanceException();
} }
if (LOGGER.isDebugEnabled()) { if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Removing documents from DB using the following query: {}", q); LOGGER.debug("Removing documents from DB using the following query: {}", q);
} }
// retrieve results for the given deletion query to delete using entity manager // retrieve results for the given deletion query to delete using entity manager
List<T> results = get(q); List<T> results = get(q);
if (results.isEmpty()) { if (results.isEmpty()) {
throw new NoResultException("Could not find any documents with given filters"); throw new NoResultException("Could not find any documents with given filters");
} }
// remove all matched documents // remove all matched documents
results.forEach(em::remove); results.forEach(em::remove);
} }
@Transactional @Transactional
@Override @Override
public Long count(RDBMSQuery<?> q) { public Long count(RDBMSQuery<?> q) {
if (maintenanceFlag) { if (maintenanceFlag) {
throw new MaintenanceException(); throw new MaintenanceException();
} }
if (LOGGER.isDebugEnabled()) { if (LOGGER.isDebugEnabled()) {
LOGGER.debug("Counting documents in DB that match the following query: {}", q.getFilter().getCountSql()); LOGGER.debug("Counting documents in DB that match the following query: {}", q.getFilter().getCountSql());
} }
// build base query // build base query
TypedQuery<Long> query = em.createQuery(q.getFilter().getCountSql(), Long.class); TypedQuery<Long> query = em.createQuery(q.getFilter().getCountSql(), Long.class);
// add ordinal parameters // add ordinal parameters
int ord = 1; int ord = 1;
for (Clause c : q.getFilter().getClauses()) { for (Clause c : q.getFilter().getClauses()) {
for (Object param : c.getParams()) { for (Object param : c.getParams()) {
query.setParameter(ord++, param); query.setParameter(ord++, param);
} }
} }
return query.getSingleResult(); return query.getSingleResult();
} }
private int getLimit(RDBMSQuery<?> q) { @Override
return q.getLimit() > 0 ? Math.min(q.getLimit(), defaultMax) : defaultLimit; public <T extends BareNode> T getReference(Object id, Class<T> type) {
} if (maintenanceFlag) {
throw new MaintenanceException();
private int getOffset(RDBMSQuery<?> q) { }
// allow for manual offsetting return em.getReference(type, id);
int manualOffset = q.getManualOffset(); }
if (manualOffset > 0) {
return manualOffset; private int getLimit(RDBMSQuery<?> q) {
} return q.getLimit() > 0 ? Math.min(q.getLimit(), defaultMax) : defaultLimit;
// if first page, no offset }
if (q.getPage() <= 1) {
return 0; private int getOffset(RDBMSQuery<?> q) {
} // allow for manual offsetting
int limit = getLimit(q); int manualOffset = q.getManualOffset();
return (limit * q.getPage()) - limit; if (manualOffset > 0) {
} return manualOffset;
}
@Override // if first page, no offset
public HealthCheckResponse call() { if (q.getPage() <= 1) {
HealthCheckResponseBuilder b = HealthCheckResponse.named("DB readiness"); return 0;
if (maintenanceFlag) { }
return b.down().withData("error", "Maintenance flag is set").build(); int limit = getLimit(q);
} return (limit * q.getPage()) - limit;
return b.up().build(); }
}
@Override
public HealthCheckResponse call() {
HealthCheckResponseBuilder b = HealthCheckResponse.named("DB readiness");
if (maintenanceFlag) {
return b.down().withData("error", "Maintenance flag is set").build();
}
return b.up().build();
}
} }
...@@ -42,4 +42,9 @@ public class PlaceholderPersistenceDao implements PersistenceDao { ...@@ -42,4 +42,9 @@ public class PlaceholderPersistenceDao implements PersistenceDao {
throw new IllegalStateException("Placeholder DAO should not be used in running instances"); throw new IllegalStateException("Placeholder DAO should not be used in running instances");
} }
@Override
public <T extends BareNode> T getReference(Object id, Class<T> type) {
throw new IllegalStateException("Placeholder DAO should not be used in running instances");
}
} }
...@@ -41,7 +41,7 @@ public abstract class BareNode { ...@@ -41,7 +41,7 @@ public abstract class BareNode {
if (getClass() != obj.getClass()) { if (getClass() != obj.getClass()) {
return false; return false;
} }
NodeBase other = (NodeBase) obj; BareNode other = (BareNode) obj;
return super.equals(obj) && Objects.equals(getId(), other.getId()); return super.equals(obj) && Objects.equals(getId(), other.getId());
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment