1# plots.py 2# 3# Plotting code for Variant Filtration with Neural Nets 4# This includes evaluation plots like Precision and Recall curves, 5# various flavors of Receiver Operating Characteristic (ROC curves), 6# As well as graphs of the metrics that are watched during neural net training. 7# 8# December 2016 9# Sam Friedman 10# sam@broadinstitute.org 11 12# Imports 13import os 14import math 15import matplotlib 16import numpy as np 17matplotlib.use('Agg') # Need this to write images from the GSA servers. Order matters: 18import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt 19from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, average_precision_score 20 21image_ext = '.png' 22 23color_array = ['red', 'indigo', 'cyan', 'pink', 'purple'] 24key_colors = { 25 'Neural Net':'green', 'CNN_SCORE':'green', 'CNN_2D':'green', 26 'Heng Li Hard Filters':'lightblue', 27 'GATK Hard Filters':'orange','GATK Signed Distance':'darksalmon', 28 'VQSR gnomAD':'cornflowerblue', 'VQSR Single Sample':'blue', 'VQSLOD':'cornflowerblue', 29 'Deep Variant':'magenta', 'QUAL':'magenta', 'DEEP_VARIANT_QUAL':'magenta', 30 'Random Forest':'darkorange', 31 'SNP':'cornflowerblue', 'NOT_SNP':'orange', 'INDEL':'green', 'NOT_INDEL':'red', 32 'VQSLOD none':'cornflowerblue', 'VQSLOD strModel':'orange', 'VQSLOD default':'green', 33 'REFERENCE':'green', 'HET_SNP':'cornflowerblue', 'HOM_SNP':'blue', 'HET_DELETION':'magenta', 34 'HOM_DELETION':'violet', 'HET_INSERTION':'orange', 'HOM_INSERTION':'darkorange' 35} 36 37precision_label = 'Precision | Positive Predictive Value | TP/(TP+FP)' 38recall_label = 'Recall | Sensitivity | True Positive Rate | TP/(TP+FN)' 39fallout_label = 'Fallout | 1 - Specificity | False Positive Rate | FP/(FP+TN)' 40 41 42def get_fpr_tpr_roc(model, test_data, test_truth, labels, batch_size=32): 43 """Get false positive and true positive rates from a classification model. 44 45 Arguments: 46 model: The model whose predictions to evaluate. 47 test_data: Input testing data in the shape the model expects. 48 test_truth: The true labels of the testing data 49 labels: dict specifying the class labels. 50 batch_size: Size of batches for prediction over the test data. 51 52 Returns: 53 dict, dict, dict: false positive rate, true positive rate, and area under ROC curve. 54 The dicts all use label indices as keys. fpr and tpr dict's values are lists 55 (the x and y coordinates that defines the ROC curves) and for AUC the value is a float. 56 """ 57 y_pred = model.predict(test_data, batch_size=batch_size, verbose=0) 58 return get_fpr_tpr_roc_pred(y_pred, test_truth, labels) 59 60 61def get_fpr_tpr_roc_pred(y_pred, test_truth, labels): 62 """Get false positive and true positive rates from predictions and true labels. 63 64 Arguments: 65 y_pred: model predictions to evaluate. 66 test_truth: The true labels of the testing data 67 labels: dict specifying the class labels. 68 69 Returns: 70 dict, dict, dict: false positive rate, true positive rate, and area under ROC curve. 71 The dicts all use label indices as keys. fpr and tpr dict's values are lists 72 (the x and y coordinates that defines the ROC curves) and for AUC the value is a float. 73 """ 74 fpr = dict() 75 tpr = dict() 76 roc_auc = dict() 77 78 for k in labels.keys(): 79 cur_idx = labels[k] 80 fpr[labels[k]], tpr[labels[k]], _ = roc_curve(test_truth[:,cur_idx], y_pred[:,cur_idx]) 81 roc_auc[labels[k]] = auc(fpr[labels[k]], tpr[labels[k]]) 82 83 return fpr, tpr, roc_auc 84 85 86def plot_roc_per_class(model, test_data, test_truth, labels, title, batch_size=32, prefix='./figures/'): 87 """Plot a per class ROC curve. 88 89 Arguments: 90 model: The model whose predictions to evaluate. 91 test_data: Input testing data in the shape the model expects. 92 test_truth: The true labels of the testing data 93 labels: dict specifying the class labels. 94 title: the title to display on the plot. 95 batch_size: Size of batches for prediction over the test data. 96 prefix: path specifying where to save the plot. 97 """ 98 fpr, tpr, roc_auc = get_fpr_tpr_roc(model, test_data, test_truth, labels, batch_size) 99 100 lw = 3 101 plt.figure(figsize=(28,22)) 102 matplotlib.rcParams.update({'font.size': 34}) 103 104 for key in labels.keys(): 105 if key in key_colors: 106 color = key_colors[key] 107 else: 108 color = np.random.choice(color_array) 109 plt.plot(fpr[labels[key]], tpr[labels[key]], color=color, lw=lw, 110 label=str(key)+' area under ROC: %0.3f'%roc_auc[labels[key]]) 111 112 plt.plot([0, 1], [0, 1], 'k:', lw=0.5) 113 plt.xlim([0.0, 1.0]) 114 plt.ylim([-0.02, 1.03]) 115 plt.xlabel(fallout_label) 116 plt.ylabel(recall_label) 117 plt.title('ROC:'+ title + '\n') 118 119 matplotlib.rcParams.update({'font.size': 56}) 120 plt.legend(loc="lower right") 121 figure_path = prefix+"per_class_roc_"+title+image_ext 122 if not os.path.exists(os.path.dirname(figure_path)): 123 os.makedirs(os.path.dirname(figure_path)) 124 plt.savefig(figure_path) 125 print('Saved figure at:', figure_path) 126 127 128def plot_metric_history(history, title, prefix='./figures/'): 129 """Plot metric history throughout training. 130 131 Arguments: 132 history: History object returned by Keras fit function. 133 title: the title to display on the plot. 134 prefix: path specifying where to save the plot. 135 """ 136 num_plots = len([k for k in history.history.keys() if not 'val' in k]) 137 138 row = 0 139 col = 0 140 rows = 4 141 cols = max(2, int(math.ceil(num_plots/float(rows)))) 142 143 f, axes = plt.subplots(rows, cols, sharex=True, figsize=(36, 24)) 144 for k in sorted(history.history.keys()): 145 if 'val' not in k: 146 axes[row, col].plot(history.history[k]) 147 axes[row, col].set_ylabel(str(k)) 148 axes[row, col].set_xlabel('epoch') 149 if 'val_'+k in history.history: 150 axes[row, col].plot(history.history['val_'+k]) 151 labels = ['train', 'valid'] 152 else: 153 labels = [k] 154 axes[row, col].legend(labels, loc='upper left') 155 156 row += 1 157 if row == rows: 158 row = 0 159 col += 1 160 if row*col >= rows*cols: 161 break 162 163 axes[0, 1].set_title(title) 164 figure_path = prefix+"metric_history_"+title+image_ext 165 if not os.path.exists(os.path.dirname(figure_path)): 166 os.makedirs(os.path.dirname(figure_path)) 167 plt.savefig(figure_path) 168 169 170def weight_path_to_title(wp): 171 """Get a title from a model's weight path 172 173 Arguments: 174 wp: path to model's weights. 175 176 Returns: 177 str: a reformatted string 178 """ 179 return wp.split('/')[-1].replace('__', '-').split('.')[0] 180 181