1 package com.github.davidmoten.aws.lw.client;
2
3 import java.io.ByteArrayOutputStream;
4 import java.io.IOException;
5 import java.io.OutputStream;
6 import java.util.List;
7 import java.util.concurrent.Callable;
8 import java.util.concurrent.CopyOnWriteArrayList;
9 import java.util.concurrent.ExecutorService;
10 import java.util.concurrent.Future;
11 import java.util.concurrent.TimeUnit;
12 import java.util.function.Function;
13 import java.util.stream.Collectors;
14
15 import com.github.davidmoten.aws.lw.client.internal.Retries;
16 import com.github.davidmoten.aws.lw.client.internal.util.Preconditions;
17 import com.github.davidmoten.aws.lw.client.xml.builder.Xml;
18
19
20 public final class MultipartOutputStream extends OutputStream {
21
22 private final Client s3;
23 private final String bucket;
24 private final String key;
25 private final String uploadId;
26 private final ExecutorService executor;
27 private final ByteArrayOutputStream bytes;
28 private final byte[] singleByte = new byte[1];
29 private final long partTimeoutMs;
30 private final Retries<Void> retries;
31 private final int partSize;
32 private final List<Future<String>> futures = new CopyOnWriteArrayList<>();
33 private int nextPart = 1;
34
35 MultipartOutputStream(Client s3, String bucket, String key,
36 Function<? super Request, ? extends Request> transformCreate, ExecutorService executor,
37 long partTimeoutMs, Retries<Void> retries, int partSize) {
38 Preconditions.checkNotNull(s3);
39 Preconditions.checkNotNull(bucket);
40 Preconditions.checkNotNull(key);
41 Preconditions.checkNotNull(transformCreate);
42 Preconditions.checkNotNull(executor);
43 Preconditions.checkArgument(partTimeoutMs > 0);
44 Preconditions.checkNotNull(retries);
45 Preconditions.checkArgument(partSize >= 5 * 1024 * 1024);
46 this.s3 = s3;
47 this.bucket = bucket;
48 this.key = key;
49 this.executor = executor;
50 this.partTimeoutMs = partTimeoutMs;
51 this.retries = retries;
52 this.partSize = partSize;
53 this.bytes = new ByteArrayOutputStream();
54 this.uploadId = transformCreate.apply(s3
55 .path(bucket, key)
56 .query("uploads")
57 .method(HttpMethod.POST))
58 .responseAsXml()
59 .content("UploadId");
60 }
61
62 public void abort() {
63 futures.forEach(f -> f.cancel(true));
64 s3
65 .path(bucket, key)
66 .query("uploadId", uploadId)
67 .method(HttpMethod.DELETE)
68 .execute();
69 }
70
71 @Override
72 public void write(byte[] b, int off, int len) throws IOException {
73 while (len > 0) {
74 int remaining = partSize - bytes.size();
75 int n = Math.min(remaining, len);
76 bytes.write(b, off, n);
77 off += n;
78 len -= n;
79 if (bytes.size() == partSize) {
80 submitPart();
81 }
82 }
83 }
84
85 @Override
86 public void write(byte[] b) throws IOException {
87 write(b, 0, b.length);
88 }
89
90 private void submitPart() {
91 int part = nextPart;
92 nextPart++;
93 byte[] body = bytes.toByteArray();
94 bytes.reset();
95 Future<String> future = executor.submit(() -> retry(() -> s3
96 .path(bucket, key)
97 .method(HttpMethod.PUT)
98 .query("partNumber", "" + part)
99 .query("uploadId", uploadId)
100 .requestBody(body)
101 .readTimeout(partTimeoutMs, TimeUnit.MILLISECONDS)
102 .responseExpectStatusCode(200)
103 .firstHeader("ETag")
104 .get()
105 .replace("\"", ""),
106 "on part " + part));
107 futures.add(future);
108 }
109
110 private <T> T retry(Callable<T> callable, String description) {
111
112 return retries.call(callable, x -> false);
113 }
114
115 @Override
116 public void close() throws IOException {
117
118 if (bytes.size() > 0) {
119 submitPart();
120 }
121 List<String> etags = futures
122 .stream()
123 .map(future -> getResult(future))
124 .collect(Collectors.toList());
125
126 Xml xml = Xml
127 .create("CompleteMultipartUpload")
128 .attribute("xmlns", "http:s3.amazonaws.com/doc/2006-03-01/");
129 for (int i = 0; i < etags.size(); i++) {
130 xml = xml
131 .element("Part")
132 .element("ETag").content(etags.get(i))
133 .up()
134 .element("PartNumber").content(String.valueOf(i + 1))
135 .up().up();
136 }
137 String xmlFinal = xml.toString();
138 retry(() -> {
139 s3.path(bucket, key)
140 .method(HttpMethod.POST)
141 .query("uploadId", uploadId)
142 .header("Content-Type", "application/xml")
143 .unsignedPayload()
144 .requestBody(xmlFinal)
145 .execute();
146 return null;
147 }, "while completing multipart upload");
148 }
149
150 private String getResult(Future<String> future) {
151 try {
152 return future.get(partTimeoutMs, TimeUnit.MILLISECONDS);
153 } catch (Throwable e) {
154 abort();
155 throw new RuntimeException(e);
156 }
157 }
158
159 @Override
160 public void write(int b) throws IOException {
161 singleByte[0] = (byte) b;
162 write(singleByte, 0, 1);
163 }
164 }