Skip to content
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ internal class PortBlocker : IDisposable
private const int MaxAttempts = 16;
private Socket _shadowSocket;
public Socket MainSocket { get; }
public Socket SecondarySocket => _shadowSocket;

public int Port;

public PortBlocker(Func<Socket> socketFactory)
{
Expand All @@ -126,7 +129,11 @@ public PortBlocker(Func<Socket> socketFactory)
_shadowSocket = new Socket(shadowAddress.AddressFamily, MainSocket.SocketType, MainSocket.ProtocolType);
success = TryBindWithoutReuseAddress(_shadowSocket, shadowEndPoint, out _);

if (success) break;
if (success)
{
Port = port;
break;
}
}
catch (SocketException)
{
Expand Down
6 changes: 6 additions & 0 deletions src/libraries/System.Net.Sockets/ref/System.Net.Sockets.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

namespace System.Net.Sockets
{
public enum ConnectAlgorithm
{
Default = 0,
Parallel = 1,
}
public enum IOControlCode : long
{
[System.Runtime.Versioning.SupportedOSPlatformAttribute("windows")]
Expand Down Expand Up @@ -343,6 +348,7 @@ public void Connect(string host, int port) { }
public System.Threading.Tasks.ValueTask ConnectAsync(System.Net.IPAddress[] addresses, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public bool ConnectAsync(System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e) { throw null; }
public static bool ConnectAsync(System.Net.Sockets.SocketType socketType, System.Net.Sockets.ProtocolType protocolType, System.Net.Sockets.SocketAsyncEventArgs e, System.Net.Sockets.ConnectAlgorithm connectAlgorithm) { throw null; }
public System.Threading.Tasks.Task ConnectAsync(string host, int port) { throw null; }
public System.Threading.Tasks.ValueTask ConnectAsync(string host, int port, System.Threading.CancellationToken cancellationToken) { throw null; }
public void Disconnect(bool reuseSocket) { }
Expand Down
3 changes: 3 additions & 0 deletions src/libraries/System.Net.Sockets/src/Resources/Strings.resx
Original file line number Diff line number Diff line change
Expand Up @@ -315,4 +315,7 @@
<data name="net_sockets_address_small" xml:space="preserve">
<value>Provided SocketAddress is too small for given AddressFamily.</value>
</data>
<data name="net_sockets_invalid_connect_algorithm" xml:space="preserve">
<value>Provided ConnectAlgorithm {0} is not valid.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

<ItemGroup Condition="'$(TargetPlatformIdentifier)' != ''">
<!-- All configurations -->
<Compile Include="System\Net\Sockets\ConnectAlgorithm.cs" />
<Compile Include="System\Net\Sockets\SocketReceiveFromResult.cs" />
<Compile Include="System\Net\Sockets\SocketReceiveMessageFromResult.cs" />
<Compile Include="System\Net\Sockets\SocketsTelemetry.cs" />
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

namespace System.Net.Sockets
{
/// <summary>
/// Specifies the algorithm used to establish a socket connection.
/// </summary>
public enum ConnectAlgorithm
{
/// <summary>
/// The default connection mechanism, typically sequential processing.
/// </summary>
Default = 0,

/// <summary>
/// Uses a Happy Eyeballs-like algorithm to connect, attempting connections in parallel to improve speed and reliability.
/// </summary>
Parallel = 1,
}
}
13 changes: 10 additions & 3 deletions src/libraries/System.Net.Sockets/src/System/Net/Sockets/Socket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2918,7 +2918,7 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaMul
e.StartOperationConnect(saeaMultiConnectCancelable, userSocket);
try
{
pending = e.DnsConnectAsync(dnsEP, default, default, cancellationToken);
pending = e.DnsConnectAsync(dnsEP, default, default, default, cancellationToken);
}
catch
{
Expand Down Expand Up @@ -2981,9 +2981,16 @@ internal bool ConnectAsync(SocketAsyncEventArgs e, bool userSocket, bool saeaMul
return pending;
}

public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e)
public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e) =>
ConnectAsync(socketType, protocolType, e, ConnectAlgorithm.Default);
public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType, SocketAsyncEventArgs e, ConnectAlgorithm connectAlgorithm)
{
ArgumentNullException.ThrowIfNull(e);
if (connectAlgorithm != ConnectAlgorithm.Default &&
connectAlgorithm != ConnectAlgorithm.Parallel)
{
throw new ArgumentException(SR.Format(SR.net_sockets_invalid_connect_algorithm, connectAlgorithm), nameof(connectAlgorithm));
}

if (e.HasMultipleBuffers)
{
Expand All @@ -3005,7 +3012,7 @@ public static bool ConnectAsync(SocketType socketType, ProtocolType protocolType
e.StartOperationConnect(saeaMultiConnectCancelable: true, userSocket: false);
try
{
pending = e.DnsConnectAsync(dnsEP, socketType, protocolType, cancellationToken: default);
pending = e.DnsConnectAsync(dnsEP, socketType, protocolType, connectAlgorithm, cancellationToken: default);
}
catch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -676,9 +676,10 @@ internal void FinishOperationAsyncFailure(SocketError socketError, int bytesTran
/// <param name="endPoint">The DNS end point to which to connect.</param>
/// <param name="socketType">The SocketType to use to construct new sockets, if necessary.</param>
/// <param name="protocolType">The ProtocolType to use to construct new sockets, if necessary.</param>
/// <param name="connectAlgorithm">Connect strategy.</param>
/// <param name="cancellationToken">The CancellationToken.</param>
/// <returns>true if the operation is pending; otherwise, false if it's already completed.</returns>
internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType, CancellationToken cancellationToken)
internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, ProtocolType protocolType, ConnectAlgorithm connectAlgorithm, CancellationToken cancellationToken)
{
Debug.Assert(endPoint.AddressFamily == AddressFamily.Unspecified ||
endPoint.AddressFamily == AddressFamily.InterNetwork ||
Expand All @@ -691,9 +692,15 @@ internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, Proto
cancellationToken = _multipleConnectCancellation.Token;
}

// We can do parallel connect only if socket was not specified and when there is at least one address of each AF.
bool parallelConnect = connectAlgorithm == ConnectAlgorithm.Parallel &&
_currentSocket == null &&
endPoint.AddressFamily == AddressFamily.Unspecified &&
Socket.OSSupportsIPv6 && Socket.OSSupportsIPv4;

// In .NET 5 and earlier, the APM implementation allowed for synchronous exceptions from this to propagate
// synchronously. This call is made here rather than in the Core async method below to preserve that behavior.
Task<IPAddress[]> addressesTask = Dns.GetHostAddressesAsync(endPoint.Host, endPoint.AddressFamily, cancellationToken);
Task<IPAddress[]> addressesTask = Dns.GetHostAddressesAsync(endPoint.Host, parallelConnect ? AddressFamily.InterNetwork : endPoint.AddressFamily, cancellationToken);

// Initialize the internal event args instance. It needs to be initialized with `this` instance's buffer
// so that it may be used as part of receives during a connect.
Expand All @@ -705,16 +712,30 @@ internal bool DnsConnectAsync(DnsEndPoint endPoint, SocketType socketType, Proto
// by a try/catch. Thus we ignore the result. We avoid an "async void" method so as to skip the implicit SynchronizationContext
// interactions async void methods entail.
#pragma warning disable CA2025
_ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, cancellationToken);
if (parallelConnect)
{
var state = new ParallelMultiConnectSocketState(this);
var internalArgsV6 = new MultiConnectSocketAsyncEventArgs();
internalArgsV6.CopyBufferFrom(this);

Task<IPAddress[]> addressesTask6 = Dns.GetHostAddressesAsync(endPoint.Host, AddressFamily.InterNetworkV6, cancellationToken);
_ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, state, cancellationToken);
_ = Core(internalArgsV6, addressesTask6, endPoint.Port, socketType, protocolType, state, cancellationToken);
return true;
}
else
{
_ = Core(internalArgs, addressesTask, endPoint.Port, socketType, protocolType, null, cancellationToken);
}
#pragma warning restore

// Determine whether the async operation already completed and stored the results into `this`.
// If we reached this point and the operation hasn't yet stored the results, then it's considered
// pending. If by the time we get here it has stored the results, it's considered completed.
// The callback won't invoke the Completed event if it gets there first.
return internalArgs.ReachedCoordinationPointFirst();
// Determine whether the async operation already completed and stored the results into `this`.
// If we reached this point and the operation hasn't yet stored the results, then it's considered
// pending. If by the time we get here it has stored the results, it's considered completed.
// The callback won't invoke the Completed event if it gets there first.
return internalArgs.ReachedCoordinationPointFirst();

async Task Core(MultiConnectSocketAsyncEventArgs internalArgs, Task<IPAddress[]> addressesTask, int port, SocketType socketType, ProtocolType protocolType, CancellationToken cancellationToken)
async Task Core(MultiConnectSocketAsyncEventArgs internalArgs, Task<IPAddress[]> addressesTask, int port, SocketType socketType, ProtocolType protocolType, ParallelMultiConnectSocketState? parallelState, CancellationToken cancellationToken)
{
Socket? tempSocketIPv4 = null, tempSocketIPv6 = null;
Exception? caughtException = null;
Expand Down Expand Up @@ -843,35 +864,52 @@ caughtException is OperationCanceledException ||
}
}

// Store the results.
if (caughtException != null)
if (parallelState != null)
{
SetResults(caughtException, 0, SocketFlags.None);
_currentSocket?.UpdateStatusAfterSocketError(_socketError);
// If we do parallel connect use SetResults from there to arbiter competing results.
if (caughtException != null)
{
parallelState.SetResults(null, _socketError, 0, SocketFlags.None, caughtException);
}
else
{
parallelState.SetResults(internalArgs.ConnectSocket, internalArgs.SocketError, internalArgs.BytesTransferred, internalArgs.SocketFlags, null);
}
internalArgs.Dispose();

}
else
{
SetResults(SocketError.Success, internalArgs.BytesTransferred, internalArgs.SocketFlags);
_connectSocket = _currentSocket = internalArgs.ConnectSocket!;
}
// Store the results.
if (caughtException != null)
{
SetResults(caughtException, 0, SocketFlags.None);
_currentSocket?.UpdateStatusAfterSocketError(_socketError);
}
else
{
SetResults(SocketError.Success, internalArgs.BytesTransferred, internalArgs.SocketFlags);
_connectSocket = _currentSocket = internalArgs.ConnectSocket!;
}

// Complete the operation.
if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(_connectSocket?.SocketType, SocketAsyncOperation.Connect, internalArgs.BytesTransferred);
// Complete the operation.
if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(_connectSocket?.SocketType, SocketAsyncOperation.Connect, internalArgs.BytesTransferred);

Complete();
Complete();

// Clean up after our temporary arguments.
internalArgs.Dispose();
// Clean up after our temporary arguments.
internalArgs.Dispose();

// If the caller is treating this operation as pending, own the completion.
if (!internalArgs.ReachedCoordinationPointFirst())
{
// Regardless of _flowExecutionContext, context will have been flown through this async method, as that's part
// of what async methods do. As such, we're already on whatever ExecutionContext is the right one to invoke
// the completion callback. This method may have even mutated the ExecutionContext, in which case for telemetry
// we need those mutations to be surfaced as part of this callback, so that logging performed here sees those
// mutations (e.g. to the current Activity).
OnCompleted(this);
// If the caller is treating this operation as pending, own the completion.
if (!internalArgs.ReachedCoordinationPointFirst())
{
// Regardless of _flowExecutionContext, context will have been flown through this async method, as that's part
// of what async methods do. As such, we're already on whatever ExecutionContext is the right one to invoke
// the completion callback. This method may have even mutated the ExecutionContext, in which case for telemetry
// we need those mutations to be surfaced as part of this callback, so that logging performed here sees those
// mutations (e.g. to the current Activity).
OnCompleted(this);
}
}
}
}
Expand All @@ -891,11 +929,61 @@ public MultiConnectSocketAsyncEventArgs() : base(unsafeSuppressExecutionContextF
public short Version => _mrvtsc.Version;
public void Reset() => _mrvtsc.Reset();

protected override void OnCompleted(SocketAsyncEventArgs e) => _mrvtsc.SetResult(true);
protected override void OnCompleted(SocketAsyncEventArgs e) =>_mrvtsc.SetResult(true);

public bool ReachedCoordinationPointFirst() => !Interlocked.Exchange(ref _isCompleted, true);
}

private sealed class ParallelMultiConnectSocketState
{
private bool _isCompleted;
private int _count;
private SocketAsyncEventArgs _saea;

public ParallelMultiConnectSocketState(SocketAsyncEventArgs saea)
{
_saea = saea;
}
public bool ReachedCoordinationPointFirst() => !Interlocked.Exchange(ref _isCompleted, true);

public void SetResults(Socket? socket, SocketError socketError, int bytesTransferred, SocketFlags flags, Exception? exception)
{
int count = Interlocked.Increment(ref _count);
bool firstFinal = false;

if (socketError == SocketError.Success)
{
firstFinal = ReachedCoordinationPointFirst();
if (firstFinal)
{
_saea._connectSocket = _saea._currentSocket = socket;
_saea.SetResults(SocketError.Success, bytesTransferred, flags);
return;
}
}
else if (count == 2) // We ignore failures on first socket since we have one more pending.
{
firstFinal = ReachedCoordinationPointFirst();
if (firstFinal)
{
// We ignore failures on first socket since we have one more pending.
_saea.SetResults(exception!, 0, SocketFlags.None);
_saea._currentSocket?.UpdateStatusAfterSocketError(socketError);
}
}

if (firstFinal)
{
// If this is the first final result, we need to complete the operation and release underlying SocketAsyncEventArgs
_saea.Complete();
//if (SocketsTelemetry.Log.IsEnabled()) LogBytesTransferEvents(socket?.SocketType, SocketAsyncOperation.Connect, bytesTransferred);
// signal caller we are done.
_saea.OnCompleted(_saea);
}
}
}


internal void FinishOperationSyncSuccess(int bytesTransferred, SocketFlags flags)
{
SetResults(SocketError.Success, bytesTransferred, flags);
Expand Down
Loading
Loading