1 package org.broadinstitute.hellbender.tools; 2 3 import java.io.File; 4 import java.io.FileOutputStream; 5 import java.io.IOException; 6 import java.io.InputStream; 7 import java.io.OutputStream; 8 import java.net.URI; 9 import java.nio.charset.Charset; 10 import java.nio.charset.StandardCharsets; 11 import java.util.ArrayList; 12 import java.util.List; 13 import java.util.concurrent.ExecutionException; 14 import java.util.concurrent.ExecutorService; 15 import java.util.concurrent.Executors; 16 import java.util.concurrent.Future; 17 import java.util.concurrent.ThreadFactory; 18 19 import com.fasterxml.jackson.databind.DeserializationFeature; 20 import com.fasterxml.jackson.databind.MapperFeature; 21 import com.fasterxml.jackson.databind.ObjectMapper; 22 import com.google.common.util.concurrent.ThreadFactoryBuilder; 23 24 import org.apache.commons.io.Charsets; 25 import org.apache.commons.io.IOUtils; 26 import org.apache.http.Header; 27 import org.apache.http.HttpEntity; 28 import org.apache.http.client.methods.CloseableHttpResponse; 29 import org.apache.http.client.methods.HttpGet; 30 import org.apache.http.impl.client.CloseableHttpClient; 31 import org.apache.http.util.EntityUtils; 32 import org.broadinstitute.barclay.argparser.Advanced; 33 import org.broadinstitute.barclay.argparser.Argument; 34 import org.broadinstitute.barclay.argparser.CommandLineProgramProperties; 35 import org.broadinstitute.barclay.argparser.ExperimentalFeature; 36 import org.broadinstitute.hellbender.cmdline.CommandLineProgram; 37 import org.broadinstitute.hellbender.cmdline.StandardArgumentDefinitions; 38 import org.broadinstitute.hellbender.cmdline.programgroups.ExampleProgramGroup; 39 import org.broadinstitute.hellbender.exceptions.UserException; 40 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetClass; 41 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetErrorResponse; 42 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetFormat; 43 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetRequestBuilder; 44 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetRequestField; 45 import org.broadinstitute.hellbender.tools.htsgetreader.HtsgetResponse; 46 import org.broadinstitute.hellbender.utils.HttpUtils; 47 import org.broadinstitute.hellbender.utils.SimpleInterval; 48 import org.broadinstitute.hellbender.utils.Utils; 49 50 /** 51 * A tool that downloads a file hosted on an htsget server to a local file 52 * 53 * <h3>Usage example</h3> 54 * <pre> 55 * gatk HtsgetReader \ 56 * --url htsget-server.org \ 57 * --id A1.bam \ 58 * --reference-name chr1 59 * -O output.bam 60 * </pre> 61 */ 62 63 @ExperimentalFeature 64 @CommandLineProgramProperties( 65 summary = "Download a file using htsget", 66 oneLineSummary = "Download a file using htsget", 67 programGroup = ExampleProgramGroup.class 68 ) 69 public class HtsgetReader extends CommandLineProgram { 70 71 public static final String URL_LONG_NAME = "url"; 72 public static final String ID_LONG_NAME = "id"; 73 public static final String FORMAT_LONG_NAME = "format"; 74 public static final String CLASS_LONG_NAME = "class"; 75 public static final String FIELDS_LONG_NAME = "field"; 76 public static final String TAGS_LONG_NAME = "tag"; 77 public static final String NOTAGS_LONG_NAME = "notag"; 78 public static final String NUM_THREADS_LONG_NAME = "reader-threads"; 79 public static final String CHECK_MD5_LONG_NAME = "check-md5"; 80 81 @Argument(doc = "Output file.", 82 fullName = StandardArgumentDefinitions.OUTPUT_SHORT_NAME, 83 shortName = StandardArgumentDefinitions.OUTPUT_LONG_NAME) 84 private File outputFile; 85 86 @Argument(doc = "URL of htsget endpoint.", 87 fullName = URL_LONG_NAME, 88 shortName = URL_LONG_NAME) 89 private URI endpoint; 90 91 @Argument(doc = "ID of record to request.", 92 fullName = ID_LONG_NAME, 93 shortName = ID_LONG_NAME) 94 private String id; 95 96 @Argument(doc = "Format to request record data in.", 97 fullName = FORMAT_LONG_NAME, 98 shortName = FORMAT_LONG_NAME, 99 optional = true) 100 private HtsgetFormat format; 101 102 @Argument(doc = "Class of data to request.", 103 fullName = CLASS_LONG_NAME, 104 shortName = CLASS_LONG_NAME, 105 optional = true) 106 private HtsgetClass dataClass; 107 108 @Argument(doc = "The interval and reference sequence to request", 109 fullName = StandardArgumentDefinitions.INTERVALS_LONG_NAME, 110 shortName = StandardArgumentDefinitions.INTERVALS_SHORT_NAME, 111 optional = true) 112 private SimpleInterval interval; 113 114 @Argument(doc = "A field to include, default: all", 115 fullName = FIELDS_LONG_NAME, 116 shortName = FIELDS_LONG_NAME, 117 optional = true) 118 private List<HtsgetRequestField> fields; 119 120 @Argument(doc = "A tag which should be included.", 121 fullName = TAGS_LONG_NAME, 122 shortName = TAGS_LONG_NAME, 123 optional = true) 124 private List<String> tags; 125 126 @Argument(doc = "A tag which should be excluded.", 127 fullName = NOTAGS_LONG_NAME, 128 shortName = NOTAGS_LONG_NAME, 129 optional = true) 130 private List<String> notags; 131 132 @Advanced 133 @Argument(fullName = NUM_THREADS_LONG_NAME, 134 shortName = NUM_THREADS_LONG_NAME, 135 doc = "How many simultaneous threads to use when reading data from an htsget response;" + 136 "higher values may improve performance when network latency is an issue.", 137 optional = true, 138 minValue = 1) 139 private int readerThreads = 1; 140 141 @Argument(fullName = CHECK_MD5_LONG_NAME, shortName = CHECK_MD5_LONG_NAME, doc = "Boolean determining whether to calculate the md5 digest of the assembled file " 142 + "and validate it against the provided md5 hash, if it exists.", optional = true) 143 private boolean checkMd5 = false; 144 145 private ExecutorService executorService; 146 147 private CloseableHttpClient client; 148 149 @Override onStartup()150 public void onStartup() { 151 if (this.readerThreads > 1) { 152 logger.info("Initializing with " + this.readerThreads + " threads"); 153 final ThreadFactory threadFactory = new ThreadFactoryBuilder() 154 .setNameFormat("htsgetReader-thread-%d") 155 .setDaemon(true).build(); 156 this.executorService = Executors.newFixedThreadPool(readerThreads, threadFactory); 157 } 158 this.client = HttpUtils.getClient(); 159 } 160 161 @Override onShutdown()162 public void onShutdown() { 163 if (this.executorService != null) { 164 this.executorService.shutdownNow(); 165 } 166 super.onShutdown(); 167 } 168 169 /** 170 * Downloads data blocks provided by response to outputFile in serial 171 */ getData(final HtsgetResponse response)172 private void getData(final HtsgetResponse response) { 173 try (final OutputStream ostream = new FileOutputStream(this.outputFile)) { 174 response.getBlocks().forEach(b -> { 175 try (final InputStream istream = b.getData()) { 176 IOUtils.copy(istream, ostream); 177 } catch (final IOException e) { 178 throw new UserException("Failed to copy data block to output file", e); 179 } 180 }); 181 } catch (final IOException e) { 182 throw new UserException("Could not create output file: " + outputFile, e); 183 } 184 } 185 186 /** 187 * Downloads data blocks provided by response to outputFile in parallel, using 188 * the number of threads specified by user 189 */ getDataParallel(final HtsgetResponse response)190 private void getDataParallel(final HtsgetResponse response) { 191 final List<Future<InputStream>> futures = new ArrayList<>(response.getBlocks().size()); 192 response.getBlocks().forEach(b -> futures.add(this.executorService.submit(b::getData))); 193 194 try (final OutputStream ostream = new FileOutputStream(this.outputFile)) { 195 futures.forEach(f -> { 196 try (final InputStream istream = f.get()) { 197 IOUtils.copy(istream, ostream); 198 } catch (final IOException e) { 199 throw new UserException("Error while copying data block to output file", e); 200 } catch (final ExecutionException | InterruptedException e) { 201 throw new UserException("Error while waiting to download block", e); 202 } 203 }); 204 } catch (final IOException e) { 205 throw new UserException("Could not create output file", e); 206 } 207 } 208 209 /** 210 * Checks md5 digest provided in response, if one exists, against calculated md5 211 * hash of downloaded file, warning user if they differ 212 */ checkMd5(final HtsgetResponse resp)213 private void checkMd5(final HtsgetResponse resp) { 214 final String expectedMd5 = resp.getMd5(); 215 if (expectedMd5 == null) { 216 logger.warn("No md5 digest provided by response"); 217 } else { 218 try { 219 final String actualMd5 = Utils.calculateFileMD5(outputFile); 220 if (!actualMd5.equals(expectedMd5)) { 221 throw new UserException("Expected md5: " + expectedMd5 + " did not match actual md5: " + actualMd5); 222 } 223 } catch (final IOException e) { 224 throw new UserException("Unable to calculate md5 digest", e); 225 } 226 } 227 } 228 getObjectMapper()229 private ObjectMapper getObjectMapper() { 230 final ObjectMapper mapper = new ObjectMapper(); 231 mapper.enable(DeserializationFeature.UNWRAP_ROOT_VALUE); 232 mapper.configure(MapperFeature.ACCEPT_CASE_INSENSITIVE_PROPERTIES, true); 233 return mapper; 234 } 235 236 @Override doWork()237 public Object doWork() { 238 // construct request from command line args and convert to URI 239 final HtsgetRequestBuilder req = new HtsgetRequestBuilder(endpoint, id) 240 .withFormat(format) 241 .withDataClass(dataClass) 242 .withInterval(interval) 243 .withFields(fields) 244 .withTags(tags) 245 .withNotags(notags); 246 final URI reqURI = req.toURI(); 247 248 final HttpGet getReq = new HttpGet(reqURI); 249 try (final CloseableHttpResponse resp = this.client.execute(getReq)) { 250 // get content of response 251 final HttpEntity entity = resp.getEntity(); 252 final Header encodingHeader = entity.getContentEncoding(); 253 final Charset encoding = encodingHeader == null 254 ? StandardCharsets.UTF_8 255 : Charsets.toCharset(encodingHeader.getValue()); 256 final String jsonBody = EntityUtils.toString(entity, encoding); 257 258 final ObjectMapper mapper = this.getObjectMapper(); 259 260 if (resp.getStatusLine() == null) { 261 throw new UserException(String.format("htsget server response did not contain status line for request %s", reqURI)); 262 } 263 final int statusCode = resp.getStatusLine().getStatusCode(); 264 if (400 <= statusCode && statusCode < 500) { 265 final HtsgetErrorResponse err = mapper.readValue(jsonBody, HtsgetErrorResponse.class); 266 throw new UserException(String.format("Invalid request %s, received error code: %d, error type: %s, message: %s", 267 reqURI, 268 statusCode, 269 err.getError(), 270 err.getMessage())); 271 } else if (statusCode == 200) { 272 final HtsgetResponse response = mapper.readValue(jsonBody, HtsgetResponse.class); 273 274 if (this.readerThreads > 1) { 275 this.getDataParallel(response); 276 } else { 277 this.getData(response); 278 } 279 280 logger.info("Successfully wrote to: " + outputFile); 281 282 if (checkMd5) { 283 this.checkMd5(response); 284 } 285 } else { 286 throw new UserException(String.format("Unrecognized status code: %d for request %s", statusCode, reqURI)); 287 } 288 } catch (final IOException e) { 289 throw new UserException(String.format("IOException during htsget download for %s", reqURI), e); 290 } 291 return null; 292 } 293 } 294