View Javadoc
1   package com.github.davidmoten.rx;
2   
3   import static com.github.davidmoten.util.Optional.absent;
4   import static com.github.davidmoten.util.Optional.of;
5   import static rx.Observable.just;
6   
7   import java.util.ArrayList;
8   import java.util.Arrays;
9   import java.util.List;
10  import java.util.concurrent.TimeUnit;
11  
12  import com.github.davidmoten.util.Optional;
13  import com.github.davidmoten.util.Preconditions;
14  
15  import rx.Observable;
16  import rx.Scheduler;
17  import rx.functions.Action1;
18  import rx.functions.Func1;
19  import rx.functions.Func2;
20  import rx.schedulers.Schedulers;
21  
22  /**
23   * Provides builder for the {@link Func1} parameter of
24   * {@link Observable#retryWhen(Func1)}. For example:
25   * 
26   * <pre>
27   * o.retryWhen(RetryWhen.maxRetries(4).delay(10, TimeUnit.SECONDS).action(log).build());
28   * </pre>
29   * 
30   * <p>
31   * or
32   * </p>
33   * 
34   * <pre>
35   * o.retryWhen(RetryWhen.exponentialBackoff(100, TimeUnit.MILLISECONDS).maxRetries(10).build());
36   * </pre>
37   */
38  public final class RetryWhen {
39  
40  	private static final long NO_MORE_DELAYS = -1;
41  
42  	private static Func1<Observable<? extends Throwable>, Observable<?>> notificationHandler(
43  			final Observable<Long> delays, final Scheduler scheduler, final Action1<? super ErrorAndDuration> action,
44  			final List<Class<? extends Throwable>> retryExceptions,
45  			final List<Class<? extends Throwable>> failExceptions,
46  			final Func1<? super Throwable, Boolean> exceptionPredicate) {
47  
48  		final Func1<ErrorAndDuration, Observable<ErrorAndDuration>> checkExceptions = createExceptionChecker(
49  				retryExceptions, failExceptions, exceptionPredicate);
50  
51  		return createNotificationHandler(delays, scheduler, action, checkExceptions);
52  	}
53  
54  	private static Func1<Observable<? extends Throwable>, Observable<?>> createNotificationHandler(
55  			final Observable<Long> delays, final Scheduler scheduler, final Action1<? super ErrorAndDuration> action,
56  			final Func1<ErrorAndDuration, Observable<ErrorAndDuration>> checkExceptions) {
57  		return new Func1<Observable<? extends Throwable>, Observable<?>>() {
58  
59  			@Override
60  			public Observable<ErrorAndDuration> call(Observable<? extends Throwable> errors) {
61  				return errors
62  						// zip with delays, use -1 to signal completion
63  						.zipWith(delays.concatWith(just(NO_MORE_DELAYS)), TO_ERROR_AND_DURATION)
64  						// check retry and non-retry exceptions
65  						.flatMap(checkExceptions)
66  						// perform user action (for example log that a
67  						// delay is happening)
68  						.doOnNext(callActionExceptForLast(action))
69  						// delay the time in ErrorAndDuration
70  						.flatMap(delay(scheduler));
71  			}
72  		};
73  	}
74  
75  	private static Action1<ErrorAndDuration> callActionExceptForLast(final Action1<? super ErrorAndDuration> action) {
76  		return new Action1<ErrorAndDuration>() {
77  
78  			@Override
79  			public void call(ErrorAndDuration e) {
80  				if (e.durationMs() != NO_MORE_DELAYS)
81  					action.call(e);
82  			}
83  
84  		};
85  	}
86  
87  	// TODO unit test
88  	private static Func1<ErrorAndDuration, Observable<ErrorAndDuration>> createExceptionChecker(
89  			final List<Class<? extends Throwable>> retryExceptions,
90  			final List<Class<? extends Throwable>> failExceptions,
91  			final Func1<? super Throwable, Boolean> exceptionPredicate) {
92  		return new Func1<ErrorAndDuration, Observable<ErrorAndDuration>>() {
93  
94  			@Override
95  			public Observable<ErrorAndDuration> call(ErrorAndDuration e) {
96  				if (!exceptionPredicate.call(e.throwable()))
97  					return Observable.error(e.throwable());
98  				for (Class<? extends Throwable> cls : failExceptions) {
99  					if (e.throwable().getClass().isAssignableFrom(cls))
100 						return Observable.error(e.throwable());
101 				}
102 				if (retryExceptions.size() > 0) {
103 					for (Class<? extends Throwable> cls : retryExceptions) {
104 						if (e.throwable().getClass().isAssignableFrom(cls))
105 							return just(e);
106 					}
107 					return Observable.error(e.throwable());
108 				} else {
109 					return just(e);
110 				}
111 			}
112 		};
113 	}
114 
115 	private static Func2<Throwable, Long, ErrorAndDuration> TO_ERROR_AND_DURATION = new Func2<Throwable, Long, ErrorAndDuration>() {
116 		@Override
117 		public ErrorAndDuration call(Throwable throwable, Long durationMs) {
118 			return new ErrorAndDuration(throwable, durationMs);
119 		}
120 	};
121 
122 	private static Func1<ErrorAndDuration, Observable<ErrorAndDuration>> delay(final Scheduler scheduler) {
123 		return new Func1<ErrorAndDuration, Observable<ErrorAndDuration>>() {
124 			@Override
125 			public Observable<ErrorAndDuration> call(ErrorAndDuration e) {
126 				if (e.durationMs() == NO_MORE_DELAYS)
127 					return Observable.error(e.throwable());
128 				else
129 					return Observable.timer(e.durationMs(), TimeUnit.MILLISECONDS, scheduler)
130 							.map(Functions.constant(e));
131 			}
132 		};
133 	}
134 
135 	// Builder factory methods
136 
137 	public static Builder retryWhenInstanceOf(Class<? extends Throwable>... classes) {
138 		return new Builder().retryWhenInstanceOf(classes);
139 	}
140 
141 	public static Builder failWhenInstanceOf(Class<? extends Throwable>... classes) {
142 		return new Builder().failWhenInstanceOf(classes);
143 	}
144 
145 	public static Builder retryIf(Func1<Throwable, Boolean> predicate) {
146 		return new Builder().retryIf(predicate);
147 	}
148 
149 	public static Builder delays(Observable<Long> delays, TimeUnit unit) {
150 		return new Builder().delays(delays, unit);
151 	}
152 
153 	public static Builder delaysInt(Observable<Integer> delays, TimeUnit unit) {
154 		return new Builder().delaysInt(delays, unit);
155 	}
156 
157 	public static Builder delay(long delay, final TimeUnit unit) {
158 		return new Builder().delay(delay, unit);
159 	}
160 
161 	public static Builder maxRetries(int maxRetries) {
162 		return new Builder().maxRetries(maxRetries);
163 	}
164 
165 	public static Builder scheduler(Scheduler scheduler) {
166 		return new Builder().scheduler(scheduler);
167 	}
168 
169 	public Builder action(Action1<? super ErrorAndDuration> action) {
170 		return new Builder().action(action);
171 	}
172 
173 	public static Builder exponentialBackoff(final long firstDelay, final TimeUnit unit, final double factor) {
174 		return new Builder().exponentialBackoff(firstDelay, unit, factor);
175 	}
176 
177 	public static Builder exponentialBackoff(long firstDelay, TimeUnit unit) {
178 		return new Builder().exponentialBackoff(firstDelay, unit);
179 	}
180 
181 	public static final class Builder {
182 
183 		private final List<Class<? extends Throwable>> retryExceptions = new ArrayList<Class<? extends Throwable>>();
184 		private final List<Class<? extends Throwable>> failExceptions = new ArrayList<Class<? extends Throwable>>();
185 		private Func1<? super Throwable, Boolean> exceptionPredicate = Functions.alwaysTrue();
186 
187 		private Observable<Long> delays = Observable.just(0L).repeat();
188 		private Optional<Integer> maxRetries = absent();
189 		private Optional<Scheduler> scheduler = of(Schedulers.computation());
190 		private Action1<? super ErrorAndDuration> action = Actions.doNothing1();
191 
192 		private Builder() {
193 			// must use static factory method to instantiate
194 		}
195 
196 		public Builder retryWhenInstanceOf(Class<? extends Throwable>... classes) {
197 			retryExceptions.addAll(Arrays.asList(classes));
198 			return this;
199 		}
200 
201 		public Builder failWhenInstanceOf(Class<? extends Throwable>... classes) {
202 			failExceptions.addAll(Arrays.asList(classes));
203 			return this;
204 		}
205 
206 		public Builder retryIf(Func1<Throwable, Boolean> predicate) {
207 			this.exceptionPredicate = predicate;
208 			return this;
209 		}
210 
211 		public Builder delays(Observable<Long> delays, TimeUnit unit) {
212 			this.delays = delays.map(toMillis(unit));
213 			return this;
214 		}
215 		
216 		private static class ToLongHolder {
217 			static final Func1<Integer, Long> INSTANCE = new Func1<Integer, Long>() {
218 				@Override
219 				public Long call(Integer n) {
220 					if (n == null) {
221 						return null;
222 					} else {
223 						return n.longValue();
224 					}
225 				}
226 			};
227 		}
228 		
229 		public Builder delaysInt(Observable<Integer> delays, TimeUnit unit) {
230 			return delays(delays.map(ToLongHolder.INSTANCE), unit);
231 		}
232 		
233 		public Builder delay(Long delay, final TimeUnit unit) {
234 			this.delays = Observable.just(delay).map(toMillis(unit)).repeat();
235 			return this;
236 		}
237 
238 		private static Func1<Long, Long> toMillis(final TimeUnit unit) {
239 			return new Func1<Long, Long>() {
240 
241 				@Override
242 				public Long call(Long t) {
243 					return unit.toMillis(t);
244 				}
245 			};
246 		}
247 
248 		public Builder maxRetries(int maxRetries) {
249 			this.maxRetries = of(maxRetries);
250 			return this;
251 		}
252 
253 		public Builder scheduler(Scheduler scheduler) {
254 			this.scheduler = of(scheduler);
255 			return this;
256 		}
257 
258 		public Builder action(Action1<? super ErrorAndDuration> action) {
259 			this.action = action;
260 			return this;
261 		}
262 
263 		public Builder exponentialBackoff(final long firstDelay, final TimeUnit unit, final double factor) {
264 			delays = Observable.range(1, Integer.MAX_VALUE)
265 					// make exponential
266 					.map(new Func1<Integer, Long>() {
267 						@Override
268 						public Long call(Integer n) {
269 							return Math.round(Math.pow(factor, n - 1) * unit.toMillis(firstDelay));
270 						}
271 					});
272 			return this;
273 		}
274 
275 		public Builder exponentialBackoff(long firstDelay, TimeUnit unit) {
276 			return exponentialBackoff(firstDelay, unit, 2);
277 		}
278 
279 		public Func1<Observable<? extends Throwable>, Observable<?>> build() {
280 			Preconditions.checkNotNull(delays);
281 			if (maxRetries.isPresent()) {
282 				delays = delays.take(maxRetries.get());
283 			}
284 			return notificationHandler(delays, scheduler.get(), action, retryExceptions, failExceptions,
285 					exceptionPredicate);
286 		}
287 
288 	}
289 
290 	public static final class ErrorAndDuration {
291 
292 		private final Throwable throwable;
293 		private final long durationMs;
294 
295 		public ErrorAndDuration(Throwable throwable, long durationMs) {
296 			this.throwable = throwable;
297 			this.durationMs = durationMs;
298 		}
299 
300 		public Throwable throwable() {
301 			return throwable;
302 		}
303 
304 		public long durationMs() {
305 			return durationMs;
306 		}
307 
308 	}
309 }