Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import io.netty.resolver.dns.DefaultAuthoritativeDnsServerCache;
import io.netty.resolver.dns.DefaultDnsCache;
import io.netty.resolver.dns.DefaultDnsCnameCache;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.resolver.dns.DnsNameResolverBuilder;
import io.netty.resolver.dns.NameServerComparator;
import io.netty.resolver.dns.NoopAuthoritativeDnsServerCache;
Expand Down Expand Up @@ -127,7 +126,7 @@ final class DefaultDnsClient implements DnsClient {
private static final Cancellable TERMINATED = () -> { };

private final EventLoopAwareNettyIoExecutor nettyIoExecutor;
private final DnsNameResolver resolver;
private final UnderlyingDnsResolver resolver;
private final MinTtlCache ttlCache;
private final long maxTTLNanos;
private final long ttlJitterNanos;
Expand Down Expand Up @@ -225,7 +224,16 @@ final class DefaultDnsClient implements DnsClient {
if (dnsServerAddressStreamProvider != null) {
builder.nameServerProvider(toNettyType(dnsServerAddressStreamProvider));
}
resolver = builder.build();
if (true /* hedging enabled */) { // need to wire this in.
DnsNameResolverBuilderUtils.consolidateCacheSize(id, builder, 0);
resolver = new HedgingDnsNameResolver(
// new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build()), nettyIoExecutor);
// TODO: this is just for hacking together tests.
new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build()), nettyIoExecutor,
HedgingDnsNameResolver.constantTracker(100), HedgingDnsNameResolver.alwaysAllowBudget());
} else {
resolver = new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build());
}
this.resolutionTimeoutMillis = resolutionTimeout != null ? resolutionTimeout.toMillis() :
// Default value is chosen based on a combination of default "timeout" and "attempts" options of
// /etc/resolv.conf: https://man7.org/linux/man-pages/man5/resolv.conf.5.html
Expand Down Expand Up @@ -435,7 +443,7 @@ protected Future<DnsAnswer<HostAndPort>> doDnsQuery(final boolean scheduledQuery
final EventLoop eventLoop = nettyIoExecutor.eventLoopGroup().next();
final Promise<DnsAnswer<HostAndPort>> promise = eventLoop.newPromise();
final Future<List<DnsRecord>> resolveFuture =
resolver.resolveAll(new DefaultDnsQuestion(name, SRV));
resolver.resolveAllQuestion(new DefaultDnsQuestion(name, SRV));
final Future<?> timeoutFuture = resolutionTimeoutMillis == 0L ? null : eventLoop.schedule(() -> {
if (!promise.isDone() && promise.tryFailure(DnsNameResolverTimeoutException.newInstance(
name, resolutionTimeoutMillis, SRV.toString(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
/*
* Copyright © 2024 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.dns.discovery.netty;

import io.servicetalk.concurrent.Cancellable;

import io.netty.handler.codec.dns.DnsQuestion;
import io.netty.handler.codec.dns.DnsRecord;
import io.netty.resolver.dns.DnsNameResolver;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import io.servicetalk.transport.api.IoExecutor;
import io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutor;

import java.net.InetAddress;
import java.util.List;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutors.toEventLoopAwareNettyIoExecutor;
import static io.servicetalk.utils.internal.NumberUtils.ensurePositive;
import static java.lang.Math.max;
import static java.lang.Math.min;

final class HedgingDnsNameResolver implements UnderlyingDnsResolver {

private final UnderlyingDnsResolver delegate;
private final EventLoopAwareNettyIoExecutor executor;
private final PercentileTracker percentile;
private final Budget budget;

HedgingDnsNameResolver(DnsNameResolver delegate, IoExecutor executor) {
this(new NettyDnsNameResolver(delegate), executor);
}

HedgingDnsNameResolver(UnderlyingDnsResolver delegate, IoExecutor executor) {
this(delegate, executor, defaultTracker(), defaultBudget());
}

HedgingDnsNameResolver(UnderlyingDnsResolver delegate, IoExecutor executor,
PercentileTracker percentile, Budget budget) {
this.delegate = delegate;
this.executor = toEventLoopAwareNettyIoExecutor(executor).next();
this.percentile = percentile;
this.budget = budget;
}

@Override
public long queryTimeoutMillis() {
return delegate.queryTimeoutMillis();
}

@Override
public Future<List<DnsRecord>> resolveAllQuestion(DnsQuestion t) {
return setupHedge(delegate::resolveAllQuestion, t);
}

@Override
public Future<List<InetAddress>> resolveAll(String t) {
return setupHedge(delegate::resolveAll, t);
}

@Override
public void close() {
delegate.close();
}

private long currentTimeMillis() {
return executor.currentTime(TimeUnit.MILLISECONDS);
}

private <T, R> Future<R> setupHedge(Function<T, Future<R>> computation, T t) {
// Only add tokens for organic requests and not retries.
budget.deposit();
Future<R> underlyingResult = computation.apply(t);
final long delay = percentile.getValue();
if (delay == Long.MAX_VALUE) {
// basically forever: just return the value.
return underlyingResult;
} else {
final long startTimeMs = currentTimeMillis();
Promise<R> promise = executor.eventLoopGroup().next().newPromise();
Cancellable hedgeTimer = executor.schedule(() -> tryHedge(computation, t, underlyingResult, promise),
delay, TimeUnit.MILLISECONDS);
underlyingResult.addListener(completedFuture -> {
measureRequest(currentTimeMillis() - startTimeMs, completedFuture);
if (complete(underlyingResult, promise)) {
hedgeTimer.cancel();
}
});
return promise;
}
}

private <T, R> void tryHedge(
Function<T, Future<R>> computation, T t, Future<R> original, Promise<R> promise) {
if (!original.isDone() && budget.withdraw()) {
System.out.println("" + System.currentTimeMillis() + ": sending backup request.");
Future<R> backupResult = computation.apply(t);
final long startTime = currentTimeMillis();
backupResult.addListener(done -> {
if (complete(backupResult, promise)) {
original.cancel(true);
measureRequest(currentTimeMillis() - startTime, done);
}
});
promise.addListener(complete -> backupResult.cancel(true));
}
}

private void measureRequest(long durationMs, Future<?> future) {
// Cancelled responses don't count but we do consider failed responses because failure
// is a legitimate response.
if (!future.isCancelled()) {
percentile.addSample(durationMs);
}
}

private <T, R> boolean complete(Future<R> f, Promise<R> p) {
assert f.isDone();
if (f.isSuccess()) {
return p.trySuccess(f.getNow());
} else {
return p.tryFailure(f.cause());
}
}

interface PercentileTracker {
void addSample(long sample);

long getValue();
}

interface Budget {
void deposit();

boolean withdraw();
}

private static final class DefaultPercentileTracker implements PercentileTracker {

private final MovingVariance movingVariance;
private final double multiple;

DefaultPercentileTracker(final double multiple, final int historySize) {
movingVariance = new MovingVariance(historySize);
this.multiple = multiple;
}

@Override
public void addSample(long sample) {
int clipped = Math.max(0, sample > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) sample);
movingVariance.addSample(clipped);
}

@Override
public long getValue() {
return Math.round(movingVariance.mean() + movingVariance.stdev() * multiple);
}
}

private static final class DefaultBudgetImpl implements Budget {

private final int depositAmount;
private final int withDrawAmount;
private final int maxTokens;
private int tokens;

DefaultBudgetImpl(int depositAmount, int withDrawAmount, int maxTokens) {
this(depositAmount, withDrawAmount, maxTokens, 0);
}

DefaultBudgetImpl(int depositAmount, int withDrawAmount, int maxTokens, int initialTokens) {
this.depositAmount = depositAmount;
this.withDrawAmount = withDrawAmount;
this.maxTokens = maxTokens;
this.tokens = initialTokens;
}

@Override
public void deposit() {
tokens = max(maxTokens, tokens + depositAmount);
}

@Override
public boolean withdraw() {
if (tokens < withDrawAmount) {
return false;
} else {
tokens -= withDrawAmount;
return true;
}
}
}

private static PercentileTracker defaultTracker() {
return new DefaultPercentileTracker(3.0, 256);
}

private static Budget defaultBudget() {
// 5% extra load and a max burst of 5 hedges.
return new DefaultBudgetImpl(1, 20, 100);
}

static PercentileTracker constantTracker(int value) {
return new PercentileTracker() {
@Override
public void addSample(long sample) {
// noop
}

@Override
public long getValue() {
return value;
}
};
}

static Budget alwaysAllowBudget() {
return new Budget() {
@Override
public void deposit() {
// noop
}

@Override
public boolean withdraw() {
return true;
}
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package io.servicetalk.dns.discovery.netty;

import java.util.Arrays;

final class MovingVariance {

private final int size;
private final double invSize;
private final double invSizeBySizeMinus1;

// We initialize with the assumption that all previous sames were zero. This lets us know the sum (x[i]) == 0
// and that Var(x[n]) == 0. However, that means that until `size` samples the variance will be artificially low.
private final int[] xi;
private int ii;
private long sumXi;
private long varianceSizeSizeMinus1;

MovingVariance(final int size) {
this(size, Integer.MAX_VALUE);
}

MovingVariance(final int size, final int initialMean) {
if (size < 2) {
throw new IllegalArgumentException("Must allow at least two samples");
}
this.size = size;
this.invSize = 1.0 / size;
this.invSizeBySizeMinus1 = 1.0 / (size * (size - 1));
this.xi = new int[size];
Arrays.fill(xi, initialMean);
sumXi = ((long) initialMean) * size;
}

public double mean() {
return sumXi * invSize;
}

public double variance() {
return varianceSizeSizeMinus1 * invSizeBySizeMinus1;
}

public double stdev() {
return Math.sqrt(variance());
}

public void addSample(int sample) {
// Widen sample to a long so that we don't have to worry about overflows.
final long xn = sample;
final int i = getIndex();
final long x0 = xi[i];
xi[i] = sample;
final long oldSumXi = sumXi;
sumXi += xn - x0;
varianceSizeSizeMinus1 = varianceSizeSizeMinus1 + (size * (xn + x0) - sumXi - oldSumXi) * (xn - x0);
}

private int getIndex() {
final int result = ii;
if (++ii == size) {
ii = 0;
}
return result;
}
}
Loading