1 package com.github.davidmoten.rx.testing;
2
3 import static com.github.davidmoten.util.Optional.of;
4
5 import java.util.ArrayList;
6 import java.util.Arrays;
7 import java.util.Collection;
8 import java.util.Collections;
9 import java.util.List;
10 import java.util.concurrent.CountDownLatch;
11 import java.util.concurrent.TimeUnit;
12 import java.util.concurrent.atomic.AtomicLong;
13
14 import junit.framework.TestCase;
15 import junit.framework.TestSuite;
16
17 import org.junit.runner.RunWith;
18 import org.junit.runners.Suite;
19 import org.junit.runners.Suite.SuiteClasses;
20
21 import rx.Observable;
22 import rx.Subscriber;
23 import rx.functions.Action0;
24 import rx.functions.Func1;
25
26 import com.github.davidmoten.util.Optional;
27 import com.github.davidmoten.util.Preconditions;
28
29
30
31
32 public final class TestingHelper {
33
34 private static final Optional<Long> ABSENT = Optional.absent();
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49 public static <T, R> Builder<T, R> function(Func1<Observable<T>, Observable<R>> function) {
50 return new Builder<T, R>().function(function);
51 }
52
53
54
55
56
57
58
59
60
61
62 public static class Builder<T, R> {
63
64 private final List<Case<T, R>> cases = new ArrayList<Case<T, R>>();
65 private Func1<Observable<T>, Observable<R>> function;
66 private long waitForUnusbscribeMs = 100;
67 private long waitForTerminalEventMs = 10000;
68 private long waitForMoreTerminalEventsMs = 50;
69
70 private Builder() {
71
72 }
73
74
75
76
77
78
79
80
81 public Builder<T, R> function(Func1<Observable<T>, Observable<R>> function) {
82 Preconditions.checkNotNull(function, "function cannot be null");
83 this.function = function;
84 return this;
85 }
86
87
88
89
90
91
92
93
94
95
96
97 public Builder<T, R> waitForUnsubscribe(long duration, TimeUnit unit) {
98 Preconditions.checkNotNull(unit, "unit cannot be null");
99 waitForUnusbscribeMs = unit.toMillis(duration);
100 return this;
101 }
102
103
104
105
106
107
108
109
110
111
112
113 public Builder<T, R> waitForTerminalEvent(long duration, TimeUnit unit) {
114 Preconditions.checkNotNull(unit, "unit cannot be null");
115 waitForTerminalEventMs = unit.toMillis(duration);
116 return this;
117 }
118
119
120
121
122
123
124
125
126
127
128
129 public Builder<T, R> waitForMoreTerminalEvents(long duration, TimeUnit unit) {
130 Preconditions.checkNotNull(unit, "unit cannot be null");
131 waitForMoreTerminalEventsMs = unit.toMillis(duration);
132 return this;
133 }
134
135
136
137
138
139
140
141
142 public CaseBuilder<T, R> name(String name) {
143 Preconditions.checkNotNull(name, "name cannot be null");
144 return new CaseBuilder<T, R>(this, Observable.<T> empty(), name);
145 }
146
147
148
149
150
151
152
153
154
155
156 public TestSuite testSuite(Class<?> cls) {
157 Preconditions.checkNotNull(cls, "cls cannot be null");
158 return new TestSuiteFromCases<T, R>(cls, new ArrayList<Case<T, R>>(this.cases));
159 }
160
161 private Builder<T, R> expect(Observable<T> from, Optional<List<R>> expected,
162 boolean ordered, Optional<Long> expectSize, boolean checkSourceUnsubscribed,
163 String name, Optional<Integer> unsubscribeAfter,
164 Optional<Class<? extends Throwable>> expectError,
165 Optional<Class<? extends RuntimeException>> expectException) {
166 cases.add(new Case<T, R>(from, expected, ordered, expectSize, checkSourceUnsubscribed,
167 function, name, unsubscribeAfter, expectError, waitForUnusbscribeMs,
168 waitForTerminalEventMs, waitForMoreTerminalEventsMs, expectException));
169 return this;
170 }
171 }
172
173 public static class CaseBuilder<T, R> {
174 private final Builder<T, R> builder;
175 private String name;
176 private Observable<T> from = Observable.empty();
177 private boolean checkSourceUnsubscribed = true;
178 private Optional<Integer> unsubscribeAfter = Optional.absent();
179
180 private CaseBuilder(Builder<T, R> builder, Observable<T> from, String name) {
181 Preconditions.checkNotNull(builder);
182 Preconditions.checkNotNull(from);
183 Preconditions.checkNotNull(name);
184 this.builder = builder;
185 this.from = from;
186 this.name = name;
187 }
188
189 public CaseBuilder<T, R> name(String name) {
190 Preconditions.checkNotNull(name, "name cannot be null");
191 this.name = name;
192 return this;
193 }
194
195 public CaseBuilder<T, R> fromEmpty() {
196 from = Observable.empty();
197 return this;
198 }
199
200 public CaseBuilder<T, R> from(T... source) {
201 Preconditions.checkNotNull(source, "source cannot be null");
202 from = Observable.from(source);
203 return this;
204 }
205
206 public CaseBuilder<T, R> from(Observable<T> source) {
207 Preconditions.checkNotNull(source, "source cannot be null");
208 from = source;
209 return this;
210 }
211
212 public CaseBuilder<T, R> fromError() {
213 from = Observable.error(new TestingException());
214 return this;
215 }
216
217 public CaseBuilder<T, R> fromErrorAfter(T... source) {
218 Preconditions.checkNotNull(source, "source cannot be null");
219 from = Observable.from(source).concatWith(Observable.<T> error(new TestingException()));
220 return this;
221 }
222
223 public CaseBuilder<T, R> fromErrorAfter(Observable<T> source) {
224 Preconditions.checkNotNull(source, "source cannot be null");
225 from = source;
226 return this;
227 }
228
229 public CaseBuilder<T, R> skipUnsubscribedCheck() {
230 this.checkSourceUnsubscribed = false;
231 return this;
232 }
233
234 public Builder<T, R> expectEmpty() {
235 return expect(Collections.<R> emptyList());
236 }
237
238 public Builder<T, R> expectError() {
239 return expectError(TestingException.class);
240 }
241
242 @SuppressWarnings("unchecked")
243 public Builder<T, R> expectError(Class<? extends Throwable> cls) {
244 Preconditions.checkNotNull(cls, "cls cannot be null");
245 return builder.expect(from, Optional.<List<R>> absent(), true, ABSENT,
246 checkSourceUnsubscribed, name, unsubscribeAfter,
247 (Optional<Class<? extends Throwable>>) (Optional<?>) of(cls),
248 Optional.<Class<? extends RuntimeException>> absent());
249 }
250
251 public Builder<T, R> expect(R... source) {
252 Preconditions.checkNotNull(source, "source cannot be null");
253 return expect(Arrays.asList(source));
254 }
255
256 public Builder<T, R> expectSize(long n) {
257 return builder.expect(from, Optional.<List<R>> absent(), true, of(n),
258 checkSourceUnsubscribed, name, unsubscribeAfter,
259 Optional.<Class<? extends Throwable>> absent(),
260 Optional.<Class<? extends RuntimeException>> absent());
261 }
262
263 public Builder<T, R> expect(List<R> source) {
264 Preconditions.checkNotNull(source, "source cannot be null");
265 return expect(source, true);
266 }
267
268 private Builder<T, R> expect(List<R> items, boolean ordered) {
269 return builder.expect(from, of(items), ordered, ABSENT, checkSourceUnsubscribed, name,
270 unsubscribeAfter, Optional.<Class<? extends Throwable>> absent(),
271 Optional.<Class<? extends RuntimeException>> absent());
272 }
273
274 public Builder<T, R> expectAnyOrder(R... source) {
275 Preconditions.checkNotNull(source, "source cannot be null");
276 return expect(Arrays.asList(source), false);
277 }
278
279 public CaseBuilder<T, R> unsubscribeAfter(int n) {
280 unsubscribeAfter = of(n);
281 return this;
282 }
283
284 @SuppressWarnings("unchecked")
285 public Builder<T, R> expectException(Class<? extends RuntimeException> cls) {
286 return builder.expect(from, Optional.<List<R>> absent(), true, ABSENT,
287 checkSourceUnsubscribed, name, unsubscribeAfter,
288 Optional.<Class<? extends Throwable>> absent(),
289 (Optional<Class<? extends RuntimeException>>) (Optional<?>) Optional.of(cls));
290 }
291
292 }
293
294 private static class Case<T, R> {
295 final String name;
296 final Observable<T> from;
297 final Optional<List<R>> expected;
298 final boolean checkSourceUnsubscribed;
299 final Func1<Observable<T>, Observable<R>> function;
300 final Optional<Integer> unsubscribeAfter;
301 final boolean ordered;
302 final Optional<Long> expectSize;
303 final Optional<Class<? extends Throwable>> expectError;
304 final long waitForUnusbscribeMs;
305 final long waitForTerminalEventMs;
306 final long waitForMoreTerminalEventsMs;
307 final Optional<Class<? extends RuntimeException>> expectedException;
308
309 Case(Observable<T> from, Optional<List<R>> expected, boolean ordered,
310 Optional<Long> expectSize, boolean checkSourceUnsubscribed,
311 Func1<Observable<T>, Observable<R>> function, String name,
312 Optional<Integer> unsubscribeAfter,
313 Optional<Class<? extends Throwable>> expectError, long waitForUnusbscribeMs,
314 long waitForTerminalEventMs, long waitForMoreTerminalEventsMs,
315 Optional<Class<? extends RuntimeException>> expectedException) {
316 Preconditions.checkNotNull(from);
317 Preconditions.checkNotNull(expected);
318 Preconditions.checkNotNull(expectSize);
319 Preconditions.checkNotNull(function);
320 Preconditions.checkNotNull(name);
321 Preconditions.checkNotNull(unsubscribeAfter);
322 Preconditions.checkNotNull(expectError);
323 Preconditions.checkNotNull(expectedException);
324 this.from = from;
325 this.expected = expected;
326 this.ordered = ordered;
327 this.expectSize = expectSize;
328 this.checkSourceUnsubscribed = checkSourceUnsubscribed;
329 this.function = function;
330 this.name = name;
331 this.unsubscribeAfter = unsubscribeAfter;
332 this.expectError = expectError;
333 this.waitForUnusbscribeMs = waitForUnusbscribeMs;
334 this.waitForTerminalEventMs = waitForTerminalEventMs;
335 this.waitForMoreTerminalEventsMs = waitForMoreTerminalEventsMs;
336 this.expectedException = expectedException;
337 }
338 }
339
340 private static <T, R> void runTest(Case<T, R> c, TestType testType) {
341 try {
342 CountDownLatch sourceUnsubscribeLatch = new CountDownLatch(1);
343 MyTestSubscriber<R> sub = createTestSubscriber(testType, c.unsubscribeAfter);
344 c.function.call(c.from.doOnUnsubscribe(countDown(sourceUnsubscribeLatch)))
345 .subscribe(sub);
346 if (c.unsubscribeAfter.isPresent()) {
347 waitForUnsubscribe(sourceUnsubscribeLatch, c.waitForUnusbscribeMs,
348 TimeUnit.MILLISECONDS);
349
350
351 } else {
352 sub.awaitTerminalEvent(c.waitForTerminalEventMs, TimeUnit.MILLISECONDS);
353 if (c.expectError.isPresent()) {
354 sub.assertError(c.expectError.get());
355
356 pause(c.waitForMoreTerminalEventsMs, TimeUnit.MILLISECONDS);
357 if (sub.numOnCompletedEvents() > 0)
358 throw new UnexpectedOnCompletedException();
359 } else {
360 sub.assertNoErrors();
361
362 pause(c.waitForMoreTerminalEventsMs, TimeUnit.MILLISECONDS);
363 if (sub.numOnCompletedEvents() > 1)
364 throw new TooManyOnCompletedException();
365 sub.assertNoErrors();
366 }
367 }
368
369 if (c.expected.isPresent())
370 sub.assertReceivedOnNext(c.expected.get(), c.ordered);
371 if (c.expectSize.isPresent())
372 sub.assertReceivedCountIs(c.expectSize.get());
373 sub.assertUnsubscribed();
374 if (c.checkSourceUnsubscribed)
375 waitForUnsubscribe(sourceUnsubscribeLatch, c.waitForUnusbscribeMs,
376 TimeUnit.MILLISECONDS);
377 if (c.expectedException.isPresent())
378 throw new ExpectedExceptionNotThrownException();
379 } catch (RuntimeException e) {
380 if (!c.expectedException.isPresent() || !c.expectedException.get().isInstance(e))
381 throw e;
382
383 }
384 }
385
386 private static Action0 countDown(final CountDownLatch latch) {
387 return new Action0() {
388 @Override
389 public void call() {
390 latch.countDown();
391 }
392 };
393 }
394
395 private static <T> void waitForUnsubscribe(CountDownLatch latch, long duration, TimeUnit unit) {
396 try {
397 if (!latch.await(duration, unit))
398 throw new UnsubscriptionFromSourceTimeoutException();
399 } catch (InterruptedException e) {
400
401 }
402 }
403
404 public static class UnsubscriptionFromSourceTimeoutException extends RuntimeException {
405 private static final long serialVersionUID = -1142604414390722544L;
406 }
407
408 private static void pause(long duration, TimeUnit unit) {
409 try {
410 Thread.sleep(unit.toMillis(duration));
411 } catch (InterruptedException e) {
412
413 }
414 }
415
416 private static final class MyTestSubscriber<T> extends Subscriber<T> {
417
418 private final List<T> next = new ArrayList<T>();
419 private final Optional<Long> onStartRequest;
420 private final Optional<Long> onNextRequest;
421 private final Optional<Integer> unsubscribeAfter;
422 private final CountDownLatch terminalLatch;
423 private int completed = 0;
424 private int count = 0;
425 private int errors = 0;
426 private final AtomicLong expected = new AtomicLong();
427 private Optional<Throwable> lastError = Optional.absent();
428 private Optional<Long> onNextRequest2;
429
430 MyTestSubscriber(Optional<Integer> unsubscribeAfter, final Optional<Long> onStartRequest,
431 final Optional<Long> onNextRequest, final Optional<Long> onNextRequest2) {
432 this.unsubscribeAfter = unsubscribeAfter;
433 this.onStartRequest = onStartRequest;
434 this.onNextRequest = onNextRequest;
435 this.onNextRequest2 = onNextRequest2;
436 this.terminalLatch = new CountDownLatch(1);
437 }
438
439 MyTestSubscriber(Optional<Integer> unsubscribeAfter) {
440 this(unsubscribeAfter, ABSENT, ABSENT, ABSENT);
441 }
442
443 @Override
444 public void onStart() {
445 if (!onStartRequest.isPresent())
446
447 expected.set(Long.MAX_VALUE);
448 else
449 expected.set(0);
450 if (onStartRequest.isPresent())
451 requestMore(onStartRequest.get());
452 }
453
454 private void requestMore(long n) {
455 if (expected.get() != Long.MAX_VALUE) {
456 if (n > 0)
457 expected.addAndGet(n);
458
459 request(n);
460 }
461 }
462
463 @Override
464 public void onCompleted() {
465 completed++;
466 terminalLatch.countDown();
467 }
468
469 @Override
470 public void onError(Throwable e) {
471 errors++;
472 lastError = of(e);
473 terminalLatch.countDown();
474 }
475
476 @Override
477 public void onNext(T t) {
478 final long exp;
479 if (expected.get() != Long.MAX_VALUE)
480 exp = expected.decrementAndGet();
481 else
482 exp = expected.get();
483 next.add(t);
484 count++;
485 if (exp < 0)
486 onError(new DeliveredMoreThanRequestedException());
487 else if (unsubscribeAfter.isPresent() && count == unsubscribeAfter.get())
488 unsubscribe();
489 else {
490 if (onNextRequest.isPresent())
491 requestMore(onNextRequest.get());
492 if (onNextRequest2.isPresent())
493 requestMore(onNextRequest2.get());
494 }
495 }
496
497 void assertError(Class<?> cls) {
498 if (errors != 1 || !cls.isInstance(lastError.get()))
499 throw new ExpectedErrorNotReceivedException();
500 }
501
502 void assertReceivedCountIs(long count) {
503 if (count != next.size())
504 throw new WrongOnNextCountException();
505 }
506
507 void awaitTerminalEvent(long duration, TimeUnit unit) {
508 try {
509 if (!terminalLatch.await(duration, unit))
510 throw new TerminalEventTimeoutException();
511 } catch (InterruptedException e) {
512
513 }
514 }
515
516 void assertReceivedOnNext(List<T> expected, boolean ordered) {
517 if (!TestingHelper.equals(expected, next, ordered))
518 throw new UnexpectedOnNextException("expected=" + expected + ", actual=" + next);
519 }
520
521 void assertUnsubscribed() {
522 if (!isUnsubscribed())
523 throw new DownstreamUnsubscriptionDidNotOccurException();
524 }
525
526 int numOnCompletedEvents() {
527 return completed;
528 }
529
530 void assertNoErrors() {
531 if (errors > 0) {
532 lastError.get().printStackTrace();
533 throw new UnexpectedOnErrorException();
534 }
535 }
536
537 }
538
539 public static class TerminalEventTimeoutException extends RuntimeException {
540 private static final long serialVersionUID = -7355281653999339840L;
541 }
542
543 public static class ExpectedErrorNotReceivedException extends RuntimeException {
544 private static final long serialVersionUID = -567146145612029349L;
545 }
546
547 public static class ExpectedExceptionNotThrownException extends RuntimeException {
548 private static final long serialVersionUID = -104410457605712970L;
549 }
550
551 public static class WrongOnNextCountException extends RuntimeException {
552 private static final long serialVersionUID = 984672575527784559L;
553 }
554
555 public static class UnexpectedOnCompletedException extends RuntimeException {
556 private static final long serialVersionUID = 7164517608988798969L;
557 }
558
559 public static class UnexpectedOnErrorException extends RuntimeException {
560 private static final long serialVersionUID = -813740137771756205L;
561 }
562
563 public static class TooManyOnCompletedException extends RuntimeException {
564 private static final long serialVersionUID = -405328882928962333L;
565 }
566
567 public static class DownstreamUnsubscriptionDidNotOccurException extends RuntimeException {
568 private static final long serialVersionUID = 7218646111664183642L;
569 }
570
571 public static class UnexpectedOnNextException extends RuntimeException {
572 private static final long serialVersionUID = -3656406263739222767L;
573
574 public UnexpectedOnNextException(String message) {
575 super(message);
576 }
577
578 }
579
580 private static enum TestType {
581 WITHOUT_BACKP, BACKP_INITIAL_REQUEST_MAX, BACKP_INITIAL_REQUEST_MAX_THEN_BY_ONE, BACKP_ONE_BY_ONE, BACKP_TWO_BY_TWO, BACKP_REQUEST_ZERO, BACKP_FIVE_BY_FIVE, BACKP_FIFTY_BY_FIFTY, BACKP_THOUSAND_BY_THOUSAND, BACKP_REQUEST_OVERFLOW;
582 }
583
584 private static <T> MyTestSubscriber<T> createTestSubscriber(Optional<Integer> unsubscribeAfter,
585 long onStartRequest, Optional<Long> onNextRequest) {
586 return new MyTestSubscriber<T>(unsubscribeAfter, of(onStartRequest), onNextRequest, ABSENT);
587 }
588
589 private static <T> MyTestSubscriber<T> createTestSubscriber(TestType testType,
590 final Optional<Integer> unsubscribeAfter) {
591
592 if (testType == TestType.WITHOUT_BACKP)
593 return new MyTestSubscriber<T>(unsubscribeAfter);
594 else if (testType == TestType.BACKP_INITIAL_REQUEST_MAX)
595 return createTestSubscriber(unsubscribeAfter, Long.MAX_VALUE, ABSENT);
596 else if (testType == TestType.BACKP_INITIAL_REQUEST_MAX_THEN_BY_ONE)
597 return createTestSubscriber(unsubscribeAfter, Long.MAX_VALUE, of(1L));
598 else if (testType == TestType.BACKP_ONE_BY_ONE)
599 return createTestSubscriber(unsubscribeAfter, 1L, of(1L));
600 else if (testType == TestType.BACKP_REQUEST_ZERO)
601 return new MyTestSubscriber<T>(unsubscribeAfter, of(1L), of(0L), of(1L));
602 else if (testType == TestType.BACKP_REQUEST_OVERFLOW)
603 return new MyTestSubscriber<T>(unsubscribeAfter, of(1L), of(Long.MAX_VALUE / 3 * 2),
604 of(Long.MAX_VALUE / 3 * 2));
605 else if (testType == TestType.BACKP_TWO_BY_TWO)
606 return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 2);
607 else if (testType == TestType.BACKP_FIVE_BY_FIVE)
608 return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 5);
609 else if (testType == TestType.BACKP_FIFTY_BY_FIFTY)
610 return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 50);
611 else if (testType == TestType.BACKP_THOUSAND_BY_THOUSAND)
612 return createTestSubscriberWithBackpNbyN(unsubscribeAfter, 1000);
613 else
614 throw new RuntimeException(testType + " not implemented");
615
616 }
617
618 private static <T> MyTestSubscriber<T> createTestSubscriberWithBackpNbyN(
619 final Optional<Integer> unsubscribeAfter, final long requestSize) {
620 return new MyTestSubscriber<T>(unsubscribeAfter, of(requestSize), ABSENT, of(requestSize));
621 }
622
623 @RunWith(Suite.class)
624 @SuiteClasses({})
625 private static class TestSuiteFromCases<T, R> extends TestSuite {
626
627 TestSuiteFromCases(Class<?> cls, List<Case<T, R>> cases) {
628 super(cls);
629 for (Case<T, R> c : cases) {
630 for (TestType testType : TestType.values())
631 if (testType != TestType.BACKP_REQUEST_OVERFLOW)
632 addTest(new MyTestCase<T, R>(c.name + "_" + testType.name(), c, testType));
633 }
634 }
635 }
636
637 private static class MyTestCase<T, R> extends TestCase {
638
639 private final Case<T, R> c;
640 private final TestType testType;
641
642 MyTestCase(String name, Case<T, R> c, TestType testType) {
643 super(name);
644 this.c = c;
645 this.testType = testType;
646 }
647
648 @Override
649 protected void runTest() throws Throwable {
650 TestingHelper.runTest(c, testType);
651 }
652
653 }
654
655 private static <T> boolean equals(Collection<T> a, Collection<T> b, boolean ordered) {
656 if (a == null)
657 return b == null;
658 else if (b == null)
659 return a == null;
660 else if (a.size() != b.size())
661 return false;
662 else if (ordered)
663 return a.equals(b);
664 else {
665 List<T> list = new ArrayList<T>(a);
666 for (T t : b) {
667 if (!list.remove(t))
668 return false;
669 }
670 return true;
671 }
672 }
673
674 private static class TestingException extends RuntimeException {
675
676 private static final long serialVersionUID = 4467514769366847747L;
677
678 }
679
680
681
682
683
684 public static class DeliveredMoreThanRequestedException extends RuntimeException {
685 private static final long serialVersionUID = 1369440545774454215L;
686
687 public DeliveredMoreThanRequestedException() {
688 super("more items arrived than requested");
689 }
690 }
691
692
693
694
695 public static class AssertionException extends RuntimeException {
696 private static final long serialVersionUID = -6846674323693517388L;
697
698 public AssertionException(String message) {
699 super(message);
700 }
701 }
702
703
704
705
706
707
708
709
710
711 public static <T> Func1<Observable<T>, TestSubscriber2<T>> test() {
712 return TestSubscriber2.test();
713 }
714
715
716
717
718
719
720
721
722
723
724
725
726 public static <T> Func1<Observable<T>, TestSubscriber2<T>> testWithRequest(
727 long initialRequest) {
728 return TestSubscriber2.testWithRequest(initialRequest);
729 }
730
731 }