Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,20 @@ public class ClientHolder {

private final AtomicBoolean connected = new AtomicBoolean(false);
private final AtomicReference<CurrentConnection> connectionHolder = new AtomicReference<>();
private final Uni<RabbitMQClient> connection;
private final AtomicReference<Context> rootContext;
private final AtomicReference<CompletionStage<RabbitMQClient>> connectionStage = new AtomicReference<>();

private final Vertx vertx;
private final RabbitMQConnectorCommonConfiguration configuration;

public ClientHolder(RabbitMQClient client,
RabbitMQConnectorCommonConfiguration configuration,
Vertx vertx,
Context root) {
this.client = client;
this.configuration = configuration;
this.vertx = vertx;
this.connection = Uni.createFrom().deferred(() -> client.start()
.onSubscription().invoke(() -> {
connected.set(true);
log.connectionEstablished(configuration.getChannel());
})
.onItem().transform(ignored -> {
connectionHolder
.set(new CurrentConnection(client, root == null ? Vertx.currentContext() : root));

// handle the case we are already disconnected.
if (!client.isConnected() || connectionHolder.get() == null) {
// Throwing the exception would trigger a retry.
connectionHolder.set(null);
throw ex.illegalStateConnectionDisconnected();
}
return client;
})
.onFailure().invoke(log::unableToConnectToBroker)
.onFailure().invoke(t -> {
connectionHolder.set(null);
log.unableToRecoverFromConnectionDisruption(t);
}))
.memoize().until(() -> {
CurrentConnection connection = connectionHolder.get();
if (connection == null) {
return true;
}
return !connection.client.isConnected();
});

this.rootContext = new AtomicReference<>(root);
}

public static CompletionStage<Void> runOnContext(Context context, IncomingRabbitMQMessage<?> msg,
Expand Down Expand Up @@ -89,6 +63,17 @@ public Context getContext() {
}
}

public void ensureContext(Context context) {
if (context == null) {
return;
}
rootContext.compareAndSet(null, context);
CurrentConnection connection = connectionHolder.get();
if (connection != null && connection.context == null) {
connectionHolder.compareAndSet(connection, new CurrentConnection(connection.client, context));
}
}

public RabbitMQClient client() {
return client;
}
Expand All @@ -112,7 +97,29 @@ public Vertx getVertx() {

@CheckReturnValue
public Uni<RabbitMQClient> getOrEstablishConnection() {
return connection;
CompletionStage<RabbitMQClient> existing = connectionStage.get();
if (existing != null) {
if (!existing.toCompletableFuture().isDone() || client.isConnected()) {
return Uni.createFrom().completionStage(existing);
}
connectionStage.compareAndSet(existing, null);
}

for (;;) {
CompletionStage<RabbitMQClient> current = connectionStage.get();
if (current != null) {
return Uni.createFrom().completionStage(current);
}
CompletionStage<RabbitMQClient> created = createConnectionUni().subscribeAsCompletionStage();
if (connectionStage.compareAndSet(null, created)) {
created.whenComplete((result, error) -> {
if (error != null) {
connectionStage.compareAndSet(created, null);
}
});
return Uni.createFrom().completionStage(created);
}
}
}

private static class CurrentConnection {
Expand All @@ -126,4 +133,32 @@ private CurrentConnection(RabbitMQClient client, Context context) {
}
}

private Uni<RabbitMQClient> createConnectionUni() {
return Uni.createFrom().deferred(() -> client.start()
.onSubscription().invoke(() -> {
connected.set(true);
log.connectionEstablished(configuration.getChannel());
})
.onItem().transform(ignored -> {
Context context = rootContext.get();
if (context == null) {
context = Vertx.currentContext();
}
connectionHolder.set(new CurrentConnection(client, context));

// handle the case we are already disconnected.
if (!client.isConnected() || connectionHolder.get() == null) {
// Throwing the exception would trigger a retry.
connectionHolder.set(null);
throw ex.illegalStateConnectionDisconnected();
}
return client;
})
.onFailure().invoke(log::unableToConnectToBroker)
.onFailure().invoke(t -> {
connectionHolder.set(null);
log.unableToRecoverFromConnectionDisruption(t);
}));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,26 @@ public <V> CompletionStage<Void> handle(IncomingRabbitMQMessage<V> message, Meta
public IncomingRabbitMQMessage(RabbitMQMessage delegate, ClientHolder holder,
RabbitMQFailureHandler onNack,
RabbitMQAckHandler onAck, String contentTypeOverride) {
this(delegate.getDelegate(), holder, onNack, onAck, contentTypeOverride);
this(delegate.getDelegate(), holder, holder.getContext(), onNack, onAck, contentTypeOverride);
}

public IncomingRabbitMQMessage(RabbitMQMessage delegate, ClientHolder holder, Context context,
RabbitMQFailureHandler onNack,
RabbitMQAckHandler onAck, String contentTypeOverride) {
this(delegate.getDelegate(), holder, context, onNack, onAck, contentTypeOverride);
}

IncomingRabbitMQMessage(io.vertx.rabbitmq.RabbitMQMessage msg, ClientHolder holder,
RabbitMQFailureHandler onNack, RabbitMQAckHandler onAck, String contentTypeOverride) {
this(msg, holder, holder.getContext(), onNack, onAck, contentTypeOverride);
}

IncomingRabbitMQMessage(io.vertx.rabbitmq.RabbitMQMessage msg, ClientHolder holder, Context context,
RabbitMQFailureHandler onNack, RabbitMQAckHandler onAck, String contentTypeOverride) {
this.message = msg;
this.deliveryTag = msg.envelope().getDeliveryTag();
this.holder = holder;
this.context = holder.getContext();
this.context = context != null ? context : holder.getContext();
this.contentTypeOverride = contentTypeOverride;
this.rabbitMQMetadata = new IncomingRabbitMQMetadata(this.message);
this.onNack = onNack;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
package io.smallrye.reactive.messaging.rabbitmq;

import static io.smallrye.reactive.messaging.annotations.ConnectorAttribute.Direction.*;
import static io.smallrye.reactive.messaging.annotations.ConnectorAttribute.Direction.INCOMING;
import static io.smallrye.reactive.messaging.annotations.ConnectorAttribute.Direction.INCOMING_AND_OUTGOING;
import static io.smallrye.reactive.messaging.annotations.ConnectorAttribute.Direction.OUTGOING;
import static io.smallrye.reactive.messaging.rabbitmq.i18n.RabbitMQExceptions.ex;
import static io.smallrye.reactive.messaging.rabbitmq.i18n.RabbitMQLogging.log;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Flow;
import java.util.concurrent.atomic.AtomicInteger;

import jakarta.annotation.Priority;
import jakarta.enterprise.context.ApplicationScoped;
Expand Down Expand Up @@ -38,6 +43,7 @@
import io.smallrye.reactive.messaging.rabbitmq.fault.RabbitMQFailureHandler;
import io.smallrye.reactive.messaging.rabbitmq.internals.IncomingRabbitMQChannel;
import io.smallrye.reactive.messaging.rabbitmq.internals.OutgoingRabbitMQChannel;
import io.smallrye.reactive.messaging.rabbitmq.internals.RabbitMQClientHelper;
import io.vertx.mutiny.core.Vertx;
import io.vertx.mutiny.rabbitmq.RabbitMQClient;
import io.vertx.rabbitmq.RabbitMQOptions;
Expand All @@ -64,6 +70,7 @@
@ConnectorAttribute(name = "reconnect-interval", direction = INCOMING_AND_OUTGOING, description = "The interval (in seconds) between two reconnection attempts", type = "int", alias = "rabbitmq-reconnect-interval", defaultValue = "10")
@ConnectorAttribute(name = "network-recovery-interval", direction = INCOMING_AND_OUTGOING, description = "How long (ms) will automatic recovery wait before attempting to reconnect", type = "int", defaultValue = "5000")
@ConnectorAttribute(name = "user", direction = INCOMING_AND_OUTGOING, description = "The user name to use when connecting to the broker", type = "string", defaultValue = "guest")
@ConnectorAttribute(name = "shared-connection-name", direction = INCOMING_AND_OUTGOING, description = "Optional identifier allowing multiple channels to share the same RabbitMQ connection when set to the same value", type = "string")
@ConnectorAttribute(name = "include-properties", direction = INCOMING_AND_OUTGOING, description = "Whether to include properties when a broker message is passed on the event bus", type = "boolean", defaultValue = "false")
@ConnectorAttribute(name = "requested-channel-max", direction = INCOMING_AND_OUTGOING, description = "The initially requested maximum channel number", type = "int", defaultValue = "2047")
@ConnectorAttribute(name = "requested-heartbeat", direction = INCOMING_AND_OUTGOING, description = "The initially requested heartbeat interval (seconds), zero for none", type = "int", defaultValue = "60")
Expand Down Expand Up @@ -164,7 +171,8 @@ public class RabbitMQConnector implements InboundConnector, OutboundConnector, H
Instance<RabbitMQFailureHandler.Factory> failureHandlerFactories;
private List<IncomingRabbitMQChannel> incomings = new CopyOnWriteArrayList<>();
private List<OutgoingRabbitMQChannel> outgoings = new CopyOnWriteArrayList<>();
private Map<String, RabbitMQClient> clients = new ConcurrentHashMap<>();
private Map<String, ClientRegistration> clientRegistrations = new ConcurrentHashMap<>();
private Map<String, SharedClient> sharedClients = new ConcurrentHashMap<>();

@Inject
@Any
Expand Down Expand Up @@ -263,27 +271,28 @@ public void terminate(
outgoing.terminate();
}

clients.forEach((channel, rabbitMQClient) -> rabbitMQClient.stopAndAwait());
clients.clear();
List<String> registeredChannels = new ArrayList<>(clientRegistrations.keySet());
for (String channel : registeredChannels) {
releaseClient(channel);
}
sharedClients.clear();
}

public Vertx vertx() {
return executionHolder.vertx();
}

public void registerClient(String channel, RabbitMQClient client) {
RabbitMQClient old = clients.put(channel, client);
if (old != null) {
old.stopAndForget();
}
}

public void reportIncomingFailure(String channel, Throwable reason) {
log.failureReported(channel, reason);
RabbitMQClient client = clients.remove(channel);
if (client != null) {
// Called on vertx context, we can't block: stop clients without waiting
client.stopAndForget();
ClientRegistration registration = clientRegistrations.remove(channel);
if (registration == null) {
return;
}

if (registration.shared) {
releaseSharedClient(registration.key, false);
} else {
stopClient(registration.holder.client(), false);
}
}

Expand All @@ -306,4 +315,115 @@ public Instance<CredentialsProvider> credentialsProviders() {
public Instance<Map<String, ?>> configMaps() {
return configMaps;
}

public ClientHolder getClientHolder(RabbitMQConnectorCommonConfiguration config, io.vertx.mutiny.core.Context context) {
ClientRegistration existing = clientRegistrations.get(config.getChannel());
if (existing != null) {
return existing.holder;
}

return config.getSharedConnectionName()
.map(name -> getOrCreateSharedHolder(config, context, name))
.orElseGet(() -> createAndRegisterHolder(config, context, config.getChannel(), false));
}

private ClientHolder createAndRegisterHolder(RabbitMQConnectorCommonConfiguration config,
io.vertx.mutiny.core.Context context, String key, boolean shared) {
ClientHolder holder = new ClientHolder(RabbitMQClientHelper.createClient(this, config), config, vertx(), context);
clientRegistrations.put(config.getChannel(), new ClientRegistration(holder, shared, key));
return holder;
}

private ClientHolder getOrCreateSharedHolder(RabbitMQConnectorCommonConfiguration config,
io.vertx.mutiny.core.Context context, String name) {
RabbitMQOptions options = RabbitMQClientHelper.buildClientOptions(this, config);
String fingerprint = RabbitMQClientHelper.computeConnectionFingerprint(options);
SharedClient shared = sharedClients.compute(name, (key, existing) -> {
if (existing != null) {
if (!existing.fingerprint.equals(fingerprint)) {
throw ex.illegalStateSharedConnectionConfigMismatch(name);
}
existing.retain();
if (context != null) {
existing.holder.ensureContext(context);
}
return existing;
}
return new SharedClient(name, new ClientHolder(
RabbitMQClient.create(vertx(), options),
config,
vertx(),
context), fingerprint);
});
clientRegistrations.put(config.getChannel(), new ClientRegistration(shared.holder, true, name));
return shared.holder;
}

public void releaseClient(String channel) {
ClientRegistration registration = clientRegistrations.remove(channel);
if (registration == null) {
return;
}

if (registration.shared) {
releaseSharedClient(registration.key, true);
} else {
stopClient(registration.holder.client(), true);
}
}

private void releaseSharedClient(String sharedName, boolean await) {
SharedClient shared = sharedClients.get(sharedName);
if (shared == null) {
return;
}
if (shared.release()) {
sharedClients.remove(sharedName, shared);
stopClient(shared.holder.client(), await);
}
}

private void stopClient(RabbitMQClient client, boolean await) {
if (client == null) {
return;
}
if (await) {
client.stopAndAwait();
} else {
client.stopAndForget();
}
}

private static final class ClientRegistration {
final ClientHolder holder;
final boolean shared;
final String key;

private ClientRegistration(ClientHolder holder, boolean shared, String key) {
this.holder = holder;
this.shared = shared;
this.key = key;
}
}

private static final class SharedClient {
final String name;
final ClientHolder holder;
final String fingerprint;
final AtomicInteger references = new AtomicInteger(1);

private SharedClient(String name, ClientHolder holder, String fingerprint) {
this.name = name;
this.holder = holder;
this.fingerprint = fingerprint;
}

private void retain() {
references.incrementAndGet();
}

private boolean release() {
return references.decrementAndGet() == 0;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,7 @@ public interface RabbitMQExceptions {

@Message(id = 16009, value = "Unable to create a client, probably a config error")
IllegalStateException illegalStateUnableToCreateClient(@Cause Throwable t);

@Message(id = 16010, value = "Shared connection '%s' has mismatched configuration; ensure all channels using the same shared-connection-name have identical connection settings")
IllegalStateException illegalStateSharedConnectionConfigMismatch(String sharedConnectionName);
}
Loading