1# plots.py
2#
3# Plotting code for Variant Filtration with Neural Nets
4# This includes evaluation plots like Precision and Recall curves,
5# various flavors of Receiver Operating Characteristic (ROC curves),
6# As well as graphs of the metrics that are watched during neural net training.
7#
8# December 2016
9# Sam Friedman
10# sam@broadinstitute.org
11
12# Imports
13import os
14import math
15import matplotlib
16import numpy as np
17matplotlib.use('Agg') # Need this to write images from the GSA servers.  Order matters:
18import matplotlib.pyplot as plt # First import matplotlib, then use Agg, then import plt
19from sklearn.metrics import roc_curve, auc, roc_auc_score, precision_recall_curve, average_precision_score
20
21image_ext = '.png'
22
23color_array = ['red', 'indigo', 'cyan', 'pink', 'purple']
24key_colors = {
25    'Neural Net':'green', 'CNN_SCORE':'green', 'CNN_2D':'green',
26    'Heng Li Hard Filters':'lightblue',
27    'GATK Hard Filters':'orange','GATK Signed Distance':'darksalmon',
28    'VQSR gnomAD':'cornflowerblue', 'VQSR Single Sample':'blue', 'VQSLOD':'cornflowerblue',
29    'Deep Variant':'magenta', 'QUAL':'magenta', 'DEEP_VARIANT_QUAL':'magenta',
30    'Random Forest':'darkorange',
31    'SNP':'cornflowerblue', 'NOT_SNP':'orange', 'INDEL':'green', 'NOT_INDEL':'red',
32    'VQSLOD none':'cornflowerblue', 'VQSLOD strModel':'orange', 'VQSLOD default':'green',
33    'REFERENCE':'green', 'HET_SNP':'cornflowerblue', 'HOM_SNP':'blue', 'HET_DELETION':'magenta',
34    'HOM_DELETION':'violet', 'HET_INSERTION':'orange', 'HOM_INSERTION':'darkorange'
35}
36
37precision_label = 'Precision | Positive Predictive Value | TP/(TP+FP)'
38recall_label = 'Recall | Sensitivity | True Positive Rate | TP/(TP+FN)'
39fallout_label = 'Fallout | 1 - Specificity | False Positive Rate | FP/(FP+TN)'
40
41
42def get_fpr_tpr_roc(model, test_data, test_truth, labels, batch_size=32):
43    """Get false positive and true positive rates from a classification model.
44
45    Arguments:
46        model: The model whose predictions to evaluate.
47        test_data: Input testing data in the shape the model expects.
48        test_truth: The true labels of the testing data
49        labels: dict specifying the class labels.
50        batch_size: Size of batches for prediction over the test data.
51
52    Returns:
53        dict, dict, dict: false positive rate, true positive rate, and area under ROC curve.
54            The dicts all use label indices as keys. fpr and tpr dict's values are lists
55            (the x and y coordinates that defines the ROC curves) and for AUC the value is a float.
56    """
57    y_pred = model.predict(test_data, batch_size=batch_size, verbose=0)
58    return get_fpr_tpr_roc_pred(y_pred, test_truth, labels)
59
60
61def get_fpr_tpr_roc_pred(y_pred, test_truth, labels):
62    """Get false positive and true positive rates from predictions and true labels.
63
64    Arguments:
65        y_pred: model predictions to evaluate.
66        test_truth: The true labels of the testing data
67        labels: dict specifying the class labels.
68
69    Returns:
70        dict, dict, dict: false positive rate, true positive rate, and area under ROC curve.
71            The dicts all use label indices as keys. fpr and tpr dict's values are lists
72            (the x and y coordinates that defines the ROC curves) and for AUC the value is a float.
73    """
74    fpr = dict()
75    tpr = dict()
76    roc_auc = dict()
77
78    for k in labels.keys():
79        cur_idx = labels[k]
80        fpr[labels[k]], tpr[labels[k]], _ = roc_curve(test_truth[:,cur_idx], y_pred[:,cur_idx])
81        roc_auc[labels[k]] = auc(fpr[labels[k]], tpr[labels[k]])
82
83    return fpr, tpr, roc_auc
84
85
86def plot_roc_per_class(model, test_data, test_truth, labels, title, batch_size=32, prefix='./figures/'):
87    """Plot a per class ROC curve.
88
89    Arguments:
90        model: The model whose predictions to evaluate.
91        test_data: Input testing data in the shape the model expects.
92        test_truth: The true labels of the testing data
93        labels: dict specifying the class labels.
94        title: the title to display on the plot.
95        batch_size: Size of batches for prediction over the test data.
96        prefix: path specifying where to save the plot.
97    """
98    fpr, tpr, roc_auc = get_fpr_tpr_roc(model, test_data, test_truth, labels, batch_size)
99
100    lw = 3
101    plt.figure(figsize=(28,22))
102    matplotlib.rcParams.update({'font.size': 34})
103
104    for key in labels.keys():
105        if key in key_colors:
106            color = key_colors[key]
107        else:
108            color = np.random.choice(color_array)
109        plt.plot(fpr[labels[key]], tpr[labels[key]], color=color, lw=lw,
110                 label=str(key)+' area under ROC: %0.3f'%roc_auc[labels[key]])
111
112    plt.plot([0, 1], [0, 1], 'k:', lw=0.5)
113    plt.xlim([0.0, 1.0])
114    plt.ylim([-0.02, 1.03])
115    plt.xlabel(fallout_label)
116    plt.ylabel(recall_label)
117    plt.title('ROC:'+ title + '\n')
118
119    matplotlib.rcParams.update({'font.size': 56})
120    plt.legend(loc="lower right")
121    figure_path = prefix+"per_class_roc_"+title+image_ext
122    if not os.path.exists(os.path.dirname(figure_path)):
123        os.makedirs(os.path.dirname(figure_path))
124    plt.savefig(figure_path)
125    print('Saved figure at:', figure_path)
126
127
128def plot_metric_history(history, title, prefix='./figures/'):
129    """Plot metric history throughout training.
130
131    Arguments:
132        history: History object returned by Keras fit function.
133        title: the title to display on the plot.
134        prefix: path specifying where to save the plot.
135    """
136    num_plots = len([k for k in history.history.keys() if not 'val' in k])
137
138    row = 0
139    col = 0
140    rows = 4
141    cols = max(2, int(math.ceil(num_plots/float(rows))))
142
143    f, axes = plt.subplots(rows, cols, sharex=True, figsize=(36, 24))
144    for k in sorted(history.history.keys()):
145        if 'val' not in k:
146            axes[row, col].plot(history.history[k])
147            axes[row, col].set_ylabel(str(k))
148            axes[row, col].set_xlabel('epoch')
149            if 'val_'+k in history.history:
150                axes[row, col].plot(history.history['val_'+k])
151                labels = ['train', 'valid']
152            else:
153                labels = [k]
154            axes[row, col].legend(labels, loc='upper left')
155
156            row += 1
157            if row == rows:
158                row = 0
159                col += 1
160                if row*col >= rows*cols:
161                    break
162
163    axes[0, 1].set_title(title)
164    figure_path = prefix+"metric_history_"+title+image_ext
165    if not os.path.exists(os.path.dirname(figure_path)):
166        os.makedirs(os.path.dirname(figure_path))
167    plt.savefig(figure_path)
168
169
170def weight_path_to_title(wp):
171    """Get a title from a model's weight path
172
173    Arguments:
174        wp: path to model's weights.
175
176    Returns:
177        str: a reformatted string
178    """
179    return wp.split('/')[-1].replace('__', '-').split('.')[0]
180
181