1#!/usr/bin/env python
2
3""" MultiQC functions to plot a linegraph """
4
5from __future__ import print_function, division
6from collections import OrderedDict
7import base64
8import inspect
9import io
10import logging
11import os
12import random
13import re
14import sys
15
16from multiqc.utils import config, report, util_functions
17
18logger = logging.getLogger(__name__)
19
20try:
21    # Import matplot lib but avoid default X environment
22    import matplotlib
23
24    matplotlib.use("Agg")
25    import matplotlib.pyplot as plt
26
27    logger.debug("Using matplotlib version {}".format(matplotlib.__version__))
28except Exception as e:
29    # MatPlotLib can break in a variety of ways. Fake an error message and continue without it if so.
30    # The lack of the library will be handled when plots are attempted
31    print("##### ERROR! MatPlotLib library could not be loaded!    #####", file=sys.stderr)
32    print("##### Flat plots will instead be plotted as interactive #####", file=sys.stderr)
33    print(e)
34
35letters = "abcdefghijklmnopqrstuvwxyz"
36
37# Load the template so that we can access its configuration
38# Do this lazily to mitigate import-spaghetti when running unit tests
39_template_mod = None
40
41
42def get_template_mod():
43    global _template_mod
44    if not _template_mod:
45        _template_mod = config.avail_templates[config.template].load()
46    return _template_mod
47
48
49def plot(data, pconfig=None):
50    """Plot a line graph with X,Y data.
51    :param data: 2D dict, first keys as sample names, then x:y data pairs
52    :param pconfig: optional dict with config key:value pairs. See CONTRIBUTING.md
53    :return: HTML and JS, ready to be inserted into the page
54    """
55    # Don't just use {} as the default argument as it's mutable. See:
56    # http://python-guide-pt-br.readthedocs.io/en/latest/writing/gotchas/
57    if pconfig is None:
58        pconfig = {}
59
60    # Allow user to overwrite any given config for this plot
61    if "id" in pconfig and pconfig["id"] and pconfig["id"] in config.custom_plot_config:
62        for k, v in config.custom_plot_config[pconfig["id"]].items():
63            pconfig[k] = v
64
65    # Given one dataset - turn it into a list
66    if type(data) is not list:
67        data = [data]
68
69    # Validate config if linting
70    if config.lint:
71        # Get module name
72        modname = ""
73        callstack = inspect.stack()
74        for n in callstack:
75            if "multiqc/modules/" in n[1] and "base_module.py" not in n[1]:
76                callpath = n[1].split("multiqc/modules/", 1)[-1]
77                modname = ">{}< ".format(callpath)
78                break
79        # Look for essential missing pconfig keys
80        for k in ["id", "title", "ylab"]:
81            if k not in pconfig:
82                errmsg = "LINT: {}Linegraph pconfig was missing key '{}'".format(modname, k)
83                logger.error(errmsg)
84                report.lint_errors.append(errmsg)
85        # Check plot title format
86        if not re.match(r"^[^:]*\S: \S[^:]*$", pconfig.get("title", "")):
87            errmsg = "LINT: {} Linegraph title did not match format 'Module: Plot Name' (found '{}')".format(
88                modname, pconfig.get("title", "")
89            )
90            logger.error(errmsg)
91            report.lint_errors.append(errmsg)
92
93    # Smooth dataset if requested in config
94    if pconfig.get("smooth_points", None) is not None:
95        sumcounts = pconfig.get("smooth_points_sumcounts", True)
96        for i, d in enumerate(data):
97            if type(sumcounts) is list:
98                sumc = sumcounts[i]
99            else:
100                sumc = sumcounts
101            data[i] = smooth_line_data(d, pconfig["smooth_points"], sumc)
102
103    # Add sane plotting config defaults
104    for idx, yp in enumerate(pconfig.get("yPlotLines", [])):
105        pconfig["yPlotLines"][idx]["width"] = pconfig["yPlotLines"][idx].get("width", 2)
106
107    # Add initial axis labels if defined in `data_labels` but not main config
108    if pconfig.get("ylab") is None:
109        try:
110            pconfig["ylab"] = pconfig["data_labels"][0]["ylab"]
111        except (KeyError, IndexError):
112            pass
113    if pconfig.get("xlab") is None:
114        try:
115            pconfig["xlab"] = pconfig["data_labels"][0]["xlab"]
116        except (KeyError, IndexError):
117            pass
118
119    # Generate the data dict structure expected by HighCharts series
120    plotdata = list()
121    for data_index, d in enumerate(data):
122        thisplotdata = list()
123
124        for s in sorted(d.keys()):
125
126            # Ensure any overwritting conditionals from data_labels (e.g. ymax) are taken in consideration
127            series_config = pconfig.copy()
128            if (
129                "data_labels" in pconfig and type(pconfig["data_labels"][data_index]) is dict
130            ):  # if not a dict: only dataset name is provided
131                series_config.update(pconfig["data_labels"][data_index])
132
133            pairs = list()
134            maxval = 0
135            if "categories" in series_config:
136                pconfig["categories"] = list()
137                for k in d[s].keys():
138                    pconfig["categories"].append(k)
139                    pairs.append(d[s][k])
140                    maxval = max(maxval, d[s][k])
141            else:
142                for k in sorted(d[s].keys()):
143                    if k is not None:
144                        if "xmax" in series_config and float(k) > float(series_config["xmax"]):
145                            continue
146                        if "xmin" in series_config and float(k) < float(series_config["xmin"]):
147                            continue
148                    if d[s][k] is not None:
149                        if "ymax" in series_config and float(d[s][k]) > float(series_config["ymax"]):
150                            continue
151                        if "ymin" in series_config and float(d[s][k]) < float(series_config["ymin"]):
152                            continue
153                    pairs.append([k, d[s][k]])
154                    try:
155                        maxval = max(maxval, d[s][k])
156                    except TypeError:
157                        pass
158            if maxval > 0 or series_config.get("hide_empty") is not True:
159                this_series = {"name": s, "data": pairs}
160                try:
161                    this_series["color"] = series_config["colors"][s]
162                except:
163                    pass
164                thisplotdata.append(this_series)
165        plotdata.append(thisplotdata)
166
167    # Add on annotation data series
168    try:
169        if pconfig.get("extra_series"):
170            extra_series = pconfig["extra_series"]
171            if type(pconfig["extra_series"]) == dict:
172                extra_series = [[pconfig["extra_series"]]]
173            elif type(pconfig["extra_series"]) == list and type(pconfig["extra_series"][0]) == dict:
174                extra_series = [pconfig["extra_series"]]
175            for i, es in enumerate(extra_series):
176                for s in es:
177                    plotdata[i].append(s)
178    except (KeyError, IndexError):
179        pass
180
181    # Make a plot - template custom, or interactive or flat
182    try:
183        return get_template_mod().linegraph(plotdata, pconfig)
184    except (AttributeError, TypeError):
185        if config.plots_force_flat or (
186            not config.plots_force_interactive and plotdata and len(plotdata[0]) > config.plots_flat_numseries
187        ):
188            try:
189                return matplotlib_linegraph(plotdata, pconfig)
190            except Exception as e:
191                logger.error("############### Error making MatPlotLib figure! Falling back to HighCharts.")
192                logger.debug(e, exc_info=True)
193                return highcharts_linegraph(plotdata, pconfig)
194        else:
195            # Use MatPlotLib to generate static plots if requested
196            if config.export_plots:
197                matplotlib_linegraph(plotdata, pconfig)
198            # Return HTML for HighCharts dynamic plot
199            return highcharts_linegraph(plotdata, pconfig)
200
201
202def highcharts_linegraph(plotdata, pconfig=None):
203    """
204    Build the HTML needed for a HighCharts line graph. Should be
205    called by linegraph.plot(), which properly formats input data.
206    """
207    if pconfig is None:
208        pconfig = {}
209
210    # Get the plot ID
211    if pconfig.get("id") is None:
212        pconfig["id"] = "mqc_hcplot_" + "".join(random.sample(letters, 10))
213
214    # Sanitise plot ID and check for duplicates
215    pconfig["id"] = report.save_htmlid(pconfig["id"])
216
217    # Build the HTML for the page
218    html = '<div class="mqc_hcplot_plotgroup">'
219
220    # Log Switch
221    if pconfig.get("logswitch") is True:
222        c_active = "active"
223        l_active = ""
224        if pconfig.get("logswitch_active") is True:
225            c_active = ""
226            l_active = "active"
227        c_label = pconfig.get("cpswitch_counts_label", "Counts")
228        l_label = pconfig.get("logswitch_label", "Log10")
229        html += '<div class="btn-group hc_switch_group"> \n'
230        html += '<button class="btn btn-default btn-sm {c_a}" data-action="set_numbers" data-target="{id}" data-ylab="{c_l}">{c_l}</button> \n'.format(
231            id=pconfig["id"], c_a=c_active, c_l=c_label
232        )
233        if pconfig.get("logswitch") is True:
234            html += '<button class="btn btn-default btn-sm {l_a}" data-action="set_log" data-target="{id}" data-ylab="{l_l}">{l_l}</button> \n'.format(
235                id=pconfig["id"], l_a=l_active, l_l=l_label
236            )
237        html += "</div> "
238        if len(plotdata) > 1:
239            html += " &nbsp; &nbsp; "
240
241    # Buttons to cycle through different datasets
242    if len(plotdata) > 1:
243        html += '<div class="btn-group hc_switch_group">\n'
244        for k, p in enumerate(plotdata):
245            active = "active" if k == 0 else ""
246            try:
247                name = pconfig["data_labels"][k]["name"]
248            except:
249                name = k + 1
250            try:
251                ylab = 'data-ylab="{}"'.format(pconfig["data_labels"][k]["ylab"])
252            except:
253                ylab = 'data-ylab="{}"'.format(name) if name != k + 1 else ""
254            try:
255                ymax = 'data-ymax="{}"'.format(pconfig["data_labels"][k]["ymax"])
256            except:
257                ymax = ""
258            try:
259                xlab = 'data-xlab="{}"'.format(pconfig["data_labels"][k]["xlab"])
260            except:
261                xlab = ""
262            html += '<button class="btn btn-default btn-sm {a}" data-action="set_data" {y} {ym} {x} data-newdata="{k}" data-target="{id}">{n}</button>\n'.format(
263                a=active, id=pconfig["id"], n=name, y=ylab, ym=ymax, x=xlab, k=k
264            )
265        html += "</div>\n\n"
266
267    # The plot div
268    html += '<div class="hc-plot-wrapper"><div id="{id}" class="hc-plot not_rendered hc-line-plot"><small>loading..</small></div></div></div> \n'.format(
269        id=pconfig["id"]
270    )
271
272    report.num_hc_plots += 1
273
274    report.plot_data[pconfig["id"]] = {"plot_type": "xy_line", "datasets": plotdata, "config": pconfig}
275
276    return html
277
278
279def matplotlib_linegraph(plotdata, pconfig=None):
280    """
281    Plot a line graph with Matplot lib and return a HTML string. Either embeds a base64
282    encoded image within HTML or writes the plot and links to it. Should be called by
283    plot_bargraph, which properly formats the input data.
284    """
285    if pconfig is None:
286        pconfig = {}
287
288    # Plot group ID
289    if pconfig.get("id") is None:
290        pconfig["id"] = "mqc_mplplot_" + "".join(random.sample(letters, 10))
291
292    # Sanitise plot ID and check for duplicates
293    pconfig["id"] = report.save_htmlid(pconfig["id"])
294
295    # Individual plot IDs
296    pids = []
297    for k in range(len(plotdata)):
298        try:
299            name = pconfig["data_labels"][k]["name"]
300        except:
301            name = k + 1
302        pid = "mqc_{}_{}".format(pconfig["id"], name)
303        pid = report.save_htmlid(pid, skiplint=True)
304        pids.append(pid)
305
306    html = (
307        '<p class="text-info"><small><span class="glyphicon glyphicon-picture" aria-hidden="true"></span> '
308        + "Flat image plot. Toolbox functions such as highlighting / hiding samples will not work "
309        + '(see the <a href="http://multiqc.info/docs/#flat--interactive-plots" target="_blank">docs</a>).</small></p>'
310    )
311    html += '<div class="mqc_mplplot_plotgroup" id="{}">'.format(pconfig["id"])
312
313    # Same defaults as HighCharts for consistency
314    default_colors = [
315        "#7cb5ec",
316        "#434348",
317        "#90ed7d",
318        "#f7a35c",
319        "#8085e9",
320        "#f15c80",
321        "#e4d354",
322        "#2b908f",
323        "#f45b5b",
324        "#91e8e1",
325    ]
326
327    # Buttons to cycle through different datasets
328    if len(plotdata) > 1 and not config.simple_output:
329        html += '<div class="btn-group mpl_switch_group mqc_mplplot_bargraph_switchds">\n'
330        for k, p in enumerate(plotdata):
331            pid = pids[k]
332            active = "active" if k == 0 else ""
333            try:
334                name = pconfig["data_labels"][k]["name"]
335            except:
336                name = k + 1
337            html += '<button class="btn btn-default btn-sm {a}" data-target="#{pid}">{n}</button>\n'.format(
338                a=active, pid=pid, n=name
339            )
340        html += "</div>\n\n"
341
342    # Go through datasets creating plots
343    for pidx, pdata in enumerate(plotdata):
344
345        # Plot ID
346        pid = pids[pidx]
347
348        # Save plot data to file
349        fdata = OrderedDict()
350        lastcats = None
351        sharedcats = True
352        for d in pdata:
353            fdata[d["name"]] = OrderedDict()
354            for i, x in enumerate(d["data"]):
355                if type(x) is list:
356                    fdata[d["name"]][str(x[0])] = x[1]
357                    # Check to see if all categories are the same
358                    if lastcats is None:
359                        lastcats = [x[0] for x in d["data"]]
360                    elif lastcats != [x[0] for x in d["data"]]:
361                        sharedcats = False
362                else:
363                    try:
364                        fdata[d["name"]][pconfig["categories"][i]] = x
365                    except (KeyError, IndexError):
366                        fdata[d["name"]][str(i)] = x
367
368        # Custom tsv output if the x axis varies
369        if not sharedcats and config.data_format == "tsv":
370            fout = ""
371            for d in pdata:
372                fout += "\t" + "\t".join([str(x[0]) for x in d["data"]])
373                fout += "\n{}\t".format(d["name"])
374                fout += "\t".join([str(x[1]) for x in d["data"]])
375                fout += "\n"
376            with io.open(os.path.join(config.data_dir, "{}.txt".format(pid)), "w", encoding="utf-8") as f:
377                print(fout.encode("utf-8", "ignore").decode("utf-8"), file=f)
378        else:
379            util_functions.write_data_file(fdata, pid)
380
381        # Set up figure
382        fig = plt.figure(figsize=(14, 6), frameon=False)
383        axes = fig.add_subplot(111)
384
385        # Go through data series
386        for idx, d in enumerate(pdata):
387
388            # Default colour index
389            cidx = idx
390            while cidx >= len(default_colors):
391                cidx -= len(default_colors)
392
393            # Line style
394            linestyle = "solid"
395            if d.get("dashStyle", None) == "Dash":
396                linestyle = "dashed"
397
398            # Reformat data (again)
399            try:
400                axes.plot(
401                    [x[0] for x in d["data"]],
402                    [x[1] for x in d["data"]],
403                    label=d["name"],
404                    color=d.get("color", default_colors[cidx]),
405                    linestyle=linestyle,
406                    linewidth=1,
407                    marker=None,
408                )
409            except TypeError:
410                # Categorical data on x axis
411                axes.plot(
412                    d["data"], label=d["name"], color=d.get("color", default_colors[cidx]), linewidth=1, marker=None
413                )
414
415        # Tidy up axes
416        axes.tick_params(labelsize=8, direction="out", left=False, right=False, top=False, bottom=False)
417        axes.set_xlabel(pconfig.get("xlab", ""))
418        axes.set_ylabel(pconfig.get("ylab", ""))
419
420        # Dataset specific y label
421        try:
422            axes.set_ylabel(pconfig["data_labels"][pidx]["ylab"])
423        except:
424            pass
425
426        # Axis limits
427        default_ylimits = axes.get_ylim()
428        ymin = default_ylimits[0]
429        if "ymin" in pconfig:
430            ymin = pconfig["ymin"]
431        elif "yFloor" in pconfig:
432            ymin = max(pconfig["yFloor"], default_ylimits[0])
433        ymax = default_ylimits[1]
434        if "ymax" in pconfig:
435            ymax = pconfig["ymax"]
436        elif "yCeiling" in pconfig:
437            ymax = min(pconfig["yCeiling"], default_ylimits[1])
438        if (ymax - ymin) < pconfig.get("yMinRange", 0):
439            ymax = ymin + pconfig["yMinRange"]
440        axes.set_ylim((ymin, ymax))
441
442        # Dataset specific ymax
443        try:
444            axes.set_ylim((ymin, pconfig["data_labels"][pidx]["ymax"]))
445        except:
446            pass
447
448        default_xlimits = axes.get_xlim()
449        xmin = default_xlimits[0]
450        if "xmin" in pconfig:
451            xmin = pconfig["xmin"]
452        elif "xFloor" in pconfig:
453            xmin = max(pconfig["xFloor"], default_xlimits[0])
454        xmax = default_xlimits[1]
455        if "xmax" in pconfig:
456            xmax = pconfig["xmax"]
457        elif "xCeiling" in pconfig:
458            xmax = min(pconfig["xCeiling"], default_xlimits[1])
459        if (xmax - xmin) < pconfig.get("xMinRange", 0):
460            xmax = xmin + pconfig["xMinRange"]
461        axes.set_xlim((xmin, xmax))
462
463        # Plot title
464        if "title" in pconfig:
465            plt.text(0.5, 1.05, pconfig["title"], horizontalalignment="center", fontsize=16, transform=axes.transAxes)
466        axes.grid(True, zorder=10, which="both", axis="y", linestyle="-", color="#dedede", linewidth=1)
467
468        # X axis categories, if specified
469        if "categories" in pconfig:
470            axes.set_xticks([i for i, v in enumerate(pconfig["categories"])])
471            axes.set_xticklabels(pconfig["categories"])
472
473        # Axis lines
474        xlim = axes.get_xlim()
475        axes.plot([xlim[0], xlim[1]], [0, 0], linestyle="-", color="#dedede", linewidth=2)
476        axes.set_axisbelow(True)
477        axes.spines["right"].set_visible(False)
478        axes.spines["top"].set_visible(False)
479        axes.spines["bottom"].set_visible(False)
480        axes.spines["left"].set_visible(False)
481
482        # Background colours, if specified
483        if "yPlotBands" in pconfig:
484            xlim = axes.get_xlim()
485            for pb in pconfig["yPlotBands"]:
486                axes.barh(
487                    pb["from"],
488                    xlim[1],
489                    height=pb["to"] - pb["from"],
490                    left=xlim[0],
491                    color=pb["color"],
492                    linewidth=0,
493                    zorder=0,
494                    align="edge",
495                )
496        if "xPlotBands" in pconfig:
497            ylim = axes.get_ylim()
498            for pb in pconfig["xPlotBands"]:
499                axes.bar(
500                    pb["from"],
501                    ylim[1],
502                    width=pb["to"] - pb["from"],
503                    bottom=ylim[0],
504                    color=pb["color"],
505                    linewidth=0,
506                    zorder=0,
507                    align="edge",
508                )
509
510        # Tight layout - makes sure that legend fits in and stuff
511        if len(pdata) <= 15:
512            axes.legend(
513                loc="lower center",
514                bbox_to_anchor=(0, -0.22, 1, 0.102),
515                ncol=5,
516                mode="expand",
517                fontsize=8,
518                frameon=False,
519            )
520            plt.tight_layout(rect=[0, 0.08, 1, 0.92])
521        else:
522            plt.tight_layout(rect=[0, 0, 1, 0.92])
523
524        # Should this plot be hidden on report load?
525        hidediv = ""
526        if pidx > 0:
527            hidediv = ' style="display:none;"'
528
529        # Save the plot to the data directory if export is requests
530        if config.export_plots:
531            for fformat in config.export_plot_formats:
532                # Make the directory if it doesn't already exist
533                plot_dir = os.path.join(config.plots_dir, fformat)
534                if not os.path.exists(plot_dir):
535                    os.makedirs(plot_dir)
536                # Save the plot
537                plot_fn = os.path.join(plot_dir, "{}.{}".format(pid, fformat))
538                fig.savefig(plot_fn, format=fformat, bbox_inches="tight")
539
540        # Output the figure to a base64 encoded string
541        if getattr(get_template_mod(), "base64_plots", True) is True:
542            img_buffer = io.BytesIO()
543            fig.savefig(img_buffer, format="png", bbox_inches="tight")
544            b64_img = base64.b64encode(img_buffer.getvalue()).decode("utf8")
545            img_buffer.close()
546            html += '<div class="mqc_mplplot" id="{}"{}><img src="data:image/png;base64,{}" /></div>'.format(
547                pid, hidediv, b64_img
548            )
549
550        # Save to a file and link <img>
551        else:
552            plot_relpath = os.path.join(config.plots_dir_name, "png", "{}.png".format(pid))
553            html += '<div class="mqc_mplplot" id="{}"{}><img src="{}" /></div>'.format(pid, hidediv, plot_relpath)
554
555        plt.close(fig)
556
557    # Close wrapping div
558    html += "</div>"
559
560    report.num_mpl_plots += 1
561
562    return html
563
564
565def smooth_line_data(data, numpoints, sumcounts=True):
566    """
567    Function to take an x-y dataset and use binning to smooth to a maximum number of datapoints.
568    Each datapoint in a smoothed dataset corresponds to the first point in a bin.
569
570    Examples to show the idea:
571
572    d=[0 1 2 3 4 5 6 7 8 9], numpoints=6
573    we want to keep the first and the last element, thus excluding the last element from the binning:
574    binsize = len([0 1 2 3 4 5 6 7 8]))/(numpoints-1) = 9/5 = 1.8
575    taking points in indices rounded from multiples of 1.8: [0, 1.8, 3.6, 5.4, 7.2, 9],
576    ...which evaluates to first_element_in_bin_indices=[0, 2, 4, 5, 7, 9]
577    picking up the elements: [0 _ 2 _ 4 5 _ 7 _ 9]
578
579    d=[0 1 2 3 4 5 6 7 8 9], numpoints=9
580    binsize = 9/8 = 1.125
581    indices: [0.0, 1.125, 2.25, 3.375, 4.5, 5.625, 6.75, 7.875, 9] -> [0, 1, 2, 3, 5, 6, 7, 8, 9]
582    picking up the elements: [0 1 2 3 _ 5 6 7 8 9]
583
584    d=[0 1 2 3 4 5 6 7 8 9], numpoints=3
585    binsize = len(d)/numpoints = 9/2 = 4.5
586    incides: [0.0, 4.5, 9] -> [0, 5, 9]
587    picking up the elements: [0 _ _ _ _ 5 _ _ _ 9]
588    """
589    smoothed_data = dict()
590    for s_name, d in data.items():
591        # Check that we need to smooth this data
592        if len(d) <= numpoints or len(d) == 0:
593            smoothed_data[s_name] = d
594            continue
595
596        binsize = (len(d) - 1) / (numpoints - 1)
597        first_element_indices = [round(binsize * i) for i in range(numpoints)]
598        smoothed_d = OrderedDict(xy for i, xy in enumerate(d.items()) if i in first_element_indices)
599        smoothed_data[s_name] = smoothed_d
600
601    return smoothed_data
602