View Javadoc
1   package com.github.davidmoten.rx2.internal.flowable;
2   
3   import java.util.concurrent.Callable;
4   import java.util.concurrent.atomic.AtomicInteger;
5   import java.util.concurrent.atomic.AtomicLong;
6   
7   import org.reactivestreams.Subscriber;
8   import org.reactivestreams.Subscription;
9   
10  import com.github.davidmoten.guavamini.Preconditions;
11  import com.github.davidmoten.rx2.StateMachine.Emitter;
12  import com.github.davidmoten.rx2.functions.Consumer3;
13  
14  import io.reactivex.BackpressureStrategy;
15  import io.reactivex.Flowable;
16  import io.reactivex.FlowableSubscriber;
17  import io.reactivex.exceptions.Exceptions;
18  import io.reactivex.functions.BiConsumer;
19  import io.reactivex.functions.Function3;
20  import io.reactivex.internal.functions.ObjectHelper;
21  import io.reactivex.internal.fuseable.SimplePlainQueue;
22  import io.reactivex.internal.queue.SpscLinkedArrayQueue;
23  import io.reactivex.internal.subscriptions.SubscriptionHelper;
24  import io.reactivex.internal.util.BackpressureHelper;
25  import io.reactivex.plugins.RxJavaPlugins;
26  
27  public final class FlowableStateMachine<State, In, Out> extends Flowable<Out> {
28  
29      private final Flowable<In> source;
30      private final Callable<? extends State> initialState;
31      private final Function3<? super State, ? super In, ? super Emitter<Out>, ? extends State> transition;
32      private final BiConsumer<? super State, ? super Emitter<Out>> completionAction;
33      private final Consumer3<? super State, ? super Throwable, ? super Emitter<Out>> errorAction;
34      private final BackpressureStrategy backpressureStrategy;
35      private final int requestBatchSize;
36  
37      public FlowableStateMachine(Flowable<In> source, //
38              Callable<? extends State> initialState, //
39              Function3<? super State, ? super In, ? super Emitter<Out>, ? extends State> transition, //
40              BiConsumer<? super State, ? super Emitter<Out>> completionAction, //
41              Consumer3<? super State, ? super Throwable, ? super Emitter<Out>> errorAction, //
42              BackpressureStrategy backpressureStrategy, //
43              int requestBatchSize) {
44          Preconditions.checkNotNull(initialState);
45          Preconditions.checkNotNull(transition);
46          Preconditions.checkNotNull(backpressureStrategy);
47          Preconditions.checkArgument(requestBatchSize > 0,
48                  "initialRequest must be greater than zero");
49          this.source = source;
50          this.initialState = initialState;
51          this.transition = transition;
52          this.completionAction = completionAction;
53          this.errorAction = errorAction;
54          this.backpressureStrategy = backpressureStrategy;
55          this.requestBatchSize = requestBatchSize;
56      }
57  
58      @Override
59      protected void subscribeActual(Subscriber<? super Out> child) {
60          source.subscribe(new StateMachineSubscriber<State, In, Out>(initialState, transition,
61                  completionAction, errorAction, backpressureStrategy, requestBatchSize, child));
62      }
63  
64      @SuppressWarnings("serial")
65      private static final class StateMachineSubscriber<State, In, Out> extends AtomicInteger
66              implements FlowableSubscriber<In>, Subscription, Emitter<Out> {
67          private final Callable<? extends State> initialState;
68          private final Function3<? super State, ? super In, ? super Emitter<Out>, ? extends State> transition;
69          private final BiConsumer<? super State, ? super Emitter<Out>> completionAction;
70          private final Consumer3<? super State, ? super Throwable, ? super Emitter<Out>> errorAction;
71          @SuppressWarnings("unused")
72          private final BackpressureStrategy backpressureStrategy; // TODO
73                                                                   // implement
74          private final int requestBatchSize;
75          private final SimplePlainQueue<Out> queue = new SpscLinkedArrayQueue<Out>(16);
76          private final Subscriber<? super Out> child;
77          private final AtomicLong requested = new AtomicLong();
78  
79          private Subscription parent;
80          private volatile boolean cancelled;
81          private State state;
82          private boolean done;
83          private volatile boolean done_;
84          private Throwable error_;
85          private boolean drainCalled;
86          private long count; // counts down arrival of last request batch
87  
88          // indicates to drain method that we can request more if needed
89          private volatile boolean requestsArrived = true;
90  
91          StateMachineSubscriber( //
92                  Callable<? extends State> initialState,
93                  Function3<? super State, ? super In, ? super Emitter<Out>, ? extends State> transition, //
94                  BiConsumer<? super State, ? super Emitter<Out>> completionAction, //
95                  Consumer3<? super State, ? super Throwable, ? super Emitter<Out>> errorAction, //
96                  BackpressureStrategy backpressureStrategy, //
97                  int requestBatchSize, //
98                  Subscriber<? super Out> child) {
99              this.initialState = initialState;
100             this.transition = transition;
101             this.completionAction = completionAction;
102             this.errorAction = errorAction;
103             this.backpressureStrategy = backpressureStrategy;
104             this.requestBatchSize = requestBatchSize;
105             this.child = child;
106             this.count = requestBatchSize;
107         }
108 
109         @Override
110         public void onSubscribe(Subscription parent) {
111             if (SubscriptionHelper.validate(this.parent, parent)) {
112                 this.parent = parent;
113                 child.onSubscribe(this);
114             }
115         }
116 
117         @Override
118         public void onNext(In t) {
119             if (done) {
120                 return;
121             }
122             if (!createdState()) {
123                 return;
124             }
125             if (--count == 0) {
126                 requestsArrived = true;
127                 count = requestBatchSize;
128             }
129             try {
130                 drainCalled = false;
131                 state = ObjectHelper.requireNonNull(transition.apply(state, t, this),
132                         "intermediate state cannot be null");
133             } catch (Throwable e) {
134                 Exceptions.throwIfFatal(e);
135                 onError(e);
136                 return;
137             }
138             if (!drainCalled) {
139                 drain();
140             }
141         }
142 
143         private boolean createdState() {
144             if (state == null) {
145                 try {
146                     state = ObjectHelper.requireNonNull(initialState.call(),
147                             "initial state cannot be null");
148                     return true;
149                 } catch (Throwable e) {
150                     Exceptions.throwIfFatal(e);
151                     done = true;
152                     onError_(e);
153                     return false;
154                 }
155             } else {
156                 return true;
157             }
158         }
159 
160         @Override
161         public void onError(Throwable e) {
162             if (done) {
163                 RxJavaPlugins.onError(e);
164                 return;
165             }
166             done = true;
167             if (!createdState()) {
168                 return;
169             }
170             if (errorAction != null) {
171                 try {
172                     errorAction.accept(state, e, this);
173                 } catch (Throwable err) {
174                     Exceptions.throwIfFatal(e);
175                     onError_(err);
176                     return;
177                 }
178             } else {
179                 onError_(e);
180             }
181         }
182 
183         @Override
184         public void onComplete() {
185             if (done) {
186                 return;
187             }
188             if (!createdState()) {
189                 return;
190             }
191             try {
192                 if (completionAction != null) {
193                     completionAction.accept(state, this);
194                 } else {
195                     onComplete_();
196                 }
197                 done = true;
198             } catch (Throwable e) {
199                 Exceptions.throwIfFatal(e);
200                 onError(e);
201                 return;
202             }
203         }
204 
205         @Override
206         public void request(long n) {
207             if (SubscriptionHelper.validate(n)) {
208                 BackpressureHelper.add(requested, n);
209                 drain();
210             }
211         }
212 
213         @Override
214         public void cancel() {
215             cancelled = true;
216             parent.cancel();
217         }
218 
219         @Override
220         public void cancel_() {
221             cancel();
222         }
223 
224         @Override
225         public void onNext_(Out t) {
226             if (done_) {
227                 return;
228             }
229             queue.offer(t);
230             drain();
231         }
232 
233         @Override
234         public void onError_(Throwable e) {
235             if (done_) {
236                 RxJavaPlugins.onError(e);
237                 return;
238             }
239             error_ = e;
240             done_ = true;
241             drain();
242         }
243 
244         @Override
245         public void onComplete_() {
246             if (done_) {
247                 return;
248             }
249             done_ = true;
250             drain();
251         }
252 
253         public void drain() {
254             drainCalled = true;
255             if (getAndIncrement() == 0) {
256                 int missed = 1;
257                 while (true) {
258                     boolean reqsArrived = requestsArrived;
259                     long r = requested.get();
260                     long e = 0;
261                     while (e != r) {
262                         if (cancelled) {
263                             return;
264                         }
265                         boolean d = done_;
266                         Out t = queue.poll();
267                         if (t == null) {
268                             if (d) {
269                                 Throwable err = error_;
270                                 if (err != null) {
271                                     cancel();
272                                     queue.clear();
273                                     child.onError(err);
274                                 } else {
275                                     cancel();
276                                     queue.clear();
277                                     child.onComplete();
278                                 }
279                                 return;
280                             } else {
281                                 break;
282                             }
283                         } else {
284                             child.onNext(t);
285                             e++;
286                         }
287                     }
288                     if (e != 0 && r != Long.MAX_VALUE) {
289                         requested.addAndGet(-e);
290                     }
291                     if (e != r && reqsArrived) {
292                         requestsArrived = false;
293                         parent.request(requestBatchSize);
294                     }
295                     missed = addAndGet(-missed);
296                     if (missed == 0) {
297                         return;
298                     }
299                 }
300             }
301         }
302     }
303 }