1"""
2A directive for including a matplotlib plot in a Sphinx document.
3
4By default, in HTML output, `plot` will include a .png file with a
5link to a high-res .png and .pdf.  In LaTeX output, it will include a
6.pdf.
7
8The source code for the plot may be included in one of three ways:
9
10  1. **A path to a source file** as the argument to the directive::
11
12       .. plot:: path/to/plot.py
13
14     When a path to a source file is given, the content of the
15     directive may optionally contain a caption for the plot::
16
17       .. plot:: path/to/plot.py
18
19          This is the caption for the plot
20
21     Additionally, one my specify the name of a function to call (with
22     no arguments) immediately after importing the module::
23
24       .. plot:: path/to/plot.py plot_function1
25
26  2. Included as **inline content** to the directive::
27
28       .. plot::
29
30          import matplotlib.pyplot as plt
31          import matplotlib.image as mpimg
32          import numpy as np
33          img = mpimg.imread('_static/stinkbug.png')
34          imgplot = plt.imshow(img)
35
36  3. Using **doctest** syntax::
37
38       .. plot::
39          A plotting example:
40          >>> import matplotlib.pyplot as plt
41          >>> plt.plot([1,2,3], [4,5,6])
42
43Options
44-------
45
46The ``plot`` directive supports the following options:
47
48    format : {'python', 'doctest'}
49        Specify the format of the input
50
51    include-source : bool
52        Whether to display the source code. The default can be changed
53        using the `plot_include_source` variable in conf.py
54
55    encoding : str
56        If this source file is in a non-UTF8 or non-ASCII encoding,
57        the encoding must be specified using the `:encoding:` option.
58        The encoding will not be inferred using the ``-*- coding -*-``
59        metacomment.
60
61    context : bool
62        If provided, the code will be run in the context of all
63        previous plot directives for which the `:context:` option was
64        specified.  This only applies to inline code plot directives,
65        not those run from files.
66
67    nofigs : bool
68        If specified, the code block will be run, but no figures will
69        be inserted.  This is usually useful with the ``:context:``
70        option.
71
72Additionally, this directive supports all of the options of the
73`image` directive, except for `target` (since plot will add its own
74target).  These include `alt`, `height`, `width`, `scale`, `align` and
75`class`.
76
77Configuration options
78---------------------
79
80The plot directive has the following configuration options:
81
82    plot_include_source
83        Default value for the include-source option
84
85    plot_pre_code
86        Code that should be executed before each plot.
87
88    plot_basedir
89        Base directory, to which ``plot::`` file names are relative
90        to.  (If None or empty, file names are relative to the
91        directoly where the file containing the directive is.)
92
93    plot_formats
94        File formats to generate. List of tuples or strings::
95
96            [(suffix, dpi), suffix, ...]
97
98        that determine the file format and the DPI. For entries whose
99        DPI was omitted, sensible defaults are chosen.
100
101    plot_html_show_formats
102        Whether to show links to the files in HTML.
103
104    plot_rcparams
105        A dictionary containing any non-standard rcParams that should
106        be applied before each plot.
107
108"""
109
110import sys, os, glob, shutil, imp, warnings, cStringIO, re, textwrap, \
111       traceback, exceptions
112
113from docutils.parsers.rst import directives
114from docutils import nodes
115from docutils.parsers.rst.directives.images import Image
116align = Image.align
117import sphinx
118
119sphinx_version = sphinx.__version__.split(".")
120# The split is necessary for sphinx beta versions where the string is
121# '6b1'
122sphinx_version = tuple([int(re.split('[a-z]', x)[0])
123                        for x in sphinx_version[:2]])
124
125try:
126    # Sphinx depends on either Jinja or Jinja2
127    import jinja2
128    def format_template(template, **kw):
129        return jinja2.Template(template).render(**kw)
130except ImportError:
131    import jinja
132    def format_template(template, **kw):
133        return jinja.from_string(template, **kw)
134
135
136import matplotlib
137import matplotlib.cbook as cbook
138matplotlib.use('Agg')
139import matplotlib.pyplot as plt
140from matplotlib import _pylab_helpers
141
142__version__ = 2
143
144#------------------------------------------------------------------------------
145# Relative pathnames
146#------------------------------------------------------------------------------
147
148# os.path.relpath is new in Python 2.6
149try:
150    from os.path import relpath
151except ImportError:
152    # Copied from Python 2.7
153    if 'posix' in sys.builtin_module_names:
154        def relpath(path, start=os.path.curdir):
155            """Return a relative version of a path"""
156            from os.path import sep, curdir, join, abspath, commonprefix, \
157                 pardir
158
159            if not path:
160                raise ValueError("no path specified")
161
162            start_list = abspath(start).split(sep)
163            path_list = abspath(path).split(sep)
164
165            # Work out how much of the filepath is shared by start and path.
166            i = len(commonprefix([start_list, path_list]))
167
168            rel_list = [pardir] * (len(start_list)-i) + path_list[i:]
169            if not rel_list:
170                return curdir
171            return join(*rel_list)
172    elif 'nt' in sys.builtin_module_names:
173        def relpath(path, start=os.path.curdir):
174            """Return a relative version of a path"""
175            from os.path import sep, curdir, join, abspath, commonprefix, \
176                 pardir, splitunc
177
178            if not path:
179                raise ValueError("no path specified")
180            start_list = abspath(start).split(sep)
181            path_list = abspath(path).split(sep)
182            if start_list[0].lower() != path_list[0].lower():
183                unc_path, rest = splitunc(path)
184                unc_start, rest = splitunc(start)
185                if bool(unc_path) ^ bool(unc_start):
186                    raise ValueError("Cannot mix UNC and non-UNC paths (%s and %s)"
187                                                                        % (path, start))
188                else:
189                    raise ValueError("path is on drive %s, start on drive %s"
190                                                        % (path_list[0], start_list[0]))
191            # Work out how much of the filepath is shared by start and path.
192            for i in range(min(len(start_list), len(path_list))):
193                if start_list[i].lower() != path_list[i].lower():
194                    break
195            else:
196                i += 1
197
198            rel_list = [pardir] * (len(start_list)-i) + path_list[i:]
199            if not rel_list:
200                return curdir
201            return join(*rel_list)
202    else:
203        raise RuntimeError("Unsupported platform (no relpath available!)")
204
205#------------------------------------------------------------------------------
206# Registration hook
207#------------------------------------------------------------------------------
208
209def plot_directive(name, arguments, options, content, lineno,
210                   content_offset, block_text, state, state_machine):
211    return run(arguments, content, options, state_machine, state, lineno)
212plot_directive.__doc__ = __doc__
213
214def _option_boolean(arg):
215    if not arg or not arg.strip():
216        # no argument given, assume used as a flag
217        return True
218    elif arg.strip().lower() in ('no', '0', 'false'):
219        return False
220    elif arg.strip().lower() in ('yes', '1', 'true'):
221        return True
222    else:
223        raise ValueError('"%s" unknown boolean' % arg)
224
225def _option_format(arg):
226    return directives.choice(arg, ('python', 'doctest'))
227
228def _option_align(arg):
229    return directives.choice(arg, ("top", "middle", "bottom", "left", "center",
230                                   "right"))
231
232def mark_plot_labels(app, document):
233    """
234    To make plots referenceable, we need to move the reference from
235    the "htmlonly" (or "latexonly") node to the actual figure node
236    itself.
237    """
238    for name, explicit in document.nametypes.items():
239        if not explicit:
240            continue
241        labelid = document.nameids[name]
242        if labelid is None:
243            continue
244        node = document.ids[labelid]
245        if node.tagname in ('html_only', 'latex_only'):
246            for n in node:
247                if n.tagname == 'figure':
248                    sectname = name
249                    for c in n:
250                        if c.tagname == 'caption':
251                            sectname = c.astext()
252                            break
253
254                    node['ids'].remove(labelid)
255                    node['names'].remove(name)
256                    n['ids'].append(labelid)
257                    n['names'].append(name)
258                    document.settings.env.labels[name] = \
259                        document.settings.env.docname, labelid, sectname
260                    break
261
262def setup(app):
263    setup.app = app
264    setup.config = app.config
265    setup.confdir = app.confdir
266
267    options = {'alt': directives.unchanged,
268               'height': directives.length_or_unitless,
269               'width': directives.length_or_percentage_or_unitless,
270               'scale': directives.nonnegative_int,
271               'align': _option_align,
272               'class': directives.class_option,
273               'include-source': _option_boolean,
274               'format': _option_format,
275               'context': directives.flag,
276               'nofigs': directives.flag,
277               'encoding': directives.encoding
278               }
279
280    app.add_directive('plot', plot_directive, True, (0, 2, False), **options)
281    app.add_config_value('plot_pre_code', None, True)
282    app.add_config_value('plot_include_source', False, True)
283    app.add_config_value('plot_formats', ['png', 'hires.png', 'pdf'], True)
284    app.add_config_value('plot_basedir', None, True)
285    app.add_config_value('plot_html_show_formats', True, True)
286    app.add_config_value('plot_rcparams', {}, True)
287
288    app.connect('doctree-read', mark_plot_labels)
289
290#------------------------------------------------------------------------------
291# Doctest handling
292#------------------------------------------------------------------------------
293
294def contains_doctest(text):
295    try:
296        # check if it's valid Python as-is
297        compile(text, '<string>', 'exec')
298        return False
299    except SyntaxError:
300        pass
301    r = re.compile(r'^\s*>>>', re.M)
302    m = r.search(text)
303    return bool(m)
304
305def unescape_doctest(text):
306    """
307    Extract code from a piece of text, which contains either Python code
308    or doctests.
309
310    """
311    if not contains_doctest(text):
312        return text
313
314    code = ""
315    for line in text.split("\n"):
316        m = re.match(r'^\s*(>>>|\.\.\.) (.*)$', line)
317        if m:
318            code += m.group(2) + "\n"
319        elif line.strip():
320            code += "# " + line.strip() + "\n"
321        else:
322            code += "\n"
323    return code
324
325def split_code_at_show(text):
326    """
327    Split code at plt.show()
328
329    """
330
331    parts = []
332    is_doctest = contains_doctest(text)
333
334    part = []
335    for line in text.split("\n"):
336        if (not is_doctest and line.strip() == 'plt.show()') or \
337               (is_doctest and line.strip() == '>>> plt.show()'):
338            part.append(line)
339            parts.append("\n".join(part))
340            part = []
341        else:
342            part.append(line)
343    if "\n".join(part).strip():
344        parts.append("\n".join(part))
345    return parts
346
347#------------------------------------------------------------------------------
348# Template
349#------------------------------------------------------------------------------
350
351
352TEMPLATE = """
353{{ source_code }}
354
355{{ only_html }}
356
357   {% if source_link or (html_show_formats and not multi_image) %}
358   (
359   {%- if source_link -%}
360   `Source code <{{ source_link }}>`__
361   {%- endif -%}
362   {%- if html_show_formats and not multi_image -%}
363     {%- for img in images -%}
364       {%- for fmt in img.formats -%}
365         {%- if source_link or not loop.first -%}, {% endif -%}
366         `{{ fmt }} <{{ dest_dir }}/{{ img.basename }}.{{ fmt }}>`__
367       {%- endfor -%}
368     {%- endfor -%}
369   {%- endif -%}
370   )
371   {% endif %}
372
373   {% for img in images %}
374   .. figure:: {{ build_dir }}/{{ img.basename }}.png
375      {%- for option in options %}
376      {{ option }}
377      {% endfor %}
378
379      {% if html_show_formats and multi_image -%}
380        (
381        {%- for fmt in img.formats -%}
382        {%- if not loop.first -%}, {% endif -%}
383        `{{ fmt }} <{{ dest_dir }}/{{ img.basename }}.{{ fmt }}>`__
384        {%- endfor -%}
385        )
386      {%- endif -%}
387
388      {{ caption }}
389   {% endfor %}
390
391{{ only_latex }}
392
393   {% for img in images %}
394   .. image:: {{ build_dir }}/{{ img.basename }}.pdf
395   {% endfor %}
396
397"""
398
399exception_template = """
400.. htmlonly::
401
402   [`source code <%(linkdir)s/%(basename)s.py>`__]
403
404Exception occurred rendering plot.
405
406"""
407
408# the context of the plot for all directives specified with the
409# :context: option
410plot_context = dict()
411
412class ImageFile(object):
413    def __init__(self, basename, dirname):
414        self.basename = basename
415        self.dirname = dirname
416        self.formats = []
417
418    def filename(self, format):
419        return os.path.join(self.dirname, "%s.%s" % (self.basename, format))
420
421    def filenames(self):
422        return [self.filename(fmt) for fmt in self.formats]
423
424def out_of_date(original, derived):
425    """
426    Returns True if derivative is out-of-date wrt original,
427    both of which are full file paths.
428    """
429    return (not os.path.exists(derived) or
430            (os.path.exists(original) and
431             os.stat(derived).st_mtime < os.stat(original).st_mtime))
432
433class PlotError(RuntimeError):
434    pass
435
436def run_code(code, code_path, ns=None, function_name=None):
437    """
438    Import a Python module from a path, and run the function given by
439    name, if function_name is not None.
440    """
441
442    # Change the working directory to the directory of the example, so
443    # it can get at its data files, if any.  Add its path to sys.path
444    # so it can import any helper modules sitting beside it.
445
446    pwd = os.getcwd()
447    old_sys_path = list(sys.path)
448    if code_path is not None:
449        dirname = os.path.abspath(os.path.dirname(code_path))
450        os.chdir(dirname)
451        sys.path.insert(0, dirname)
452
453    # Redirect stdout
454    stdout = sys.stdout
455    sys.stdout = cStringIO.StringIO()
456
457    # Reset sys.argv
458    old_sys_argv = sys.argv
459    sys.argv = [code_path]
460
461    try:
462        try:
463            code = unescape_doctest(code)
464            if ns is None:
465                ns = {}
466            if not ns:
467                if setup.config.plot_pre_code is None:
468                    exec "import numpy as np\nfrom matplotlib import pyplot as plt\n" in ns
469                else:
470                    exec setup.config.plot_pre_code in ns
471            exec code in ns
472            if function_name is not None:
473                exec function_name + "()" in ns
474        except (Exception, SystemExit), err:
475            raise PlotError(traceback.format_exc())
476    finally:
477        os.chdir(pwd)
478        sys.argv = old_sys_argv
479        sys.path[:] = old_sys_path
480        sys.stdout = stdout
481    return ns
482
483def clear_state(plot_rcparams):
484    plt.close('all')
485    matplotlib.rcdefaults()
486    matplotlib.rcParams.update(plot_rcparams)
487
488def render_figures(code, code_path, output_dir, output_base, context,
489                   function_name, config):
490    """
491    Run a pyplot script and save the low and high res PNGs and a PDF
492    in outdir.
493
494    Save the images under *output_dir* with file names derived from
495    *output_base*
496    """
497    # -- Parse format list
498    default_dpi = {'png': 80, 'hires.png': 200, 'pdf': 200}
499    formats = []
500    for fmt in config.plot_formats:
501        if isinstance(fmt, str):
502            formats.append((fmt, default_dpi.get(fmt, 80)))
503        elif type(fmt) in (tuple, list) and len(fmt)==2:
504            formats.append((str(fmt[0]), int(fmt[1])))
505        else:
506            raise PlotError('invalid image format "%r" in plot_formats' % fmt)
507
508    # -- Try to determine if all images already exist
509
510    code_pieces = split_code_at_show(code)
511
512    # Look for single-figure output files first
513    # Look for single-figure output files first
514    all_exists = True
515    img = ImageFile(output_base, output_dir)
516    for format, dpi in formats:
517        if out_of_date(code_path, img.filename(format)):
518            all_exists = False
519            break
520        img.formats.append(format)
521
522    if all_exists:
523        return [(code, [img])]
524
525    # Then look for multi-figure output files
526    results = []
527    all_exists = True
528    for i, code_piece in enumerate(code_pieces):
529        images = []
530        for j in xrange(1000):
531            if len(code_pieces) > 1:
532                img = ImageFile('%s_%02d_%02d' % (output_base, i, j), output_dir)
533            else:
534                img = ImageFile('%s_%02d' % (output_base, j), output_dir)
535            for format, dpi in formats:
536                if out_of_date(code_path, img.filename(format)):
537                    all_exists = False
538                    break
539                img.formats.append(format)
540
541            # assume that if we have one, we have them all
542            if not all_exists:
543                all_exists = (j > 0)
544                break
545            images.append(img)
546        if not all_exists:
547            break
548        results.append((code_piece, images))
549
550    if all_exists:
551        return results
552
553    # We didn't find the files, so build them
554
555    results = []
556    if context:
557        ns = plot_context
558    else:
559        ns = {}
560
561    for i, code_piece in enumerate(code_pieces):
562        if not context:
563            clear_state(config.plot_rcparams)
564        run_code(code_piece, code_path, ns, function_name)
565
566        images = []
567        fig_managers = _pylab_helpers.Gcf.get_all_fig_managers()
568        for j, figman in enumerate(fig_managers):
569            if len(fig_managers) == 1 and len(code_pieces) == 1:
570                img = ImageFile(output_base, output_dir)
571            elif len(code_pieces) == 1:
572                img = ImageFile("%s_%02d" % (output_base, j), output_dir)
573            else:
574                img = ImageFile("%s_%02d_%02d" % (output_base, i, j),
575                                output_dir)
576            images.append(img)
577            for format, dpi in formats:
578                try:
579                    figman.canvas.figure.savefig(img.filename(format),
580                                                 dpi=dpi,
581                                                 bbox_inches='tight')
582                except exceptions.BaseException as err:
583                    raise PlotError(traceback.format_exc())
584                img.formats.append(format)
585
586        results.append((code_piece, images))
587
588    return results
589
590def run(arguments, content, options, state_machine, state, lineno):
591    # The user may provide a filename *or* Python code content, but not both
592    if arguments and content:
593        raise RuntimeError("plot:: directive can't have both args and content")
594
595    document = state_machine.document
596    config = document.settings.env.config
597    nofigs = options.has_key('nofigs')
598
599    options.setdefault('include-source', config.plot_include_source)
600    context = options.has_key('context')
601
602    rst_file = document.attributes['source']
603    rst_dir = os.path.dirname(rst_file)
604
605    if len(arguments):
606        if not config.plot_basedir:
607            source_file_name = os.path.join(setup.app.builder.srcdir,
608                                            directives.uri(arguments[0]))
609        else:
610            source_file_name = os.path.join(setup.confdir, config.plot_basedir,
611                                            directives.uri(arguments[0]))
612
613        # If there is content, it will be passed as a caption.
614        caption = '\n'.join(content)
615
616        # If the optional function name is provided, use it
617        if len(arguments) == 2:
618            function_name = arguments[1]
619        else:
620            function_name = None
621
622        fd = open(source_file_name, 'r')
623        code = fd.read()
624        fd.close()
625        output_base = os.path.basename(source_file_name)
626    else:
627        source_file_name = rst_file
628        code = textwrap.dedent("\n".join(map(str, content)))
629        counter = document.attributes.get('_plot_counter', 0) + 1
630        document.attributes['_plot_counter'] = counter
631        base, ext = os.path.splitext(os.path.basename(source_file_name))
632        output_base = '%s-%d.py' % (base, counter)
633        function_name = None
634        caption = ''
635
636    base, source_ext = os.path.splitext(output_base)
637    if source_ext in ('.py', '.rst', '.txt'):
638        output_base = base
639    else:
640        source_ext = ''
641
642    # ensure that LaTeX includegraphics doesn't choke in foo.bar.pdf filenames
643    output_base = output_base.replace('.', '-')
644
645    # is it in doctest format?
646    is_doctest = contains_doctest(code)
647    if options.has_key('format'):
648        if options['format'] == 'python':
649            is_doctest = False
650        else:
651            is_doctest = True
652
653    # determine output directory name fragment
654    source_rel_name = relpath(source_file_name, setup.confdir)
655    source_rel_dir = os.path.dirname(source_rel_name)
656    while source_rel_dir.startswith(os.path.sep):
657        source_rel_dir = source_rel_dir[1:]
658
659    # build_dir: where to place output files (temporarily)
660    build_dir = os.path.join(os.path.dirname(setup.app.doctreedir),
661                             'plot_directive',
662                             source_rel_dir)
663    # get rid of .. in paths, also changes pathsep
664    # see note in Python docs for warning about symbolic links on Windows.
665    # need to compare source and dest paths at end
666    build_dir = os.path.normpath(build_dir)
667
668    if not os.path.exists(build_dir):
669        os.makedirs(build_dir)
670
671    # output_dir: final location in the builder's directory
672    dest_dir = os.path.abspath(os.path.join(setup.app.builder.outdir,
673                                            source_rel_dir))
674    if not os.path.exists(dest_dir):
675        os.makedirs(dest_dir) # no problem here for me, but just use built-ins
676
677    # how to link to files from the RST file
678    dest_dir_link = os.path.join(relpath(setup.confdir, rst_dir),
679                                 source_rel_dir).replace(os.path.sep, '/')
680    build_dir_link = relpath(build_dir, rst_dir).replace(os.path.sep, '/')
681    source_link = dest_dir_link + '/' + output_base + source_ext
682
683    # make figures
684    try:
685        results = render_figures(code, source_file_name, build_dir, output_base,
686                                 context, function_name, config)
687        errors = []
688    except PlotError, err:
689        reporter = state.memo.reporter
690        sm = reporter.system_message(
691            2, "Exception occurred in plotting %s: %s" % (output_base, err),
692            line=lineno)
693        results = [(code, [])]
694        errors = [sm]
695
696    # Properly indent the caption
697    caption = '\n'.join('      ' + line.strip()
698                        for line in caption.split('\n'))
699
700    # generate output restructuredtext
701    total_lines = []
702    for j, (code_piece, images) in enumerate(results):
703        if options['include-source']:
704            if is_doctest:
705                lines = ['']
706                lines += [row.rstrip() for row in code_piece.split('\n')]
707            else:
708                lines = ['.. code-block:: python', '']
709                lines += ['    %s' % row.rstrip()
710                          for row in code_piece.split('\n')]
711            source_code = "\n".join(lines)
712        else:
713            source_code = ""
714
715        if nofigs:
716            images = []
717
718        opts = [':%s: %s' % (key, val) for key, val in options.items()
719                if key in ('alt', 'height', 'width', 'scale', 'align', 'class')]
720
721        only_html = ".. only:: html"
722        only_latex = ".. only:: latex"
723
724        if j == 0:
725            src_link = source_link
726        else:
727            src_link = None
728
729        result = format_template(
730            TEMPLATE,
731            dest_dir=dest_dir_link,
732            build_dir=build_dir_link,
733            source_link=src_link,
734            multi_image=len(images) > 1,
735            only_html=only_html,
736            only_latex=only_latex,
737            options=opts,
738            images=images,
739            source_code=source_code,
740            html_show_formats=config.plot_html_show_formats,
741            caption=caption)
742
743        total_lines.extend(result.split("\n"))
744        total_lines.extend("\n")
745
746    if total_lines:
747        state_machine.insert_input(total_lines, source=source_file_name)
748
749    # copy image files to builder's output directory, if necessary
750    if not os.path.exists(dest_dir):
751        cbook.mkdirs(dest_dir)
752
753    for code_piece, images in results:
754        for img in images:
755            for fn in img.filenames():
756                destimg = os.path.join(dest_dir, os.path.basename(fn))
757                if fn != destimg:
758                    shutil.copyfile(fn, destimg)
759
760    # copy script (if necessary)
761    if source_file_name == rst_file:
762        target_name = os.path.join(dest_dir, output_base + source_ext)
763        f = open(target_name, 'w')
764        f.write(unescape_doctest(code))
765        f.close()
766
767    return errors
768