1 package com.github.davidmoten.rx;
2
3 import static com.github.davidmoten.util.Optional.absent;
4 import static com.github.davidmoten.util.Optional.of;
5 import static rx.Observable.just;
6
7 import java.util.ArrayList;
8 import java.util.Arrays;
9 import java.util.List;
10 import java.util.concurrent.TimeUnit;
11
12 import com.github.davidmoten.util.Optional;
13 import com.github.davidmoten.util.Preconditions;
14
15 import rx.Observable;
16 import rx.Scheduler;
17 import rx.functions.Action1;
18 import rx.functions.Func1;
19 import rx.functions.Func2;
20 import rx.schedulers.Schedulers;
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38 public final class RetryWhen {
39
40 private static final long NO_MORE_DELAYS = -1;
41
42 private static Func1<Observable<? extends Throwable>, Observable<?>> notificationHandler(
43 final Observable<Long> delays, final Scheduler scheduler, final Action1<? super ErrorAndDuration> action,
44 final List<Class<? extends Throwable>> retryExceptions,
45 final List<Class<? extends Throwable>> failExceptions,
46 final Func1<? super Throwable, Boolean> exceptionPredicate) {
47
48 final Func1<ErrorAndDuration, Observable<ErrorAndDuration>> checkExceptions = createExceptionChecker(
49 retryExceptions, failExceptions, exceptionPredicate);
50
51 return createNotificationHandler(delays, scheduler, action, checkExceptions);
52 }
53
54 private static Func1<Observable<? extends Throwable>, Observable<?>> createNotificationHandler(
55 final Observable<Long> delays, final Scheduler scheduler, final Action1<? super ErrorAndDuration> action,
56 final Func1<ErrorAndDuration, Observable<ErrorAndDuration>> checkExceptions) {
57 return new Func1<Observable<? extends Throwable>, Observable<?>>() {
58
59 @Override
60 public Observable<ErrorAndDuration> call(Observable<? extends Throwable> errors) {
61 return errors
62
63 .zipWith(delays.concatWith(just(NO_MORE_DELAYS)), TO_ERROR_AND_DURATION)
64
65 .flatMap(checkExceptions)
66
67
68 .doOnNext(callActionExceptForLast(action))
69
70 .flatMap(delay(scheduler));
71 }
72 };
73 }
74
75 private static Action1<ErrorAndDuration> callActionExceptForLast(final Action1<? super ErrorAndDuration> action) {
76 return new Action1<ErrorAndDuration>() {
77
78 @Override
79 public void call(ErrorAndDuration e) {
80 if (e.durationMs() != NO_MORE_DELAYS)
81 action.call(e);
82 }
83
84 };
85 }
86
87
88 private static Func1<ErrorAndDuration, Observable<ErrorAndDuration>> createExceptionChecker(
89 final List<Class<? extends Throwable>> retryExceptions,
90 final List<Class<? extends Throwable>> failExceptions,
91 final Func1<? super Throwable, Boolean> exceptionPredicate) {
92 return new Func1<ErrorAndDuration, Observable<ErrorAndDuration>>() {
93
94 @Override
95 public Observable<ErrorAndDuration> call(ErrorAndDuration e) {
96 if (!exceptionPredicate.call(e.throwable()))
97 return Observable.error(e.throwable());
98 for (Class<? extends Throwable> cls : failExceptions) {
99 if (e.throwable().getClass().isAssignableFrom(cls))
100 return Observable.error(e.throwable());
101 }
102 if (retryExceptions.size() > 0) {
103 for (Class<? extends Throwable> cls : retryExceptions) {
104 if (e.throwable().getClass().isAssignableFrom(cls))
105 return just(e);
106 }
107 return Observable.error(e.throwable());
108 } else {
109 return just(e);
110 }
111 }
112 };
113 }
114
115 private static Func2<Throwable, Long, ErrorAndDuration> TO_ERROR_AND_DURATION = new Func2<Throwable, Long, ErrorAndDuration>() {
116 @Override
117 public ErrorAndDuration call(Throwable throwable, Long durationMs) {
118 return new ErrorAndDuration(throwable, durationMs);
119 }
120 };
121
122 private static Func1<ErrorAndDuration, Observable<ErrorAndDuration>> delay(final Scheduler scheduler) {
123 return new Func1<ErrorAndDuration, Observable<ErrorAndDuration>>() {
124 @Override
125 public Observable<ErrorAndDuration> call(ErrorAndDuration e) {
126 if (e.durationMs() == NO_MORE_DELAYS)
127 return Observable.error(e.throwable());
128 else
129 return Observable.timer(e.durationMs(), TimeUnit.MILLISECONDS, scheduler)
130 .map(Functions.constant(e));
131 }
132 };
133 }
134
135
136
137 public static Builder retryWhenInstanceOf(Class<? extends Throwable>... classes) {
138 return new Builder().retryWhenInstanceOf(classes);
139 }
140
141 public static Builder failWhenInstanceOf(Class<? extends Throwable>... classes) {
142 return new Builder().failWhenInstanceOf(classes);
143 }
144
145 public static Builder retryIf(Func1<Throwable, Boolean> predicate) {
146 return new Builder().retryIf(predicate);
147 }
148
149 public static Builder delays(Observable<Long> delays, TimeUnit unit) {
150 return new Builder().delays(delays, unit);
151 }
152
153 public static Builder delaysInt(Observable<Integer> delays, TimeUnit unit) {
154 return new Builder().delaysInt(delays, unit);
155 }
156
157 public static Builder delay(long delay, final TimeUnit unit) {
158 return new Builder().delay(delay, unit);
159 }
160
161 public static Builder maxRetries(int maxRetries) {
162 return new Builder().maxRetries(maxRetries);
163 }
164
165 public static Builder scheduler(Scheduler scheduler) {
166 return new Builder().scheduler(scheduler);
167 }
168
169 public Builder action(Action1<? super ErrorAndDuration> action) {
170 return new Builder().action(action);
171 }
172
173 public static Builder exponentialBackoff(final long firstDelay, final TimeUnit unit, final double factor) {
174 return new Builder().exponentialBackoff(firstDelay, unit, factor);
175 }
176
177 public static Builder exponentialBackoff(long firstDelay, TimeUnit unit) {
178 return new Builder().exponentialBackoff(firstDelay, unit);
179 }
180
181 public static final class Builder {
182
183 private final List<Class<? extends Throwable>> retryExceptions = new ArrayList<Class<? extends Throwable>>();
184 private final List<Class<? extends Throwable>> failExceptions = new ArrayList<Class<? extends Throwable>>();
185 private Func1<? super Throwable, Boolean> exceptionPredicate = Functions.alwaysTrue();
186
187 private Observable<Long> delays = Observable.just(0L).repeat();
188 private Optional<Integer> maxRetries = absent();
189 private Optional<Scheduler> scheduler = of(Schedulers.computation());
190 private Action1<? super ErrorAndDuration> action = Actions.doNothing1();
191
192 private Builder() {
193
194 }
195
196 public Builder retryWhenInstanceOf(Class<? extends Throwable>... classes) {
197 retryExceptions.addAll(Arrays.asList(classes));
198 return this;
199 }
200
201 public Builder failWhenInstanceOf(Class<? extends Throwable>... classes) {
202 failExceptions.addAll(Arrays.asList(classes));
203 return this;
204 }
205
206 public Builder retryIf(Func1<Throwable, Boolean> predicate) {
207 this.exceptionPredicate = predicate;
208 return this;
209 }
210
211 public Builder delays(Observable<Long> delays, TimeUnit unit) {
212 this.delays = delays.map(toMillis(unit));
213 return this;
214 }
215
216 private static class ToLongHolder {
217 static final Func1<Integer, Long> INSTANCE = new Func1<Integer, Long>() {
218 @Override
219 public Long call(Integer n) {
220 if (n == null) {
221 return null;
222 } else {
223 return n.longValue();
224 }
225 }
226 };
227 }
228
229 public Builder delaysInt(Observable<Integer> delays, TimeUnit unit) {
230 return delays(delays.map(ToLongHolder.INSTANCE), unit);
231 }
232
233 public Builder delay(Long delay, final TimeUnit unit) {
234 this.delays = Observable.just(delay).map(toMillis(unit)).repeat();
235 return this;
236 }
237
238 private static Func1<Long, Long> toMillis(final TimeUnit unit) {
239 return new Func1<Long, Long>() {
240
241 @Override
242 public Long call(Long t) {
243 return unit.toMillis(t);
244 }
245 };
246 }
247
248 public Builder maxRetries(int maxRetries) {
249 this.maxRetries = of(maxRetries);
250 return this;
251 }
252
253 public Builder scheduler(Scheduler scheduler) {
254 this.scheduler = of(scheduler);
255 return this;
256 }
257
258 public Builder action(Action1<? super ErrorAndDuration> action) {
259 this.action = action;
260 return this;
261 }
262
263 public Builder exponentialBackoff(final long firstDelay, final TimeUnit unit, final double factor) {
264 delays = Observable.range(1, Integer.MAX_VALUE)
265
266 .map(new Func1<Integer, Long>() {
267 @Override
268 public Long call(Integer n) {
269 return Math.round(Math.pow(factor, n - 1) * unit.toMillis(firstDelay));
270 }
271 });
272 return this;
273 }
274
275 public Builder exponentialBackoff(long firstDelay, TimeUnit unit) {
276 return exponentialBackoff(firstDelay, unit, 2);
277 }
278
279 public Func1<Observable<? extends Throwable>, Observable<?>> build() {
280 Preconditions.checkNotNull(delays);
281 if (maxRetries.isPresent()) {
282 delays = delays.take(maxRetries.get());
283 }
284 return notificationHandler(delays, scheduler.get(), action, retryExceptions, failExceptions,
285 exceptionPredicate);
286 }
287
288 }
289
290 public static final class ErrorAndDuration {
291
292 private final Throwable throwable;
293 private final long durationMs;
294
295 public ErrorAndDuration(Throwable throwable, long durationMs) {
296 this.throwable = throwable;
297 this.durationMs = durationMs;
298 }
299
300 public Throwable throwable() {
301 return throwable;
302 }
303
304 public long durationMs() {
305 return durationMs;
306 }
307
308 }
309 }