1 package org.broadinstitute.hellbender.tools;
2 
3 import java.io.File;
4 import java.io.FileOutputStream;
5 import java.io.IOException;
6 import java.io.InputStream;
7 import java.io.OutputStream;
8 import java.net.URI;
9 import java.nio.charset.Charset;
10 import java.nio.charset.StandardCharsets;
11 import java.util.ArrayList;
12 import java.util.List;
13 import java.util.concurrent.ExecutionException;
14 import java.util.concurrent.ExecutorService;
15 import java.util.concurrent.Executors;
16 import java.util.concurrent.Future;
17 import java.util.concurrent.ThreadFactory;
18 
19 import com.fasterxml.jackson.databind.DeserializationFeature;
20 import com.fasterxml.jackson.databind.MapperFeature;
21 import com.fasterxml.jackson.databind.ObjectMapper;
22 import com.google.common.util.concurrent.ThreadFactoryBuilder;
23 
24 import org.apache.commons.io.Charsets;
25 import org.apache.commons.io.IOUtils;
26 import org.apache.http.Header;
27 import org.apache.http.HttpEntity;
28 import org.apache.http.client.methods.CloseableHttpResponse;
29 import org.apache.http.client.methods.HttpGet;
30 import org.apache.http.impl.client.CloseableHttpClient;
31 import org.apache.http.util.EntityUtils;
32 import org.broadinstitute.barclay.argparser.Advanced;
33 import org.broadinstitute.barclay.argparser.Argument;
34 import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
35 import org.broadinstitute.barclay.argparser.ExperimentalFeature;
36 import org.broadinstitute.hellbender.cmdline.CommandLineProgram;
37 import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions;
38 import org.broadinstitute.hellbender.cmdline.programgroups.ExampleProgramGroup;
39 import org.broadinstitute.hellbender.exceptions.UserException;
40 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetClass;
41 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetErrorResponse;
42 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetFormat;
43 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetRequestBuilder;
44 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetRequestField;
45 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetResponse;
46 import org.broadinstitute.hellbender.utils.HttpUtils;
47 import org.broadinstitute.hellbender.utils.SimpleInterval;
48 import org.broadinstitute.hellbender.utils.Utils;
49 
50 /**
51  * A tool that downloads a file hosted on an htsget server to a local file
52  *
53  * <h3>Usage example</h3>
54  * <pre>
55  * gatk HtsgetReader \
56  *   --url htsget-server.org \
57  *   --id A1.bam \
58  *   --reference-name chr1
59  *   -O output.bam
60  * </pre>
61  */
62 
63 @ExperimentalFeature
64 @CommandLineProgramProperties(
65         summary = "Download a file using htsget",
66         oneLineSummary = "Download a file using htsget",
67         programGroup = ExampleProgramGroup.class
68 )
69 public class HtsgetReader extends CommandLineProgram {
70 
71     public static final String URL_LONG_NAME = "url";
72     public static final String ID_LONG_NAME = "id";
73     public static final String FORMAT_LONG_NAME = "format";
74     public static final String CLASS_LONG_NAME = "class";
75     public static final String FIELDS_LONG_NAME = "field";
76     public static final String TAGS_LONG_NAME = "tag";
77     public static final String NOTAGS_LONG_NAME = "notag";
78     public static final String NUM_THREADS_LONG_NAME = "reader-threads";
79     public static final String CHECK_MD5_LONG_NAME = "check-md5";
80 
81     @Argument(doc = "Output file.",
82             fullName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME,
83             shortName = StandardArgumentDefinitions.OUTPUT_LONG_NAME)
84     private File outputFile;
85 
86     @Argument(doc = "URL of htsget endpoint.",
87             fullName = URL_LONG_NAME,
88             shortName = URL_LONG_NAME)
89     private URI endpoint;
90 
91     @Argument(doc = "ID of record to request.",
92             fullName = ID_LONG_NAME,
93             shortName = ID_LONG_NAME)
94     private String id;
95 
96     @Argument(doc = "Format to request record data in.",
97             fullName = FORMAT_LONG_NAME,
98             shortName = FORMAT_LONG_NAME,
99             optional = true)
100     private HtsgetFormat format;
101 
102     @Argument(doc = "Class of data to request.",
103             fullName = CLASS_LONG_NAME,
104             shortName = CLASS_LONG_NAME,
105             optional = true)
106     private HtsgetClass dataClass;
107 
108     @Argument(doc = "The interval and reference sequence to request",
109             fullName = StandardArgumentDefinitions.INTERVALS_LONG_NAME,
110             shortName = StandardArgumentDefinitions.INTERVALS_SHORT_NAME,
111             optional = true)
112     private SimpleInterval interval;
113 
114     @Argument(doc = "A field to include, default: all",
115             fullName = FIELDS_LONG_NAME,
116             shortName = FIELDS_LONG_NAME,
117             optional = true)
118     private List<HtsgetRequestField> fields;
119 
120     @Argument(doc = "A tag which should be included.",
121             fullName = TAGS_LONG_NAME,
122             shortName = TAGS_LONG_NAME,
123             optional = true)
124     private List<String> tags;
125 
126     @Argument(doc = "A tag which should be excluded.",
127             fullName = NOTAGS_LONG_NAME,
128             shortName = NOTAGS_LONG_NAME,
129             optional = true)
130     private List<String> notags;
131 
132     @Advanced
133     @Argument(fullName = NUM_THREADS_LONG_NAME,
134             shortName = NUM_THREADS_LONG_NAME,
135             doc = "How many simultaneous threads to use when reading data from an htsget response;" +
136                     "higher values may improve performance when network latency is an issue.",
137             optional = true,
138             minValue = 1)
139     private int readerThreads = 1;
140 
141     @Argument(fullName = CHECK_MD5_LONG_NAME, shortName = CHECK_MD5_LONG_NAME, doc = "Boolean determining whether to calculate the md5 digest of the assembled file "
142             + "and validate it against the provided md5 hash, if it exists.", optional = true)
143     private boolean checkMd5 = false;
144 
145     private ExecutorService executorService;
146 
147     private CloseableHttpClient client;
148 
149     @Override
onStartup()150     public void onStartup() {
151         if (this.readerThreads > 1) {
152             logger.info("Initializing with " + this.readerThreads + " threads");
153             final ThreadFactory threadFactory = new ThreadFactoryBuilder()
154                 .setNameFormat("htsgetReader-thread-%d")
155                 .setDaemon(true).build();
156             this.executorService = Executors.newFixedThreadPool(readerThreads, threadFactory);
157         }
158         this.client = HttpUtils.getClient();
159     }
160 
161     @Override
onShutdown()162     public void onShutdown() {
163         if (this.executorService != null) {
164             this.executorService.shutdownNow();
165         }
166         super.onShutdown();
167     }
168 
169     /**
170      * Downloads data blocks provided by response to outputFile in serial
171      */
getData(final HtsgetResponse response)172     private void getData(final HtsgetResponse response) {
173         try (final OutputStream ostream = new FileOutputStream(this.outputFile)) {
174             response.getBlocks().forEach(b -> {
175                 try (final InputStream istream = b.getData()) {
176                     IOUtils.copy(istream, ostream);
177                 } catch (final IOException e) {
178                     throw new UserException("Failed to copy data block to output file", e);
179                 }
180             });
181         } catch (final IOException e) {
182             throw new UserException("Could not create output file: " + outputFile, e);
183         }
184     }
185 
186     /**
187      * Downloads data blocks provided by response to outputFile in parallel, using
188      * the number of threads specified by user
189      */
getDataParallel(final HtsgetResponse response)190     private void getDataParallel(final HtsgetResponse response) {
191         final List<Future<InputStream>> futures = new ArrayList<>(response.getBlocks().size());
192         response.getBlocks().forEach(b -> futures.add(this.executorService.submit(b::getData)));
193 
194         try (final OutputStream ostream = new FileOutputStream(this.outputFile)) {
195             futures.forEach(f -> {
196                 try (final InputStream istream = f.get()) {
197                     IOUtils.copy(istream, ostream);
198                 } catch (final IOException e) {
199                     throw new UserException("Error while copying data block to output file", e);
200                 } catch (final ExecutionException | InterruptedException e) {
201                     throw new UserException("Error while waiting to download block", e);
202                 }
203             });
204         } catch (final IOException e) {
205             throw new UserException("Could not create output file", e);
206         }
207     }
208 
209     /**
210      * Checks md5 digest provided in response, if one exists, against calculated md5
211      * hash of downloaded file, warning user if they differ
212      */
checkMd5(final HtsgetResponse resp)213     private void checkMd5(final HtsgetResponse resp) {
214         final String expectedMd5 = resp.getMd5();
215         if (expectedMd5 == null) {
216             logger.warn("No md5 digest provided by response");
217         } else {
218             try {
219                 final String actualMd5 = Utils.calculateFileMD5(outputFile);
220                 if (!actualMd5.equals(expectedMd5)) {
221                     throw new UserException("Expected md5: " + expectedMd5 + " did not match actual md5: " + actualMd5);
222                 }
223             } catch (final IOException e) {
224                 throw new UserException("Unable to calculate md5 digest", e);
225             }
226         }
227     }
228 
getObjectMapper()229     private ObjectMapper getObjectMapper() {
230         final ObjectMapper mapper = new ObjectMapper();
231         mapper.enable(DeserializationFeature.UNWRAP_ROOT_VALUE);
232         mapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true);
233         return mapper;
234     }
235 
236     @Override
doWork()237     public Object doWork() {
238         // construct request from command line args and convert to URI
239         final HtsgetRequestBuilder req = new HtsgetRequestBuilder(endpoint, id)
240             .withFormat(format)
241             .withDataClass(dataClass)
242             .withInterval(interval)
243             .withFields(fields)
244             .withTags(tags)
245             .withNotags(notags);
246         final URI reqURI = req.toURI();
247 
248         final HttpGet getReq = new HttpGet(reqURI);
249         try (final CloseableHttpResponse resp = this.client.execute(getReq)) {
250             // get content of response
251             final HttpEntity entity = resp.getEntity();
252             final Header encodingHeader = entity.getContentEncoding();
253             final Charset encoding = encodingHeader == null
254                 ? StandardCharsets.UTF_8
255                 : Charsets.toCharset(encodingHeader.getValue());
256             final String jsonBody = EntityUtils.toString(entity, encoding);
257 
258             final ObjectMapper mapper = this.getObjectMapper();
259 
260             if (resp.getStatusLine() == null) {
261                 throw new UserException(String.format("htsget server response did not contain status line for request %s", reqURI));
262             }
263             final int statusCode = resp.getStatusLine().getStatusCode();
264             if (400 <= statusCode && statusCode < 500) {
265                 final HtsgetErrorResponse err = mapper.readValue(jsonBody, HtsgetErrorResponse.class);
266                 throw new UserException(String.format("Invalid request %s, received error code: %d, error type: %s, message: %s",
267                         reqURI,
268                         statusCode,
269                         err.getError(),
270                         err.getMessage()));
271             } else if (statusCode == 200) {
272                 final HtsgetResponse response = mapper.readValue(jsonBody, HtsgetResponse.class);
273 
274                 if (this.readerThreads > 1) {
275                     this.getDataParallel(response);
276                 } else {
277                     this.getData(response);
278                 }
279 
280                 logger.info("Successfully wrote to: " + outputFile);
281 
282                 if (checkMd5) {
283                     this.checkMd5(response);
284                 }
285             } else {
286                 throw new UserException(String.format("Unrecognized status code: %d for request %s", statusCode, reqURI));
287             }
288         } catch (final IOException e) {
289             throw new UserException(String.format("IOException during htsget download for %s", reqURI), e);
290         }
291         return null;
292     }
293 }
294