Skip to content

JAVA-3055 CqlPrepareAsyncProcessor must handle cancellations of the returned Future #2003

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,19 @@
import com.datastax.oss.driver.shaded.guava.common.cache.Cache;
import com.datastax.oss.driver.shaded.guava.common.cache.CacheBuilder;
import com.datastax.oss.driver.shaded.guava.common.collect.Iterables;
import com.datastax.oss.driver.shaded.guava.common.collect.Sets;
import com.datastax.oss.protocol.internal.ProtocolConstants;
import com.google.common.base.Functions;
import edu.umd.cs.findbugs.annotations.NonNull;
import io.netty.util.concurrent.EventExecutor;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import net.jcip.annotations.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -53,23 +58,73 @@
public class CqlPrepareAsyncProcessor
implements RequestProcessor<PrepareRequest, CompletionStage<PreparedStatement>> {

public class CacheEntry {

private CompletableFuture<PreparedStatement> result;
private Set<CompletableFuture<PreparedStatement>> futures;
private AtomicBoolean lock;

public CacheEntry() {

result = new CompletableFuture<>();
futures = Sets.newHashSet(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One observation, HashSet is not thread safe, and futures get added to it in the threads calling prepare|Async (main thread in tests) concurrently.

From what I can tell, Set is used here because in process we want to ensure the future gets added before starting when creating the initial cache entry, and then it's added after the entry is created to ensure it's accounted for, right?

Can probably resolve this by using ConcurrentHashMap or Collections.synchronizedSet

lock = new AtomicBoolean(false);
}
Comment on lines +63 to +72
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small nit: could probably simplify this by initializing all of these in the variable declaration and then getting rid of the constructor. Should also make these fields all final.


public void addFuture(CompletableFuture<PreparedStatement> future) {

futures.add(future);
}

public void tryStart(
PrepareRequest request,
DefaultSession session,
InternalDriverContext context,
String sessionLogPrefix) {
// Guarantee that we'll only create one CqlPrepareHandler for each cache entry
if (lock.compareAndSet(false, true)) {

new CqlPrepareHandler(request, session, context, sessionLogPrefix)
.handle()
.whenComplete(
(ps, t) -> {
if (t == null) {
for (CompletableFuture<PreparedStatement> future : futures) {
future.complete(ps);
}
} else {
cache.invalidate(request);
for (CompletableFuture<PreparedStatement> future : futures) {
future.completeExceptionally(t);
}
}
});
}
}

public PreparedStatement waitForResult() {
return this.result.join();
}
}

private static final Logger LOG = LoggerFactory.getLogger(CqlPrepareAsyncProcessor.class);

protected final Cache<PrepareRequest, CompletableFuture<PreparedStatement>> cache;
protected final Cache<PrepareRequest, CacheEntry> cache;

public CqlPrepareAsyncProcessor() {
this(Optional.empty());
}

public CqlPrepareAsyncProcessor(@NonNull Optional<? extends DefaultDriverContext> context) {
this(CacheBuilder.newBuilder().weakValues().build(), context);
this(context, Functions.identity());
}

protected CqlPrepareAsyncProcessor(
Cache<PrepareRequest, CompletableFuture<PreparedStatement>> cache,
Optional<? extends DefaultDriverContext> context) {
Optional<? extends DefaultDriverContext> context,
Function<CacheBuilder<Object, Object>, CacheBuilder<Object, Object>> decorator) {

this.cache = cache;
CacheBuilder<Object, Object> baseCache = CacheBuilder.newBuilder().weakValues();
this.cache = decorator.apply(baseCache).build();
context.ifPresent(
(ctx) -> {
LOG.info("Adding handler to invalidate cached prepared statements on type changes");
Expand Down Expand Up @@ -108,11 +163,10 @@ private static boolean typeMatches(UserDefinedType oldType, DataType typeToCheck
}

private void onTypeChanged(TypeChangeEvent event) {
for (Map.Entry<PrepareRequest, CompletableFuture<PreparedStatement>> entry :
this.cache.asMap().entrySet()) {
for (Map.Entry<PrepareRequest, CacheEntry> entry : this.cache.asMap().entrySet()) {

try {
PreparedStatement stmt = entry.getValue().get();
PreparedStatement stmt = entry.getValue().waitForResult();
if (Iterables.any(
stmt.getResultSetDefinitions(), (def) -> typeMatches(event.oldType, def.getType()))
|| Iterables.any(
Expand Down Expand Up @@ -141,25 +195,23 @@ public CompletionStage<PreparedStatement> process(
String sessionLogPrefix) {

try {
CompletableFuture<PreparedStatement> result = cache.getIfPresent(request);
if (result == null) {
CompletableFuture<PreparedStatement> mine = new CompletableFuture<>();
result = cache.get(request, () -> mine);
if (result == mine) {
new CqlPrepareHandler(request, session, context, sessionLogPrefix)
.handle()
.whenComplete(
(preparedStatement, error) -> {
if (error != null) {
mine.completeExceptionally(error);
cache.invalidate(request); // Make sure failure isn't cached indefinitely
} else {
mine.complete(preparedStatement);
}
});
}
}
return result;
CompletableFuture<PreparedStatement> rv = new CompletableFuture<>();
Copy link
Contributor

@tolbertam tolbertam Feb 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're creating a brand new future, If the CacheEntry already exists and is completed, won't we return a CompletableFuture that never completes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrote a test that reproduces this, so it does seem to be a problem:

  @Test
  public void should_complete_if_already_prepared() throws Exception {
    CqlSession session = SessionUtils.newSession(ccmRule, sessionRule.keyspace());
    CqlPrepareAsyncProcessor processor = findProcessor(session);
    Cache<?, ?> cache = processor.getCache();
    assertThat(cache.size()).isEqualTo(0);

    // Prepare a statement and then wait for it to complete
    String cql = "select v from test_table_1 where k = ?";
    CompletableFuture<PreparedStatement> cf1 = toCompletableFuture(session, cql);
    assertThat(cache.size()).isEqualTo(1);

    CqlPrepareAsyncProcessor.CacheEntry entry =
        (CqlPrepareAsyncProcessor.CacheEntry) Iterables.get(cache.asMap().values(), 0);
    PreparedStatement stmt = entry.waitForResult();
    assertThat(cf1.isDone()).isTrue();
    assertThat(cf1.join()).isEqualTo(stmt);

    // Prepare the same prepared statement, which should be completed immediately since it was previously prepared.
    CompletableFuture<PreparedStatement> cf2 = toCompletableFuture(session, cql);
    // cache should not grow
    assertThat(cache.size()).isEqualTo(1);
    CqlPrepareAsyncProcessor.CacheEntry newEntry =
        (CqlPrepareAsyncProcessor.CacheEntry) Iterables.get(cache.asMap().values(), 0);
    // Strictly the same entry in the cache.
    assertThat(entry).isSameAs(newEntry);
    // Note: made futures public just to test this, not necessary, just for demonstrating CacheEntry has this future
    assertThat(newEntry.futures).contains(cf2);
    // Future should be complete (where the test fails)
    assertThat(cf2.isDone()).isTrue();
    assertThat(cf2.join()).isEqualTo(stmt);
  }

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another possible concern here: Futures get added in process which will be handled in the application threads, while whenComplete gets handled in an io thread. I think even fixing the issue above, technically without additional coordination/synchronization you could have a race where a new Future gets added in process but never gets marked completed in whenComplete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with both points. This impl won't work as it stands... I'm wondering if there's anything that can be salvaged here or if this is just an outright dumpster fire and it's time to move on. 😞

CacheEntry entry =
cache.get(
request,
() -> {
CacheEntry newEntry = new CacheEntry();
newEntry.addFuture(rv);
newEntry.tryStart(request, session, context, sessionLogPrefix);
return newEntry;
});

// We don't know whether we're dealing with a newly-created entry or one that was
// already cached so try the future insert again. We wind up duoing an extra hash op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duoing -> doing

// on the initial insert this way but that's a relatively small price to pay.
entry.addFuture(rv);

return rv;
} catch (ExecutionException e) {
return CompletableFutures.failedFuture(e.getCause());
}
Expand All @@ -170,7 +222,7 @@ public CompletionStage<PreparedStatement> newFailure(RuntimeException error) {
return CompletableFutures.failedFuture(error);
}

public Cache<PrepareRequest, CompletableFuture<PreparedStatement>> getCache() {
public Cache<PrepareRequest, CacheEntry> getCache() {
return cache;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import com.datastax.oss.driver.internal.core.util.concurrent.BlockingOperation;
import com.datastax.oss.driver.internal.core.util.concurrent.CompletableFutures;
import com.datastax.oss.driver.shaded.guava.common.cache.Cache;
import java.util.concurrent.CompletableFuture;
import net.jcip.annotations.ThreadSafe;

@ThreadSafe
Expand Down Expand Up @@ -62,7 +61,7 @@ public PreparedStatement process(
asyncProcessor.process(request, session, context, sessionLogPrefix));
}

public Cache<PrepareRequest, CompletableFuture<PreparedStatement>> getCache() {
public Cache<PrepareRequest, CqlPrepareAsyncProcessor.CacheEntry> getCache() {
return asyncProcessor.getCache();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import com.datastax.oss.driver.api.core.config.DefaultDriverOption;
import com.datastax.oss.driver.api.core.config.DriverConfigLoader;
import com.datastax.oss.driver.api.core.context.DriverContext;
import com.datastax.oss.driver.api.core.cql.PrepareRequest;
import com.datastax.oss.driver.api.core.cql.PreparedStatement;
import com.datastax.oss.driver.api.core.metrics.DefaultSessionMetric;
import com.datastax.oss.driver.api.core.session.ProgrammaticArguments;
import com.datastax.oss.driver.api.core.session.SessionBuilder;
Expand All @@ -41,7 +39,6 @@
import com.datastax.oss.driver.internal.core.session.BuiltInRequestProcessors;
import com.datastax.oss.driver.internal.core.session.RequestProcessor;
import com.datastax.oss.driver.internal.core.session.RequestProcessorRegistry;
import com.datastax.oss.driver.shaded.guava.common.cache.CacheBuilder;
import com.datastax.oss.driver.shaded.guava.common.cache.RemovalListener;
import com.datastax.oss.driver.shaded.guava.common.util.concurrent.Uninterruptibles;
import com.google.common.collect.ImmutableList;
Expand All @@ -54,7 +51,6 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -119,12 +115,12 @@ private static class TestCqlPrepareAsyncProcessor extends CqlPrepareAsyncProcess
private static final Logger LOG =
LoggerFactory.getLogger(PreparedStatementCachingIT.TestCqlPrepareAsyncProcessor.class);

private static RemovalListener<PrepareRequest, CompletableFuture<PreparedStatement>>
buildCacheRemoveCallback(@NonNull Optional<DefaultDriverContext> context) {
private static RemovalListener<Object, Object> buildCacheRemoveCallback(
@NonNull Optional<DefaultDriverContext> context) {
return (evt) -> {
try {
CompletableFuture<PreparedStatement> future = evt.getValue();
ByteBuffer queryId = Uninterruptibles.getUninterruptibly(future).getId();
CacheEntry entry = (CacheEntry) evt.getValue();
ByteBuffer queryId = entry.waitForResult().getId();
context.ifPresent(
ctx -> ctx.getEventBus().fire(new PreparedStatementRemovalEvent(queryId)));
} catch (Exception e) {
Expand All @@ -136,9 +132,7 @@ private static class TestCqlPrepareAsyncProcessor extends CqlPrepareAsyncProcess
public TestCqlPrepareAsyncProcessor(@NonNull Optional<DefaultDriverContext> context) {
// Default CqlPrepareAsyncProcessor uses weak values here as well. We avoid doing so
// to prevent cache entries from unexpectedly disappearing mid-test.
super(
CacheBuilder.newBuilder().removalListener(buildCacheRemoveCallback(context)).build(),
context);
super(context, builder -> builder.removalListener(buildCacheRemoveCallback(context)));
}
}

Expand Down
Loading