Skip to content

Commit decd2d4

Browse files
authored
Merge pull request #415 from 0xced/SelectedPreferredInstance-Fastest
Select the fastest (Kdc/Kpasswd) instance instead of at random
2 parents 95c800c + 3bf4b3d commit decd2d4

2 files changed

Lines changed: 80 additions & 13 deletions

File tree

Kerberos.NET/Client/Transport/KerberosTransportBase.cs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System;
77
using System.Collections.Generic;
88
using System.Linq;
9+
using System.Net.NetworkInformation;
910
using System.Threading;
1011
using System.Threading.Tasks;
1112
using Kerberos.NET.Asn1;
@@ -19,15 +20,15 @@ namespace Kerberos.NET.Transport
1920
{
2021
public abstract class KerberosTransportBase : IKerberosTransport2, IDisposable
2122
{
22-
private static readonly Random Random = new Random();
23-
2423
protected KerberosTransportBase(ILoggerFactory logger)
2524
{
2625
this.ClientRealmService = new ClientDomainService(logger);
2726
}
2827

2928
private bool disposedValue;
3029

30+
private DnsRecord fastest;
31+
3132
public virtual bool TransportFailed { get; set; }
3233

3334
public virtual KerberosTransportException LastError { get; set; }
@@ -165,34 +166,58 @@ public void Dispose()
165166
protected virtual async Task<DnsRecord> LocatePreferredKdc(string domain, string servicePrefix)
166167
{
167168
var results = await this.LocateKdc(domain, servicePrefix);
168-
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
169+
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKerberosPort);
169170
}
170171

171172
protected virtual async Task<DnsRecord> LocatePreferredKpasswd(string domain, string servicePrefix)
172173
{
173174
var results = await this.LocateKpasswd(domain, servicePrefix);
174-
return SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
175+
return await SelectedPreferredInstance(domain, servicePrefix, results, ClientDomainService.DefaultKpasswdPort);
175176
}
176177

177-
protected virtual DnsRecord SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
178+
protected virtual async Task<DnsRecord> SelectedPreferredInstance(string domain, string servicePrefix, IEnumerable<DnsRecord> results, int defaultPort)
178179
{
179-
results = results.Where(r => r.Name.StartsWith(servicePrefix));
180+
if (results.Contains(fastest, DnsRecordComparer.Instance))
181+
{
182+
return fastest;
183+
}
180184

181-
var rand = Random.Next(0, results?.Count() ?? 0);
185+
fastest = await results.Where(r => r.Name.StartsWith(servicePrefix)).GetFastestAsync(PingAsync);
186+
return fastest ?? throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
187+
}
182188

183-
var srv = results?.ElementAtOrDefault(rand);
189+
private async Task<DnsRecord> PingAsync(DnsRecord record, CancellationToken cancellationToken)
190+
{
191+
using var ping = new Ping();
192+
cancellationToken.Register(() => ping.SendAsyncCancel());
193+
var reply = await ping.SendPingAsync(record.Target, Convert.ToInt32(ConnectTimeout.TotalMilliseconds));
194+
return reply.Status == IPStatus.Success ? record : throw new PingException($"Ping {record.Target} returned {reply.Status}");
195+
}
184196

185-
if (srv == null)
197+
private class DnsRecordComparer : IEqualityComparer<DnsRecord>
198+
{
199+
public static readonly DnsRecordComparer Instance = new();
200+
201+
private DnsRecordComparer()
186202
{
187-
throw new KerberosTransportException($"Cannot locate SRV record for {domain}");
188203
}
189204

190-
if (srv.Port <= 0)
205+
public bool Equals(DnsRecord x, DnsRecord y)
191206
{
192-
srv.Port = defaultPort;
207+
if (ReferenceEquals(x, y)) return true;
208+
if (x is null) return false;
209+
if (y is null) return false;
210+
if (x.GetType() != y.GetType()) return false;
211+
return x.Target == y.Target && x.Port == y.Port;
193212
}
194213

195-
return srv;
214+
public int GetHashCode(DnsRecord obj)
215+
{
216+
unchecked
217+
{
218+
return ((obj.Target != null ? obj.Target.GetHashCode() : 0) * 397) ^ obj.Port;
219+
}
220+
}
196221
}
197222
}
198223
}

Kerberos.NET/TaskExtensions.cs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// -----------------------------------------------------------------------
2+
// Licensed to The .NET Foundation under one or more agreements.
3+
// The .NET Foundation licenses this file to you under the MIT license.
4+
// -----------------------------------------------------------------------
5+
6+
using System;
7+
using System.Collections.Generic;
8+
using System.Linq;
9+
using System.Threading;
10+
using System.Threading.Tasks;
11+
12+
internal static class TaskExtensions
13+
{
14+
public static async Task<TResult> GetFastestAsync<TSource, TResult>(this IEnumerable<TSource> source, Func<TSource, CancellationToken, Task<TResult>> task, CancellationToken cancellationToken = default)
15+
{
16+
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
17+
var tasks = new HashSet<Task<TResult>>(source.Select(e => task(e, cts.Token)));
18+
if (tasks.Count == 0)
19+
{
20+
return default;
21+
}
22+
23+
var exceptions = new List<Exception>();
24+
do
25+
{
26+
var completedTask = await Task.WhenAny(tasks);
27+
if (completedTask.Status == TaskStatus.RanToCompletion)
28+
{
29+
cts.Cancel();
30+
return completedTask.Result;
31+
}
32+
33+
if (completedTask.Exception != null)
34+
{
35+
exceptions.AddRange(completedTask.Exception.InnerExceptions);
36+
}
37+
tasks.Remove(completedTask);
38+
} while (tasks.Count > 0);
39+
40+
throw new AggregateException(exceptions);
41+
}
42+
}

0 commit comments

Comments
 (0)