1 package com.github.davidmoten.rx.jdbc;
2
3 import java.sql.ResultSet;
4 import java.sql.SQLException;
5 import java.sql.Statement;
6 import java.util.List;
7
8 import org.slf4j.Logger;
9 import org.slf4j.LoggerFactory;
10
11 import rx.Observable;
12 import rx.Observable.OnSubscribe;
13 import rx.Subscriber;
14 import rx.Subscription;
15 import rx.functions.Action0;
16 import rx.subscriptions.Subscriptions;
17
18
19
20
21 final class QueryUpdateOnSubscribe<T> implements OnSubscribe<T> {
22
23 private static final Logger log = LoggerFactory.getLogger(QueryUpdateOnSubscribe.class);
24
25 static final String BEGIN_TRANSACTION = "begin";
26
27
28
29
30 static final String ROLLBACK = "rollback";
31
32
33
34
35 static final String COMMIT = "commit";
36
37
38
39
40
41
42
43
44
45 static <T> Observable<T> execute(QueryUpdate<T> query, List<Parameter> parameters) {
46 return Observable.create(new QueryUpdateOnSubscribe<T>(query, parameters));
47 }
48
49
50
51
52 private final QueryUpdate<T> query;
53
54
55
56
57
58
59 private final List<Parameter> parameters;
60
61
62
63
64
65
66
67 private QueryUpdateOnSubscribe(QueryUpdate<T> query, List<Parameter> parameters) {
68 this.query = query;
69 this.parameters = parameters;
70 }
71
72 @Override
73 public void call(Subscriber<? super T> subscriber) {
74 final State state = new State();
75 try {
76 if (isBeginTransaction())
77 performBeginTransaction(subscriber);
78 else {
79 getConnection(state);
80 subscriber.add(createUnsubscriptionAction(state));
81 if (isCommit())
82 performCommit(subscriber, state);
83 else if (isRollback())
84 performRollback(subscriber, state);
85 else
86 performUpdate(subscriber, state);
87 }
88 } catch (Exception e) {
89 query.context().endTransactionObserve();
90 query.context().endTransactionSubscribe();
91 try {
92 close(state);
93 } finally {
94 handleException(e, subscriber);
95 }
96 }
97 }
98
99 private Subscription createUnsubscriptionAction(final State state) {
100 return Subscriptions.create(new Action0() {
101 @Override
102 public void call() {
103 close(state);
104 }
105 });
106 }
107
108 private boolean isBeginTransaction() {
109 return query.sql().equals(BEGIN_TRANSACTION);
110 }
111
112 @SuppressWarnings("unchecked")
113 private void performBeginTransaction(Subscriber<? super T> subscriber) {
114 query.context().beginTransactionObserve();
115 log.debug("beginTransaction emitting 1");
116 subscriber.onNext((T) Integer.valueOf(1));
117 log.debug("emitted 1");
118 complete(subscriber);
119 }
120
121
122
123
124 private void getConnection(State state) {
125 log.debug("getting connection");
126 state.con = query.context().connectionProvider().get();
127 log.debug("cp={}", query.context().connectionProvider());
128 }
129
130
131
132
133
134
135 private boolean isCommit() {
136 return query.sql().equals(COMMIT);
137 }
138
139
140
141
142
143
144 private boolean isRollback() {
145 return query.sql().equals(ROLLBACK);
146 }
147
148
149
150
151
152
153
154
155 @SuppressWarnings("unchecked")
156 private void performCommit(Subscriber<? super T> subscriber, State state) {
157 query.context().endTransactionObserve();
158 if (subscriber.isUnsubscribed())
159 return;
160
161 log.debug("committing");
162 Conditions.checkTrue(!Util.isAutoCommit(state.con));
163 Util.commit(state.con);
164
165
166 close(state);
167
168 if (subscriber.isUnsubscribed())
169 return;
170
171 subscriber.onNext((T) Integer.valueOf(1));
172 log.debug("committed");
173 complete(subscriber);
174 }
175
176
177
178
179
180
181
182
183 @SuppressWarnings("unchecked")
184 private void performRollback(Subscriber<? super T> subscriber, State state) {
185 log.debug("rolling back");
186 query.context().endTransactionObserve();
187 Conditions.checkTrue(!Util.isAutoCommit(state.con));
188 Util.rollback(state.con);
189
190
191 close(state);
192 subscriber.onNext((T) Integer.valueOf(0));
193 log.debug("rolled back");
194 complete(subscriber);
195 }
196
197
198
199
200
201
202
203
204 @SuppressWarnings("unchecked")
205 private void performUpdate(final Subscriber<? super T> subscriber, State state)
206 throws SQLException {
207 if (subscriber.isUnsubscribed()) {
208 return;
209 }
210 int keysOption;
211 if (query.returnGeneratedKeys()) {
212 keysOption = Statement.RETURN_GENERATED_KEYS;
213 } else {
214 keysOption = Statement.NO_GENERATED_KEYS;
215 }
216 state.ps = state.con.prepareStatement(query.sql(), keysOption);
217 Util.setParameters(state.ps, parameters, query.names());
218
219 if (subscriber.isUnsubscribed())
220 return;
221
222 int count;
223 try {
224 log.debug("executing sql={}, parameters {}", query.sql(), parameters);
225 count = state.ps.executeUpdate();
226 log.debug("executed ps={}", state.ps);
227 if (query.returnGeneratedKeys()) {
228 log.debug("getting generated keys");
229 ResultSet rs = state.ps.getGeneratedKeys();
230 log.debug("returned generated key result set {}" , rs);
231 state.rs = rs;
232 Observable<Parameter> params = Observable.just(new Parameter(state));
233 Observable<Object> depends = Observable.empty();
234 Observable<T> o = new QuerySelect(QuerySelect.RETURN_GENERATED_KEYS, params,
235 depends, query.context()).execute(query.returnGeneratedKeysFunction());
236 Subscriber<T> sub = createSubscriber(subscriber);
237 o.unsafeSubscribe(sub);
238 }
239 } catch (SQLException e) {
240 throw new SQLException("failed to execute sql=" + query.sql(), e);
241 }
242 if (!query.returnGeneratedKeys()) {
243
244
245 close(state);
246 if (subscriber.isUnsubscribed())
247 return;
248 log.debug("onNext");
249 subscriber.onNext((T) (Integer) count);
250 complete(subscriber);
251 }
252 }
253
254 private Subscriber<T> createSubscriber(final Subscriber<? super T> subscriber) {
255 return new Subscriber<T>(subscriber) {
256
257 @Override
258 public void onCompleted() {
259 complete(subscriber);
260 }
261
262 @Override
263 public void onError(Throwable e) {
264 subscriber.onError(e);
265 }
266
267 @Override
268 public void onNext(T t) {
269 subscriber.onNext(t);
270 }
271 };
272 }
273
274
275
276
277
278
279
280 private void complete(Subscriber<? super T> subscriber) {
281 if (!subscriber.isUnsubscribed()) {
282 log.debug("onCompleted");
283 subscriber.onCompleted();
284 } else
285 log.debug("unsubscribed");
286 }
287
288
289
290
291
292
293
294 private void handleException(Exception e, Subscriber<? super T> subscriber) {
295 log.debug("onError: ", e.getMessage());
296 if (subscriber.isUnsubscribed())
297 log.debug("unsubscribed");
298 else {
299 subscriber.onError(e);
300 }
301 }
302
303
304
305
306
307 private void close(State state) {
308
309 if (state.closed.compareAndSet(false, true)) {
310 Util.closeQuietly(state.ps);
311 if (isCommit() || isRollback())
312 Util.closeQuietly(state.con);
313 else
314 Util.closeQuietlyIfAutoCommit(state.con);
315 }
316 }
317
318 }