diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java index b909dc4bf0..fcaa9f100c 100644 --- a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/DefaultDnsClient.java @@ -47,7 +47,6 @@ import io.netty.resolver.dns.DefaultAuthoritativeDnsServerCache; import io.netty.resolver.dns.DefaultDnsCache; import io.netty.resolver.dns.DefaultDnsCnameCache; -import io.netty.resolver.dns.DnsNameResolver; import io.netty.resolver.dns.DnsNameResolverBuilder; import io.netty.resolver.dns.NameServerComparator; import io.netty.resolver.dns.NoopAuthoritativeDnsServerCache; @@ -127,7 +126,7 @@ final class DefaultDnsClient implements DnsClient { private static final Cancellable TERMINATED = () -> { }; private final EventLoopAwareNettyIoExecutor nettyIoExecutor; - private final DnsNameResolver resolver; + private final UnderlyingDnsResolver resolver; private final MinTtlCache ttlCache; private final long maxTTLNanos; private final long ttlJitterNanos; @@ -225,7 +224,16 @@ final class DefaultDnsClient implements DnsClient { if (dnsServerAddressStreamProvider != null) { builder.nameServerProvider(toNettyType(dnsServerAddressStreamProvider)); } - resolver = builder.build(); + if (true /* hedging enabled */) { // need to wire this in. + DnsNameResolverBuilderUtils.consolidateCacheSize(id, builder, 0); + resolver = new HedgingDnsNameResolver( +// new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build()), nettyIoExecutor); + // TODO: this is just for hacking together tests. + new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build()), nettyIoExecutor, + HedgingDnsNameResolver.constantTracker(100), HedgingDnsNameResolver.alwaysAllowBudget()); + } else { + resolver = new UnderlyingDnsResolver.NettyDnsNameResolver(builder.build()); + } this.resolutionTimeoutMillis = resolutionTimeout != null ? resolutionTimeout.toMillis() : // Default value is chosen based on a combination of default "timeout" and "attempts" options of // /etc/resolv.conf: https://man7.org/linux/man-pages/man5/resolv.conf.5.html @@ -435,7 +443,7 @@ protected Future> doDnsQuery(final boolean scheduledQuery final EventLoop eventLoop = nettyIoExecutor.eventLoopGroup().next(); final Promise> promise = eventLoop.newPromise(); final Future> resolveFuture = - resolver.resolveAll(new DefaultDnsQuestion(name, SRV)); + resolver.resolveAllQuestion(new DefaultDnsQuestion(name, SRV)); final Future timeoutFuture = resolutionTimeoutMillis == 0L ? null : eventLoop.schedule(() -> { if (!promise.isDone() && promise.tryFailure(DnsNameResolverTimeoutException.newInstance( name, resolutionTimeoutMillis, SRV.toString(), diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/HedgingDnsNameResolver.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/HedgingDnsNameResolver.java new file mode 100644 index 0000000000..c1e0a24e86 --- /dev/null +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/HedgingDnsNameResolver.java @@ -0,0 +1,245 @@ +/* + * Copyright © 2024 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.dns.discovery.netty; + +import io.servicetalk.concurrent.Cancellable; + +import io.netty.handler.codec.dns.DnsQuestion; +import io.netty.handler.codec.dns.DnsRecord; +import io.netty.resolver.dns.DnsNameResolver; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.servicetalk.transport.api.IoExecutor; +import io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutor; + +import java.net.InetAddress; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static io.servicetalk.transport.netty.internal.EventLoopAwareNettyIoExecutors.toEventLoopAwareNettyIoExecutor; +import static io.servicetalk.utils.internal.NumberUtils.ensurePositive; +import static java.lang.Math.max; +import static java.lang.Math.min; + +final class HedgingDnsNameResolver implements UnderlyingDnsResolver { + + private final UnderlyingDnsResolver delegate; + private final EventLoopAwareNettyIoExecutor executor; + private final PercentileTracker percentile; + private final Budget budget; + + HedgingDnsNameResolver(DnsNameResolver delegate, IoExecutor executor) { + this(new NettyDnsNameResolver(delegate), executor); + } + + HedgingDnsNameResolver(UnderlyingDnsResolver delegate, IoExecutor executor) { + this(delegate, executor, defaultTracker(), defaultBudget()); + } + + HedgingDnsNameResolver(UnderlyingDnsResolver delegate, IoExecutor executor, + PercentileTracker percentile, Budget budget) { + this.delegate = delegate; + this.executor = toEventLoopAwareNettyIoExecutor(executor).next(); + this.percentile = percentile; + this.budget = budget; + } + + @Override + public long queryTimeoutMillis() { + return delegate.queryTimeoutMillis(); + } + + @Override + public Future> resolveAllQuestion(DnsQuestion t) { + return setupHedge(delegate::resolveAllQuestion, t); + } + + @Override + public Future> resolveAll(String t) { + return setupHedge(delegate::resolveAll, t); + } + + @Override + public void close() { + delegate.close(); + } + + private long currentTimeMillis() { + return executor.currentTime(TimeUnit.MILLISECONDS); + } + + private Future setupHedge(Function> computation, T t) { + // Only add tokens for organic requests and not retries. + budget.deposit(); + Future underlyingResult = computation.apply(t); + final long delay = percentile.getValue(); + if (delay == Long.MAX_VALUE) { + // basically forever: just return the value. + return underlyingResult; + } else { + final long startTimeMs = currentTimeMillis(); + Promise promise = executor.eventLoopGroup().next().newPromise(); + Cancellable hedgeTimer = executor.schedule(() -> tryHedge(computation, t, underlyingResult, promise), + delay, TimeUnit.MILLISECONDS); + underlyingResult.addListener(completedFuture -> { + measureRequest(currentTimeMillis() - startTimeMs, completedFuture); + if (complete(underlyingResult, promise)) { + hedgeTimer.cancel(); + } + }); + return promise; + } + } + + private void tryHedge( + Function> computation, T t, Future original, Promise promise) { + if (!original.isDone() && budget.withdraw()) { + System.out.println("" + System.currentTimeMillis() + ": sending backup request."); + Future backupResult = computation.apply(t); + final long startTime = currentTimeMillis(); + backupResult.addListener(done -> { + if (complete(backupResult, promise)) { + original.cancel(true); + measureRequest(currentTimeMillis() - startTime, done); + } + }); + promise.addListener(complete -> backupResult.cancel(true)); + } + } + + private void measureRequest(long durationMs, Future future) { + // Cancelled responses don't count but we do consider failed responses because failure + // is a legitimate response. + if (!future.isCancelled()) { + percentile.addSample(durationMs); + } + } + + private boolean complete(Future f, Promise p) { + assert f.isDone(); + if (f.isSuccess()) { + return p.trySuccess(f.getNow()); + } else { + return p.tryFailure(f.cause()); + } + } + + interface PercentileTracker { + void addSample(long sample); + + long getValue(); + } + + interface Budget { + void deposit(); + + boolean withdraw(); + } + + private static final class DefaultPercentileTracker implements PercentileTracker { + + private final MovingVariance movingVariance; + private final double multiple; + + DefaultPercentileTracker(final double multiple, final int historySize) { + movingVariance = new MovingVariance(historySize); + this.multiple = multiple; + } + + @Override + public void addSample(long sample) { + int clipped = Math.max(0, sample > Integer.MAX_VALUE ? Integer.MAX_VALUE : (int) sample); + movingVariance.addSample(clipped); + } + + @Override + public long getValue() { + return Math.round(movingVariance.mean() + movingVariance.stdev() * multiple); + } + } + + private static final class DefaultBudgetImpl implements Budget { + + private final int depositAmount; + private final int withDrawAmount; + private final int maxTokens; + private int tokens; + + DefaultBudgetImpl(int depositAmount, int withDrawAmount, int maxTokens) { + this(depositAmount, withDrawAmount, maxTokens, 0); + } + + DefaultBudgetImpl(int depositAmount, int withDrawAmount, int maxTokens, int initialTokens) { + this.depositAmount = depositAmount; + this.withDrawAmount = withDrawAmount; + this.maxTokens = maxTokens; + this.tokens = initialTokens; + } + + @Override + public void deposit() { + tokens = max(maxTokens, tokens + depositAmount); + } + + @Override + public boolean withdraw() { + if (tokens < withDrawAmount) { + return false; + } else { + tokens -= withDrawAmount; + return true; + } + } + } + + private static PercentileTracker defaultTracker() { + return new DefaultPercentileTracker(3.0, 256); + } + + private static Budget defaultBudget() { + // 5% extra load and a max burst of 5 hedges. + return new DefaultBudgetImpl(1, 20, 100); + } + + static PercentileTracker constantTracker(int value) { + return new PercentileTracker() { + @Override + public void addSample(long sample) { + // noop + } + + @Override + public long getValue() { + return value; + } + }; + } + + static Budget alwaysAllowBudget() { + return new Budget() { + @Override + public void deposit() { + // noop + } + + @Override + public boolean withdraw() { + return true; + } + }; + } +} diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/MovingVariance.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/MovingVariance.java new file mode 100644 index 0000000000..8f4e7527a2 --- /dev/null +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/MovingVariance.java @@ -0,0 +1,64 @@ +package io.servicetalk.dns.discovery.netty; + +import java.util.Arrays; + +final class MovingVariance { + + private final int size; + private final double invSize; + private final double invSizeBySizeMinus1; + + // We initialize with the assumption that all previous sames were zero. This lets us know the sum (x[i]) == 0 + // and that Var(x[n]) == 0. However, that means that until `size` samples the variance will be artificially low. + private final int[] xi; + private int ii; + private long sumXi; + private long varianceSizeSizeMinus1; + + MovingVariance(final int size) { + this(size, Integer.MAX_VALUE); + } + + MovingVariance(final int size, final int initialMean) { + if (size < 2) { + throw new IllegalArgumentException("Must allow at least two samples"); + } + this.size = size; + this.invSize = 1.0 / size; + this.invSizeBySizeMinus1 = 1.0 / (size * (size - 1)); + this.xi = new int[size]; + Arrays.fill(xi, initialMean); + sumXi = ((long) initialMean) * size; + } + + public double mean() { + return sumXi * invSize; + } + + public double variance() { + return varianceSizeSizeMinus1 * invSizeBySizeMinus1; + } + + public double stdev() { + return Math.sqrt(variance()); + } + + public void addSample(int sample) { + // Widen sample to a long so that we don't have to worry about overflows. + final long xn = sample; + final int i = getIndex(); + final long x0 = xi[i]; + xi[i] = sample; + final long oldSumXi = sumXi; + sumXi += xn - x0; + varianceSizeSizeMinus1 = varianceSizeSizeMinus1 + (size * (xn + x0) - sumXi - oldSumXi) * (xn - x0); + } + + private int getIndex() { + final int result = ii; + if (++ii == size) { + ii = 0; + } + return result; + } +} diff --git a/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/UnderlyingDnsResolver.java b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/UnderlyingDnsResolver.java new file mode 100644 index 0000000000..d101c40cb6 --- /dev/null +++ b/servicetalk-dns-discovery-netty/src/main/java/io/servicetalk/dns/discovery/netty/UnderlyingDnsResolver.java @@ -0,0 +1,50 @@ +package io.servicetalk.dns.discovery.netty; + +import io.netty.handler.codec.dns.DnsQuestion; +import io.netty.handler.codec.dns.DnsRecord; +import io.netty.resolver.dns.DnsNameResolver; +import io.netty.util.concurrent.Future; + +import java.io.Closeable; +import java.net.InetAddress; +import java.util.List; + +interface UnderlyingDnsResolver extends Closeable { + + Future> resolveAllQuestion(DnsQuestion t); + + Future> resolveAll(String t); + + long queryTimeoutMillis(); + + @Override + void close(); + + static final class NettyDnsNameResolver implements UnderlyingDnsResolver { + private final DnsNameResolver resolver; + + NettyDnsNameResolver(final DnsNameResolver resolver) { + this.resolver = resolver; + } + + @Override + public Future> resolveAllQuestion(DnsQuestion t) { + return resolver.resolveAll(t); + } + + @Override + public Future> resolveAll(String t) { + return resolver.resolveAll(t); + } + + @Override + public long queryTimeoutMillis() { + return resolver.queryTimeoutMillis(); + } + + @Override + public void close() { + resolver.close(); + } + } +} diff --git a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java index 06888c0d3c..5ec408920d 100644 --- a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java +++ b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/DefaultDnsClientTest.java @@ -44,6 +44,9 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; +import java.net.DatagramPacket; +import java.net.DatagramSocket; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; @@ -175,6 +178,50 @@ static Stream missingRecordStatus() { return Stream.of(ServiceDiscovererEvent.Status.EXPIRED, ServiceDiscovererEvent.Status.UNAVAILABLE); } + @Test + void hedging() throws Exception { + // should be bound now. + DatagramSocket datagramSocket = new DatagramSocket(new InetSocketAddress("127.0.0.1", 5657)); + Thread t = new Thread(() -> { + while (true) { + byte[] buf = new byte[2048]; + DatagramPacket packet = new DatagramPacket(buf, buf.length); + try { + datagramSocket.receive(packet); + } catch (IOException ex) { + System.out.println("Exception: " + ex); + return; + } + String packetStr = new String(buf, 0, packet.getLength()); + System.out.println("" + System.currentTimeMillis() + ": Received packet- " + packetStr); + packet.getLength(); + } + }); + t.start(); + + setup(builder -> builder.dnsServerAddressStreamProvider( + new SequentialDnsServerAddressStreamProvider((InetSocketAddress) datagramSocket.getLocalSocketAddress(), dnsServer.localAddress()))); +// setup(); + + final String targetDomain1 = "sd.domain.com"; + final String ip1 = nextIp(); + + recordStore.addIPv4Address(targetDomain1, DEFAULT_TTL, ip1); + CountDownLatch latch = new CountDownLatch(1); + recordStore.addStall(targetDomain1, latch); + + TestPublisherSubscriber> subscriber = dnsQuery(targetDomain1); + Subscription subscription = subscriber.awaitSubscription(); + subscription.request(Long.MAX_VALUE); + + Thread.sleep(100); // just add an actual delay so our println messages don't stack atop one another. + advanceTime(); + + List> signals = subscriber.takeOnNext(1); + assertHasEvent(signals, ip1, AVAILABLE); + } + + @Test void singleSrvSingleADiscover() throws Exception { setup(); @@ -1378,7 +1425,7 @@ private static void assertHasEvent(Collection timerExecutor = ExecutorExtension.withTestExecutor() + .setClassLevel(true); + + @RegisterExtension + static final ExecutorExtension ioExecutor = ExecutorExtension + .withExecutor(() -> createIoExecutor(1)) + .setClassLevel(true); + + HedgingDnsNameResolver.PercentileTracker percentileTracker; + HedgingDnsNameResolver.Budget budget; + + UnderlyingDnsResolver underlying; + HedgingDnsNameResolver resolver; + + void setup() { + if (percentileTracker == null) { + percentileTracker = mock(HedgingDnsNameResolver.PercentileTracker.class); + when(percentileTracker.getValue()).thenReturn(10L); + } + if (budget == null) { + budget = mock(HedgingDnsNameResolver.Budget.class); + when(budget.withdraw()).thenReturn(true); + } + underlying = mock(UnderlyingDnsResolver.class); + // DnsResolverIface delegate, Executor executor, EventLoop eventLoop, + // PercentileTracker percentile, Budget budget + resolver = new HedgingDnsNameResolver(underlying, + new DefaultDnsClientTest.NettyIoExecutorWithTestTimer(ioExecutor.executor(), timerExecutor.executor()), + percentileTracker, budget); + } + + @Test + void requestThatDoesntNeedHedge() throws Exception { + setup(); + Promise> p1 = newPromise(); + when(underlying.resolveAll(any())).thenReturn(p1, null); + Future> results = resolver.resolveAll("apple.com"); + assertThat(results.isDone(), equalTo(false)); + advanceTime(1); + List result = new ArrayList<>(); + p1.trySuccess(result); + assertThat(results.get(), equalTo(result)); + verify(budget).deposit(); + verify(budget, never()).withdraw(); + verify(percentileTracker).addSample(1); + } + + @Test + void requestWithHedgingAndFirstWins() throws Exception { + setup(); + Promise> p1 = newPromise(); + Promise> p2 = newPromise(); + when(underlying.resolveAll(any())).thenReturn(p1, p2); + Future> results = resolver.resolveAll("apple.com"); + assertThat(results.isDone(), equalTo(false)); + advanceTime(10); + List result = new ArrayList<>(); + p1.trySuccess(result); + assertThat(results.get(), equalTo(result)); + verify(budget).deposit(); + verify(budget, times(1)).withdraw(); + verify(percentileTracker).addSample(10); + assertThat(p2.isCancelled(), equalTo(true)); + } + + @Test + void requestWithHedgingAndSecondWins() throws Exception { + setup(); + Promise> p1 = newPromise(); + Promise> p2 = newPromise(); + when(underlying.resolveAll(any())).thenReturn(p1, p2); + + Future> results = resolver.resolveAll("apple.com"); + assertThat(results.isDone(), equalTo(false)); + advanceTime(10); + + // Hedging should have started and the new timer should be set. + advanceTime(5); + List result = new ArrayList<>(); + p2.trySuccess(result); + assertThat(results.get(), equalTo(result)); + verify(budget).deposit(); + verify(budget, times(1)).withdraw(); + verify(percentileTracker).addSample(5); // only add the successful sample. + assertThat(p1.isCancelled(), equalTo(true)); + } + + @Test + void requestWhenNoHedgingBudget() throws Exception { + setup(); + when(budget.withdraw()).thenReturn(false); + Promise> p1 = newPromise(); + when(underlying.resolveAll(any())).thenReturn(p1); + + Future> results = resolver.resolveAll("apple.com"); + assertThat(results.isDone(), equalTo(false)); + advanceTime(10); + + verify(budget).deposit(); + verify(budget, times(1)).withdraw(); + verify(underlying, times(1)).resolveAll("apple.com"); + } + + private Promise newPromise() { + return ioExecutor.executor().eventLoopGroup().next().newPromise(); + } + + private static void advanceTime(int advance) throws Exception { + // To make sure that the time is advanced after all prior work on the EvenLoop is complete, we advance it from + // the EventLoop too. + ioExecutor.executor().submit(() -> { + LOGGER.debug("Advance time by {}s.", advance); + timerExecutor.executor().advanceTimeBy(advance, MILLISECONDS); + }).toFuture().get(); + } +} diff --git a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/MovingVarianceTest.java b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/MovingVarianceTest.java new file mode 100644 index 0000000000..b574aa1c75 --- /dev/null +++ b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/MovingVarianceTest.java @@ -0,0 +1,55 @@ +package io.servicetalk.dns.discovery.netty; + +import org.junit.jupiter.api.RepeatedTest; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.ThreadLocalRandom; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.closeTo; +import static org.hamcrest.Matchers.equalTo; + +class MovingVarianceTest { + + @Test + void twoSamples() { + MovingVariance m = new MovingVariance(2); + m.addSample(1); + m.addSample(1); + assertThat(m.variance(), equalTo(0.0)); + } + + @RepeatedTest(1000) + void jquikish() { + ThreadLocalRandom r = ThreadLocalRandom.current(); + final int size = r.nextInt(2, 100); + MovingVariance m = new MovingVariance(size); + int[] samples = new int[size]; + for (int i = 0; i < size; i++) { + int xi = r.nextInt(-100, 100); + samples[i] = xi; + m.addSample(xi); + } + + assertThat(m.variance(), closeTo(variance(samples), 0.0001)); + assertThat(m.mean(), closeTo(mean(samples), 0.0001)); + } + + private double variance(int[] values) { + final double mean = mean(values); + double accumulator = 0; + for (int value : values) { + double diff = value - mean; + accumulator += diff * diff; + } + return accumulator / (values.length - 1); + } + + double mean(int[] values) { + long sum = 0; + for (double v : values) { + sum += v; + } + return ((double) sum) / values.length; + } +} diff --git a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java index 50b797ee15..a2a424401e 100644 --- a/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java +++ b/servicetalk-dns-discovery-netty/src/test/java/io/servicetalk/dns/discovery/netty/TestRecordStore.java @@ -50,6 +50,7 @@ final class TestRecordStore implements RecordStore { private static final int SRV_DEFAULT_PRIORITY = 10; private final Set failSet = new HashSet<>(); + private final Map stalledRecords = new HashMap<>(); private final Map timeouts = new ConcurrentHashMap<>(); private final Map>> recordsToReturnByDomain = new ConcurrentHashMap<>(); @@ -93,6 +94,10 @@ public int hashCode() { } } + public synchronized void addStall(final String dnsRecordName, CountDownLatch latch) { + stalledRecords.put(dnsRecordName, latch); + } + public synchronized void addFail(final ServFail fail) { failSet.add(fail); } @@ -245,10 +250,30 @@ public synchronized Set getRecords(final QuestionRecord question } } final String domain = questionRecord.getDomainName(); - if (failSet.contains(ServFail.of(questionRecord))) { - throw new DnsException(SERVER_FAILURE); + + // TODO: the blocking doesn't work as expected because we can't get any more messages through for the + // backup request. + final CountDownLatch latch; + synchronized (this) { + latch = stalledRecords.remove(domain); + } + if (latch != null) { + try { + latch.await(); + } catch (InterruptedException cause) { + DnsException ex = new DnsException(SERVER_FAILURE); + ex.initCause(cause); + throw ex; + } } - final Map> recordsToReturn = recordsToReturnByDomain.get(domain); + final Map> recordsToReturn; + synchronized (this) { + if (failSet.contains(ServFail.of(questionRecord))) { + throw new DnsException(SERVER_FAILURE); + } + recordsToReturn = recordsToReturnByDomain.get(domain); + } + LOGGER.debug("Getting {} records for {}", questionRecord.getRecordType(), domain); if (recordsToReturn != null) { final List recordsForType = recordsToReturn.get(questionRecord.getRecordType());