View Javadoc
1   package com.github.davidmoten.rx2.internal.flowable;
2   
3   import java.util.HashMap;
4   import java.util.LinkedList;
5   import java.util.Map;
6   import java.util.Queue;
7   import java.util.concurrent.atomic.AtomicInteger;
8   import java.util.concurrent.atomic.AtomicLong;
9   import java.util.concurrent.atomic.AtomicReference;
10  
11  import org.reactivestreams.Subscriber;
12  import org.reactivestreams.Subscription;
13  
14  import com.github.davidmoten.guavamini.Preconditions;
15  
16  import io.reactivex.Flowable;
17  import io.reactivex.FlowableSubscriber;
18  import io.reactivex.exceptions.Exceptions;
19  import io.reactivex.functions.BiFunction;
20  import io.reactivex.functions.Function;
21  import io.reactivex.internal.fuseable.SimpleQueue;
22  import io.reactivex.internal.queue.MpscLinkedQueue;
23  import io.reactivex.internal.subscriptions.SubscriptionHelper;
24  import io.reactivex.internal.util.BackpressureHelper;
25  
26  public final class FlowableMatch<A, B, K, C> extends Flowable<C> {
27  
28      private final Flowable<A> a;
29      private final Flowable<B> b;
30      private final Function<? super A, ? extends K> aKey;
31      private final Function<? super B, ? extends K> bKey;
32      private final BiFunction<? super A, ? super B, C> combiner;
33      private final long requestSize;
34  
35      public FlowableMatch(Flowable<A> a, Flowable<B> b, Function<? super A, ? extends K> aKey,
36              Function<? super B, ? extends K> bKey, BiFunction<? super A, ? super B, C> combiner, long requestSize) {
37          Preconditions.checkNotNull(a, "a should not be null");
38          Preconditions.checkNotNull(b, "b should not be null");
39          Preconditions.checkNotNull(aKey, "aKey cannot be null");
40          Preconditions.checkNotNull(bKey, "bKey cannot be null");
41          Preconditions.checkNotNull(combiner, "combiner cannot be null");
42          Preconditions.checkArgument(requestSize >= 1, "requestSize must be >=1");
43          this.a = a;
44          this.b = b;
45          this.aKey = aKey;
46          this.bKey = bKey;
47          this.combiner = combiner;
48          this.requestSize = requestSize;
49      }
50  
51      @Override
52      protected void subscribeActual(Subscriber<? super C> child) {
53          MatchCoordinator<A, B, K, C> coordinator = new MatchCoordinator<A, B, K, C>(aKey, bKey, combiner, requestSize,
54                  child);
55          child.onSubscribe(coordinator);
56          coordinator.subscribe(a, b);
57      }
58  
59      interface Receiver {
60          void offer(Object item);
61      }
62  
63      @SuppressWarnings("serial")
64      private static final class MatchCoordinator<A, B, K, C> extends AtomicInteger implements Receiver, Subscription {
65          private final Map<K, Queue<A>> as = new HashMap<K, Queue<A>>();
66          private final Map<K, Queue<B>> bs = new HashMap<K, Queue<B>>();
67          private final Function<? super A, ? extends K> aKey;
68          private final Function<? super B, ? extends K> bKey;
69          private final BiFunction<? super A, ? super B, C> combiner;
70          private final long requestSize;
71          private final transient SimpleQueue<Object> queue;
72          private final Subscriber<? super C> child;
73          private final AtomicLong requested = new AtomicLong(0);
74  
75          // mutable fields, guarded by `this` atomics
76          private int requestFromA = 0;
77          private int requestFromB = 0;
78  
79          // completion state machine
80          private int completed = COMPLETED_NONE;
81          // completion states
82          private static final int COMPLETED_NONE = 0;
83          private static final int COMPLETED_A = 1;
84          private static final int COMPLETED_B = 2;
85          private static final int COMPLETED_BOTH = 3;
86  
87          private MySubscriber<A, K> aSub;
88          private MySubscriber<B, K> bSub;
89  
90          private volatile boolean cancelled = false;
91  
92          MatchCoordinator(Function<? super A, ? extends K> aKey, Function<? super B, ? extends K> bKey,
93                  BiFunction<? super A, ? super B, C> combiner, long requestSize, Subscriber<? super C> child) {
94              this.aKey = aKey;
95              this.bKey = bKey;
96              this.combiner = combiner;
97              this.requestSize = requestSize;
98              this.queue = new MpscLinkedQueue<Object>();
99              this.child = child;
100         }
101 
102         public void subscribe(Flowable<A> a, Flowable<B> b) {
103             aSub = new MySubscriber<A, K>(Source.A, this, requestSize);
104             bSub = new MySubscriber<B, K>(Source.B, this, requestSize);
105             a.subscribe(aSub);
106             b.subscribe(bSub);
107         }
108 
109         @Override
110         public void request(long n) {
111             if (SubscriptionHelper.validate(n)) {
112                 BackpressureHelper.add(requested, n);
113                 drain();
114             }
115         }
116 
117         @Override
118         public void cancel() {
119             if (!cancelled) {
120                 cancelled = true;
121                 cancelAll();
122             }
123         }
124 
125         void cancelAll() {
126             aSub.cancel();
127             bSub.cancel();
128         }
129 
130         void drain() {
131             if (getAndIncrement() != 0) {
132                 // work already in progress
133                 // so exit
134                 return;
135             }
136             int missed = 1;
137             while (true) {
138                 long r = requested.get();
139                 long emitted = 0;
140                 while (emitted != r) {
141                     if (cancelled) {
142                         return;
143                     }
144                     // note no null values on the queue
145                     Object v;
146                     try {
147                         v = queue.poll();
148                     } catch (Exception e) {
149                         Exceptions.throwIfFatal(e);
150                         clear();
151                         child.onError(e);
152                         return;
153                     }
154                     if (v == null) {
155                         // queue is empty
156                         break;
157                     } else if (v instanceof ItemA) {
158                         Emitted em = handleItem(((ItemA) v).value, Source.A);
159                         if (em == Emitted.FINISHED) {
160                             return;
161                         } else if (em == Emitted.ONE) {
162                             emitted += 1;
163                         }
164                     } else if (v instanceof Source) {
165                         // source completed
166                         Status status = handleCompleted((Source) v);
167                         if (status == Status.FINISHED) {
168                             return;
169                         }
170                     } else if (v instanceof MyError) {
171                         // v must be an error
172                         clear();
173                         child.onError(((MyError) v).error);
174                         return;
175                     } else {
176                         // is onNext from B
177                         Emitted em = handleItem(v, Source.B);
178                         if (em == Emitted.FINISHED) {
179                             return;
180                         } else if (em == Emitted.ONE) {
181                             emitted += 1;
182                         }
183                     }
184                     if (r == emitted) {
185                         break;
186                     }
187                 }
188                 // reduce requested by emitted which will always be positive
189                 BackpressureHelper.produced(requested, emitted);
190                 missed = this.addAndGet(-missed);
191                 if (missed == 0) {
192                     return;
193                 }
194             }
195         }
196 
197         private Emitted handleItem(Object value, Source source) {
198             final Emitted result;
199 
200             // logic duplication occurs below
201             // would be nice to simplify without making code
202             // unreadable. A bit of a toss-up.
203             if (source == Source.A) {
204                 // look for match
205                 @SuppressWarnings("unchecked")
206                 A a = (A) value;
207                 K key;
208                 try {
209                     key = aKey.apply(a);
210                 } catch (Throwable e) {
211                     clear();
212                     child.onError(e);
213                     return Emitted.FINISHED;
214                 }
215                 Queue<B> q = bs.get(key);
216                 if (q == null) {
217                     // cache value
218                     add(as, key, a);
219                     result = Emitted.NONE;
220                 } else {
221                     // emit match
222                     B b = poll(bs, q, key);
223                     C c;
224                     try {
225                         c = combiner.apply(a, b);
226                     } catch (Throwable e) {
227                         clear();
228                         child.onError(e);
229                         return Emitted.FINISHED;
230                     }
231                     child.onNext(c);
232                     result = Emitted.ONE;
233                 }
234                 // if the other source has completed and there
235                 // is nothing to match with then we should stop
236                 if (completed == COMPLETED_B && bs.isEmpty()) {
237                     // can finish
238                     clear();
239                     child.onComplete();
240                     return Emitted.FINISHED;
241                 } else {
242                     requestFromA += 1;
243                 }
244             } else {
245                 // look for match
246                 @SuppressWarnings("unchecked")
247                 B b = (B) value;
248                 K key;
249                 try {
250                     key = bKey.apply(b);
251                 } catch (Throwable e) {
252                     clear();
253                     child.onError(e);
254                     return Emitted.FINISHED;
255                 }
256                 Queue<A> q = as.get(key);
257                 if (q == null) {
258                     // cache value
259                     add(bs, key, b);
260                     result = Emitted.NONE;
261                 } else {
262                     // emit match
263                     A a = poll(as, q, key);
264                     C c;
265                     try {
266                         c = combiner.apply(a, b);
267                     } catch (Throwable e) {
268                         clear();
269                         child.onError(e);
270                         return Emitted.FINISHED;
271                     }
272                     child.onNext(c);
273                     result = Emitted.ONE;
274                 }
275                 // if the other source has completed and there
276                 // is nothing to match with then we should stop
277                 if (completed == COMPLETED_A && as.isEmpty()) {
278                     // can finish
279                     clear();
280                     child.onComplete();
281                     return Emitted.FINISHED;
282                 } else {
283                     requestFromB += 1;
284                 }
285             }
286             // requests are batched so that each source gets a turn
287             checkToRequestMore();
288             return result;
289         }
290 
291         private enum Emitted {
292             ONE, NONE, FINISHED;
293         }
294 
295         private Status handleCompleted(Source source) {
296             completed(source);
297             final boolean done;
298             if (source == Source.A) {
299                 aSub.cancel();
300                 done = (completed == COMPLETED_BOTH) || (completed == COMPLETED_A && as.isEmpty());
301             } else {
302                 bSub.cancel();
303                 done = (completed == COMPLETED_BOTH) || (completed == COMPLETED_B && bs.isEmpty());
304             }
305             if (done) {
306                 clear();
307                 child.onComplete();
308                 return Status.FINISHED;
309             } else {
310                 checkToRequestMore();
311                 return Status.KEEP_GOING;
312             }
313         }
314 
315         private enum Status {
316             FINISHED, KEEP_GOING;
317         }
318 
319         private void checkToRequestMore() {
320             if (requestFromA == requestSize && completed == COMPLETED_B) {
321                 requestFromA = 0;
322                 aSub.request(requestSize);
323             } else if (requestFromB == requestSize && completed == COMPLETED_A) {
324                 requestFromB = 0;
325                 bSub.request(requestSize);
326             } else if (requestFromA == requestSize && requestFromB == requestSize) {
327                 requestFromA = 0;
328                 requestFromB = 0;
329                 aSub.request(requestSize);
330                 bSub.request(requestSize);
331             }
332         }
333 
334         private void completed(Source source) {
335             if (source == Source.A) {
336                 if (completed == COMPLETED_NONE) {
337                     completed = COMPLETED_A;
338                 } else if (completed == COMPLETED_B) {
339                     completed = COMPLETED_BOTH;
340                 }
341             } else {
342                 if (completed == COMPLETED_NONE) {
343                     completed = COMPLETED_B;
344                 } else if (completed == COMPLETED_A) {
345                     completed = COMPLETED_BOTH;
346                 }
347             }
348         }
349 
350         private void clear() {
351             as.clear();
352             bs.clear();
353             queue.clear();
354             aSub.cancel();
355             bSub.cancel();
356         }
357 
358         private static <K, T> void add(Map<K, Queue<T>> map, K key, T value) {
359             Queue<T> q = map.get(key);
360             if (q == null) {
361                 q = new LinkedList<T>();
362                 map.put(key, q);
363             }
364             q.offer(value);
365         }
366 
367         private static <K, T> T poll(Map<K, Queue<T>> map, Queue<T> q, K key) {
368             T t = q.poll();
369             if (q.isEmpty()) {
370                 map.remove(key);
371             }
372             return t;
373         }
374 
375         @Override
376         public void offer(Object item) {
377             queue.offer(item);
378             drain();
379         }
380 
381     }
382 
383     @SuppressWarnings("serial")
384     private static final class MySubscriber<T, K> extends AtomicReference<Subscription> implements FlowableSubscriber<T>, Subscription {
385 
386         private final Receiver receiver;
387         private final Source source;
388         private final long requestSize;
389 
390         MySubscriber(Source source, Receiver receiver, long requestSize) {
391             this.source = source;
392             this.receiver = receiver;
393             this.requestSize = requestSize;
394         }
395 
396         @Override
397         public void onSubscribe(Subscription subscription) {
398             if (SubscriptionHelper.setOnce(this, subscription)) {
399                 subscription.request(requestSize);
400             }
401         }
402 
403         @Override
404         public void request(long n) {
405             get().request(n);
406         }
407 
408         @Override
409         public void cancel() {
410             SubscriptionHelper.cancel(this);
411         }
412 
413         @Override
414         public void onNext(T t) {
415             if (source == Source.A) {
416                 receiver.offer(new ItemA(t));
417             } else {
418                 receiver.offer(t);
419             }
420         }
421 
422         @Override
423         public void onComplete() {
424             receiver.offer(source);
425         }
426 
427         @Override
428         public void onError(Throwable e) {
429             receiver.offer(new MyError(e));
430         }
431 
432     }
433 
434     private static final class MyError {
435         final Throwable error;
436 
437         MyError(Throwable error) {
438             this.error = error;
439         }
440     }
441 
442     private static final class ItemA {
443         final Object value;
444 
445         ItemA(Object value) {
446             this.value = value;
447         }
448     }
449 
450     private enum Source {
451         A, B;
452     }
453 
454 }