View Javadoc
1   package com.github.davidmoten.aws.lw.client;
2   
3   import java.io.ByteArrayOutputStream;
4   import java.io.IOException;
5   import java.io.OutputStream;
6   import java.util.List;
7   import java.util.concurrent.Callable;
8   import java.util.concurrent.CopyOnWriteArrayList;
9   import java.util.concurrent.ExecutorService;
10  import java.util.concurrent.Future;
11  import java.util.concurrent.TimeUnit;
12  import java.util.function.Function;
13  import java.util.stream.Collectors;
14  
15  import com.github.davidmoten.aws.lw.client.internal.Retries;
16  import com.github.davidmoten.aws.lw.client.internal.util.Preconditions;
17  import com.github.davidmoten.aws.lw.client.xml.builder.Xml;
18  
19  // NotThreadSafe
20  public final class MultipartOutputStream extends OutputStream {
21  
22      private final Client s3;
23      private final String bucket;
24      private final String key;
25      private final String uploadId;
26      private final ExecutorService executor;
27      private final ByteArrayOutputStream bytes;
28      private final byte[] singleByte = new byte[1]; // for reuse in write(int) method
29      private final long partTimeoutMs;
30      private final Retries<Void> retries;
31      private final int partSize;
32      private final List<Future<String>> futures = new CopyOnWriteArrayList<>();
33      private int nextPart = 1;
34  
35      MultipartOutputStream(Client s3, String bucket, String key,
36              Function<? super Request, ? extends Request> transformCreate, ExecutorService executor,
37              long partTimeoutMs, Retries<Void> retries, int partSize) {
38          Preconditions.checkNotNull(s3);
39          Preconditions.checkNotNull(bucket);
40          Preconditions.checkNotNull(key);
41          Preconditions.checkNotNull(transformCreate);
42          Preconditions.checkNotNull(executor);
43          Preconditions.checkArgument(partTimeoutMs > 0);
44          Preconditions.checkNotNull(retries);
45          Preconditions.checkArgument(partSize >= 5 * 1024 * 1024);
46          this.s3 = s3;
47          this.bucket = bucket;
48          this.key = key;
49          this.executor = executor;
50          this.partTimeoutMs = partTimeoutMs;
51          this.retries = retries;
52          this.partSize = partSize;
53          this.bytes = new ByteArrayOutputStream();
54          this.uploadId = transformCreate.apply(s3 //
55                  .path(bucket, key) //
56                  .query("uploads") //
57                  .method(HttpMethod.POST)) //
58                  .responseAsXml() //
59                  .content("UploadId");
60      }
61  
62      public void abort() {
63          futures.forEach(f -> f.cancel(true));
64          s3 //
65                  .path(bucket, key) //
66                  .query("uploadId", uploadId) //
67                  .method(HttpMethod.DELETE) //
68                  .execute();
69      }
70  
71      @Override
72      public void write(byte[] b, int off, int len) throws IOException {
73          while (len > 0) {
74              int remaining = partSize - bytes.size();
75              int n = Math.min(remaining, len);
76              bytes.write(b, off, n);
77              off += n;
78              len -= n;
79              if (bytes.size() == partSize) {
80                  submitPart();
81              }
82          }
83      }
84  
85      @Override
86      public void write(byte[] b) throws IOException {
87          write(b, 0, b.length);
88      }
89  
90      private void submitPart() {
91          int part = nextPart;
92          nextPart++;
93          byte[] body = bytes.toByteArray();
94          bytes.reset();
95          Future<String> future = executor.submit(() -> retry(() -> s3 //
96                  .path(bucket, key) //
97                  .method(HttpMethod.PUT) //
98                  .query("partNumber", "" + part) //
99                  .query("uploadId", uploadId) //
100                 .requestBody(body) //
101                 .readTimeout(partTimeoutMs, TimeUnit.MILLISECONDS) //
102                 .responseExpectStatusCode(200) //
103                 .firstHeader("ETag") //
104                 .get() //
105                 .replace("\"", ""), //
106                 "on part " + part));
107         futures.add(future);
108     }
109 
110     private <T> T retry(Callable<T> callable, String description) {
111         //TODO use description
112         return retries.call(callable, x -> false);
113     }
114 
115     @Override
116     public void close() throws IOException {
117         // submit whatever's left
118         if (bytes.size() > 0) {
119             submitPart();
120         }
121         List<String> etags = futures //
122                 .stream() //
123                 .map(future -> getResult(future)) //
124                 .collect(Collectors.toList());
125 
126         Xml xml = Xml //
127                 .create("CompleteMultipartUpload") //
128                 .attribute("xmlns", "http:s3.amazonaws.com/doc/2006-03-01/");
129         for (int i = 0; i < etags.size(); i++) {
130             xml = xml //
131                     .element("Part") //
132                     .element("ETag").content(etags.get(i)) //
133                     .up() //
134                     .element("PartNumber").content(String.valueOf(i + 1)) //
135                     .up().up();
136         }
137         String xmlFinal = xml.toString();
138         retry(() -> {
139             s3.path(bucket, key) //
140                     .method(HttpMethod.POST) //
141                     .query("uploadId", uploadId) //
142                     .header("Content-Type", "application/xml") //
143                     .unsignedPayload() //
144                     .requestBody(xmlFinal) //
145                     .execute();
146             return null;
147         }, "while completing multipart upload");
148     }
149 
150     private String getResult(Future<String> future) {
151         try {
152             return future.get(partTimeoutMs, TimeUnit.MILLISECONDS);
153         } catch (Throwable e) {
154             abort();
155             throw new RuntimeException(e);
156         }
157     }
158 
159     @Override
160     public void write(int b) throws IOException {
161         singleByte[0] = (byte) b;
162         write(singleByte, 0, 1);
163     }
164 }