diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareAsyncProcessor.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareAsyncProcessor.java index ffbc8ee046a..c279b187b1a 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareAsyncProcessor.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareAsyncProcessor.java @@ -37,14 +37,19 @@ import com.datastax.oss.driver.shaded.guava.common.cache.Cache; import com.datastax.oss.driver.shaded.guava.common.cache.CacheBuilder; import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; +import com.datastax.oss.driver.shaded.guava.common.collect.Sets; import com.datastax.oss.protocol.internal.ProtocolConstants; +import com.google.common.base.Functions; import edu.umd.cs.findbugs.annotations.NonNull; import io.netty.util.concurrent.EventExecutor; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; import net.jcip.annotations.ThreadSafe; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,23 +58,73 @@ public class CqlPrepareAsyncProcessor implements RequestProcessor> { + public class CacheEntry { + + private CompletableFuture result; + private Set> futures; + private AtomicBoolean lock; + + public CacheEntry() { + + result = new CompletableFuture<>(); + futures = Sets.newHashSet(result); + lock = new AtomicBoolean(false); + } + + public void addFuture(CompletableFuture future) { + + futures.add(future); + } + + public void tryStart( + PrepareRequest request, + DefaultSession session, + InternalDriverContext context, + String sessionLogPrefix) { + // Guarantee that we'll only create one CqlPrepareHandler for each cache entry + if (lock.compareAndSet(false, true)) { + + new CqlPrepareHandler(request, session, context, sessionLogPrefix) + .handle() + .whenComplete( + (ps, t) -> { + if (t == null) { + for (CompletableFuture future : futures) { + future.complete(ps); + } + } else { + cache.invalidate(request); + for (CompletableFuture future : futures) { + future.completeExceptionally(t); + } + } + }); + } + } + + public PreparedStatement waitForResult() { + return this.result.join(); + } + } + private static final Logger LOG = LoggerFactory.getLogger(CqlPrepareAsyncProcessor.class); - protected final Cache> cache; + protected final Cache cache; public CqlPrepareAsyncProcessor() { this(Optional.empty()); } public CqlPrepareAsyncProcessor(@NonNull Optional context) { - this(CacheBuilder.newBuilder().weakValues().build(), context); + this(context, Functions.identity()); } protected CqlPrepareAsyncProcessor( - Cache> cache, - Optional context) { + Optional context, + Function, CacheBuilder> decorator) { - this.cache = cache; + CacheBuilder baseCache = CacheBuilder.newBuilder().weakValues(); + this.cache = decorator.apply(baseCache).build(); context.ifPresent( (ctx) -> { LOG.info("Adding handler to invalidate cached prepared statements on type changes"); @@ -108,11 +163,10 @@ private static boolean typeMatches(UserDefinedType oldType, DataType typeToCheck } private void onTypeChanged(TypeChangeEvent event) { - for (Map.Entry> entry : - this.cache.asMap().entrySet()) { + for (Map.Entry entry : this.cache.asMap().entrySet()) { try { - PreparedStatement stmt = entry.getValue().get(); + PreparedStatement stmt = entry.getValue().waitForResult(); if (Iterables.any( stmt.getResultSetDefinitions(), (def) -> typeMatches(event.oldType, def.getType())) || Iterables.any( @@ -141,25 +195,23 @@ public CompletionStage process( String sessionLogPrefix) { try { - CompletableFuture result = cache.getIfPresent(request); - if (result == null) { - CompletableFuture mine = new CompletableFuture<>(); - result = cache.get(request, () -> mine); - if (result == mine) { - new CqlPrepareHandler(request, session, context, sessionLogPrefix) - .handle() - .whenComplete( - (preparedStatement, error) -> { - if (error != null) { - mine.completeExceptionally(error); - cache.invalidate(request); // Make sure failure isn't cached indefinitely - } else { - mine.complete(preparedStatement); - } - }); - } - } - return result; + CompletableFuture rv = new CompletableFuture<>(); + CacheEntry entry = + cache.get( + request, + () -> { + CacheEntry newEntry = new CacheEntry(); + newEntry.addFuture(rv); + newEntry.tryStart(request, session, context, sessionLogPrefix); + return newEntry; + }); + + // We don't know whether we're dealing with a newly-created entry or one that was + // already cached so try the future insert again. We wind up duoing an extra hash op + // on the initial insert this way but that's a relatively small price to pay. + entry.addFuture(rv); + + return rv; } catch (ExecutionException e) { return CompletableFutures.failedFuture(e.getCause()); } @@ -170,7 +222,7 @@ public CompletionStage newFailure(RuntimeException error) { return CompletableFutures.failedFuture(error); } - public Cache> getCache() { + public Cache getCache() { return cache; } } diff --git a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareSyncProcessor.java b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareSyncProcessor.java index 0896df07140..cd4579e3910 100644 --- a/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareSyncProcessor.java +++ b/core/src/main/java/com/datastax/oss/driver/internal/core/cql/CqlPrepareSyncProcessor.java @@ -27,7 +27,6 @@ import com.datastax.oss.driver.internal.core.util.concurrent.BlockingOperation; import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures; import com.datastax.oss.driver.shaded.guava.common.cache.Cache; -import java.util.concurrent.CompletableFuture; import net.jcip.annotations.ThreadSafe; @ThreadSafe @@ -62,7 +61,7 @@ public PreparedStatement process( asyncProcessor.process(request, session, context, sessionLogPrefix)); } - public Cache> getCache() { + public Cache getCache() { return asyncProcessor.getCache(); } diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCachingIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCachingIT.java index 05ac3bd0e92..473f4a18ca1 100644 --- a/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCachingIT.java +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCachingIT.java @@ -24,8 +24,6 @@ import com.datastax.oss.driver.api.core.config.DefaultDriverOption; import com.datastax.oss.driver.api.core.config.DriverConfigLoader; import com.datastax.oss.driver.api.core.context.DriverContext; -import com.datastax.oss.driver.api.core.cql.PrepareRequest; -import com.datastax.oss.driver.api.core.cql.PreparedStatement; import com.datastax.oss.driver.api.core.metrics.DefaultSessionMetric; import com.datastax.oss.driver.api.core.session.ProgrammaticArguments; import com.datastax.oss.driver.api.core.session.SessionBuilder; @@ -41,7 +39,6 @@ import com.datastax.oss.driver.internal.core.session.BuiltInRequestProcessors; import com.datastax.oss.driver.internal.core.session.RequestProcessor; import com.datastax.oss.driver.internal.core.session.RequestProcessorRegistry; -import com.datastax.oss.driver.shaded.guava.common.cache.CacheBuilder; import com.datastax.oss.driver.shaded.guava.common.cache.RemovalListener; import com.datastax.oss.driver.shaded.guava.common.util.concurrent.Uninterruptibles; import com.google.common.collect.ImmutableList; @@ -54,7 +51,6 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -119,12 +115,12 @@ private static class TestCqlPrepareAsyncProcessor extends CqlPrepareAsyncProcess private static final Logger LOG = LoggerFactory.getLogger(PreparedStatementCachingIT.TestCqlPrepareAsyncProcessor.class); - private static RemovalListener> - buildCacheRemoveCallback(@NonNull Optional context) { + private static RemovalListener buildCacheRemoveCallback( + @NonNull Optional context) { return (evt) -> { try { - CompletableFuture future = evt.getValue(); - ByteBuffer queryId = Uninterruptibles.getUninterruptibly(future).getId(); + CacheEntry entry = (CacheEntry) evt.getValue(); + ByteBuffer queryId = entry.waitForResult().getId(); context.ifPresent( ctx -> ctx.getEventBus().fire(new PreparedStatementRemovalEvent(queryId))); } catch (Exception e) { @@ -136,9 +132,7 @@ private static class TestCqlPrepareAsyncProcessor extends CqlPrepareAsyncProcess public TestCqlPrepareAsyncProcessor(@NonNull Optional context) { // Default CqlPrepareAsyncProcessor uses weak values here as well. We avoid doing so // to prevent cache entries from unexpectedly disappearing mid-test. - super( - CacheBuilder.newBuilder().removalListener(buildCacheRemoveCallback(context)).build(), - context); + super(context, builder -> builder.removalListener(buildCacheRemoveCallback(context))); } } diff --git a/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCancellationIT.java b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCancellationIT.java new file mode 100644 index 00000000000..e2b7b28b178 --- /dev/null +++ b/integration-tests/src/test/java/com/datastax/oss/driver/core/cql/PreparedStatementCancellationIT.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.datastax.oss.driver.core.cql; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.fail; + +import com.datastax.oss.driver.api.core.CqlSession; +import com.datastax.oss.driver.api.core.cql.PreparedStatement; +import com.datastax.oss.driver.api.testinfra.ccm.CustomCcmRule; +import com.datastax.oss.driver.api.testinfra.session.SessionRule; +import com.datastax.oss.driver.api.testinfra.session.SessionUtils; +import com.datastax.oss.driver.categories.IsolatedTests; +import com.datastax.oss.driver.internal.core.context.DefaultDriverContext; +import com.datastax.oss.driver.internal.core.cql.CqlPrepareAsyncProcessor; +import com.datastax.oss.driver.shaded.guava.common.base.Predicates; +import com.datastax.oss.driver.shaded.guava.common.cache.Cache; +import com.datastax.oss.driver.shaded.guava.common.collect.Iterables; +import java.util.concurrent.CompletableFuture; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.rules.RuleChain; +import org.junit.rules.TestRule; + +@Category(IsolatedTests.class) +public class PreparedStatementCancellationIT { + + private CustomCcmRule ccmRule = CustomCcmRule.builder().build(); + + private SessionRule sessionRule = SessionRule.builder(ccmRule).build(); + + @Rule public TestRule chain = RuleChain.outerRule(ccmRule).around(sessionRule); + + @Before + public void setup() { + + CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace()); + session.execute("DROP TABLE IF EXISTS test_table_1"); + session.execute("CREATE TABLE test_table_1 (k int primary key, v int)"); + session.execute("INSERT INTO test_table_1 (k,v) VALUES (1, 100)"); + session.execute("INSERT INTO test_table_1 (k,v) VALUES (2, 200)"); + session.execute("INSERT INTO test_table_1 (k,v) VALUES (3, 300)"); + session.close(); + } + + @After + public void teardown() { + + CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace()); + session.execute("DROP TABLE test_table_1"); + session.close(); + } + + private CompletableFuture toCompletableFuture(CqlSession session, String cql) { + + return session.prepareAsync(cql).toCompletableFuture(); + } + + private CqlPrepareAsyncProcessor findProcessor(CqlSession session) { + + DefaultDriverContext context = (DefaultDriverContext) session.getContext(); + return (CqlPrepareAsyncProcessor) + Iterables.find( + context.getRequestProcessorRegistry().getProcessors(), + Predicates.instanceOf(CqlPrepareAsyncProcessor.class)); + } + + @Test + public void should_cache_valid_cql() throws Exception { + + CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace()); + CqlPrepareAsyncProcessor processor = findProcessor(session); + Cache cache = processor.getCache(); + assertThat(cache.size()).isEqualTo(0); + + // Make multiple CompletableFuture requests for the specified CQL, then wait until + // the cached request finishes and confirm that all futures got the same values + String cql = "select v from test_table_1 where k = ?"; + CompletableFuture cf1 = toCompletableFuture(session, cql); + CompletableFuture cf2 = toCompletableFuture(session, cql); + assertThat(cache.size()).isEqualTo(1); + + CqlPrepareAsyncProcessor.CacheEntry entry = + (CqlPrepareAsyncProcessor.CacheEntry) Iterables.get(cache.asMap().values(), 0); + PreparedStatement stmt = entry.waitForResult(); + + // Waiting for results on the cache entry should prevent us from continuing until the CacheEntry + // future returns + assertThat(cf1.isDone()).isTrue(); + assertThat(cf2.isDone()).isTrue(); + + assertThat(cf1.join()).isEqualTo(stmt); + assertThat(cf2.join()).isEqualTo(stmt); + } + + @Test + public void should_not_cache_invalid_cql() throws Exception { + + CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace()); + CqlPrepareAsyncProcessor processor = findProcessor(session); + Cache cache = processor.getCache(); + assertThat(cache.size()).isEqualTo(0); + + // Verify that we get the CompletableFuture even if the CQL is invalid but that nothing is + // cached + String cql = "select v fromfrom test_table_1 where k = ?"; + CompletableFuture cf = toCompletableFuture(session, cql); + + // join() here should throw exceptions due to the invalid syntax... for purposes of this test we + // can ignore this + try { + cf.join(); + fail(); + } catch (Exception e) { + } + + assertThat(cache.size()).isEqualTo(0); + } + + @Test + public void should_not_affect_cache_if_returned_futures_are_cancelled() throws Exception { + + CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace()); + CqlPrepareAsyncProcessor processor = findProcessor(session); + Cache cache = processor.getCache(); + assertThat(cache.size()).isEqualTo(0); + + String cql = "select v from test_table_1 where k = ?"; + CompletableFuture cf = toCompletableFuture(session, cql); + + assertThat(cf.isCancelled()).isFalse(); + assertThat(cf.cancel(false)).isTrue(); + assertThat(cf.isCancelled()).isTrue(); + assertThat(cf.isCompletedExceptionally()).isTrue(); + + // Confirm that cancelling the CompletableFuture returned to the user does _not_ cancel the + // future used within the cache. CacheEntry very deliberately doesn't maintain a reference + // to it's contained CompletableFuture so we have to get at this by secondary effects. + assertThat(cache.size()).isEqualTo(1); + CqlPrepareAsyncProcessor.CacheEntry entry = + (CqlPrepareAsyncProcessor.CacheEntry) Iterables.get(cache.asMap().values(), 0); + + PreparedStatement rv = entry.waitForResult(); + assertThat(rv).isNotNull(); + assertThat(rv.getQuery()).isEqualTo(cql); + assertThat(cache.size()).isEqualTo(1); + } +}