View Javadoc
1   package com.github.davidmoten.aws.lw.client.internal;
2   
3   import java.io.IOException;
4   import java.io.UncheckedIOException;
5   import java.util.concurrent.Callable;
6   import java.util.function.Predicate;
7   
8   import com.github.davidmoten.aws.lw.client.MaxAttemptsExceededException;
9   import com.github.davidmoten.aws.lw.client.internal.util.Preconditions;
10  
11  public final class Retries<T> {
12  
13      private final long initialIntervalMs;
14      private final int maxAttempts;
15      private final double backoffFactor;
16      private final long maxIntervalMs;
17      private final double jitter;
18      private final Predicate<? super T> valueShouldRetry;
19      private final Predicate<? super Throwable> throwableShouldRetry;
20  
21      public Retries(long initialIntervalMs, int maxAttempts, double backoffFactor, double jitter, long maxIntervalMs,
22              Predicate<? super T> valueShouldRetry, Predicate<? super Throwable> throwableShouldRetry) {
23          Preconditions.checkArgument(jitter >= 0 && jitter <= 1, "jitter must be between 0 and 1 inclusive");
24          this.initialIntervalMs = initialIntervalMs;
25          this.maxAttempts = maxAttempts;
26          this.backoffFactor = backoffFactor;
27          this.jitter = jitter;
28          this.maxIntervalMs = maxIntervalMs;
29          this.valueShouldRetry = valueShouldRetry;
30          this.throwableShouldRetry = throwableShouldRetry;
31      }
32  
33      public static <T> Retries<T> create(Predicate<? super T> valueShouldRetry,
34              Predicate<? super Throwable> throwableShouldRetry) {
35          return new Retries<T>( //
36                  100, //
37                  4, //
38                  2.0, //
39                  0.0, // no jitter
40                  20000, //
41                  valueShouldRetry, //
42                  throwableShouldRetry);
43      }
44  
45      public T call(Callable<T> callable) {
46          return call(callable, valueShouldRetry);
47      }
48  
49      public <S> S call(Callable<S> callable, Predicate<? super S> valueShouldRetry) {
50          long intervalMs = initialIntervalMs;
51          int attempt = 0;
52          while (true) {
53              S value;
54              try {
55                  attempt++;
56                  value = callable.call();
57                  if (!valueShouldRetry.test(value)) {
58                      return value;
59                  }
60                  if (reachedMaxAttempts(attempt, maxAttempts)) {
61                      // note that caller is not aware that maxAttempts were reached, the caller just
62                      // receives the last error response
63                      return value;
64                  }
65              } catch (Throwable t) {
66                  if (!throwableShouldRetry.test(t)) {
67                      rethrow(t);
68                  }
69                  if (reachedMaxAttempts(attempt, maxAttempts)) {
70                      throw new MaxAttemptsExceededException("exceeded max attempts " + maxAttempts, t);
71                  }
72              }
73              sleep(intervalMs);
74              //calculate the interval for the next retry
75              intervalMs = Math.round(backoffFactor * intervalMs);
76              if (maxIntervalMs > 0) {
77                  intervalMs = Math.min(maxIntervalMs, intervalMs);
78              }
79              // apply jitter (if 0 then no change)
80              intervalMs = Math.round((1 - jitter * Math.random()) * intervalMs);
81          }
82      }
83  
84      // VisibleForTesting
85      static boolean reachedMaxAttempts(int attempt, int maxAttempts) {
86          return maxAttempts > 0 && attempt >= maxAttempts;
87      }
88  
89      static void sleep(long intervalMs) {
90          try {
91              Thread.sleep(intervalMs);
92          } catch (InterruptedException e) {
93              throw new RuntimeException(e);
94          }
95      }
96  
97      public <S> Retries<S> withValueShouldRetry(Predicate<? super S> valueShouldRetry) {
98          return new Retries<S>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
99                  throwableShouldRetry);
100     }
101 
102     public Retries<T> withInitialIntervalMs(long initialIntervalMs) {
103         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
104                 throwableShouldRetry);
105     }
106 
107     public Retries<T> withMaxAttempts(int maxAttempts) {
108         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
109                 throwableShouldRetry);
110     }
111 
112     public Retries<T> withBackoffFactor(double backoffFactor) {
113         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
114                 throwableShouldRetry);
115     }
116 
117     public Retries<T> withMaxIntervalMs(long maxIntervalMs) {
118         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
119                 throwableShouldRetry);
120     }
121     
122     public Retries<T> withJitter(double jitter) {
123         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
124                 throwableShouldRetry);
125     }
126 
127     public Retries<T> withThrowableShouldRetry(Predicate<? super Throwable> throwableShouldRetry) {
128         return new Retries<T>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
129                 throwableShouldRetry);
130     }
131 
132     public Retries<T> copy() {
133         return new Retries<>(initialIntervalMs, maxAttempts, backoffFactor, jitter, maxIntervalMs, valueShouldRetry,
134                 throwableShouldRetry);
135     }
136 
137     // VisibleForTesting
138     static void rethrow(Throwable t) throws Error {
139         if (t instanceof RuntimeException) {
140             throw (RuntimeException) t;
141         } else if (t instanceof Error) {
142             throw (Error) t;
143         } else if (t instanceof IOException) {
144             throw new UncheckedIOException((IOException) t);
145         } else {
146             throw new RuntimeException(t);
147         }
148     }
149 
150 }