1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=protected-access, logging-format-interpolation, invalid-name, no-member, too-many-branches
20"""Monitor outputs, weights, and gradients for debugging."""
21
22import re
23import ctypes
24import logging
25from math import sqrt
26
27from .ndarray import NDArray
28from .base import NDArrayHandle, py_str
29from . import ndarray
30
31
32class Monitor(object):
33    """Monitor inputs, outputs, weights, and gradients for debugging.
34
35    Parameters
36    ----------
37    interval : int
38        Number of batches between printing.
39    stat_func : function
40        A function that computes statistics of tensors.
41        Takes an `NDArray` and returns an `NDArray`. Defaults to mean
42        absolute value abs(x)/size(x).
43    pattern : str
44        A regular expression specifying which tensors to monitor.
45        Only tensors with names that match `name_pattern` will be included.
46        For example, '.*weight|.*output' will print all weights and outputs and
47        '.*backward.*' will print all gradients.
48    monitor_all : bool, default False
49        If true, monitor both input and output, otherwise monitor output only.
50    """
51    def __init__(self, interval, stat_func=None, pattern='.*', sort=False, monitor_all=False):
52        if stat_func is None:
53            def asum_stat(x):
54                """returns |x|/size(x), async execution."""
55                return ndarray.norm(x)/sqrt(x.size)
56            stat_func = asum_stat
57        self.stat_func = stat_func
58        self.interval = interval
59        self.activated = False
60        self.queue = []
61        self.step = 0
62        self.exes = []
63        self.re_prog = re.compile(pattern)
64        self.sort = sort
65        self.monitor_all = monitor_all
66        def stat_helper(name, array):
67            """wrapper for executor callback"""
68            array = ctypes.cast(array, NDArrayHandle)
69            array = NDArray(array, writable=False)
70            if not self.activated or not self.re_prog.match(py_str(name)):
71                return
72            self.queue.append((self.step, py_str(name), self.stat_func(array)))
73        self.stat_helper = stat_helper
74
75    def install(self, exe):
76        """install callback to executor.
77        Supports installing to multiple exes.
78
79        Parameters
80        ----------
81        exe : mx.executor.Executor
82            The Executor (returned by symbol.bind) to install to.
83        """
84        exe.set_monitor_callback(self.stat_helper, self.monitor_all)
85        self.exes.append(exe)
86
87    def tic(self):
88        """Start collecting stats for current batch.
89        Call before calling forward."""
90        if self.step % self.interval == 0:
91            for exe in self.exes:
92                for array in exe.arg_arrays:
93                    array.wait_to_read()
94                for array in exe.aux_arrays:
95                    array.wait_to_read()
96            self.queue = []
97            self.activated = True
98        self.step += 1
99
100
101    def toc(self):
102        """End collecting for current batch and return results.
103        Call after computation of current batch.
104
105        Returns
106        -------
107        res : list of """
108        if not self.activated:
109            return []
110        for exe in self.exes:
111            for array in exe.arg_arrays:
112                array.wait_to_read()
113            for array in exe.aux_arrays:
114                array.wait_to_read()
115        for exe in self.exes:
116            for name, array in zip(exe._symbol.list_arguments(), exe.arg_arrays):
117                if self.re_prog.match(name):
118                    self.queue.append((self.step, name, self.stat_func(array)))
119            for name, array in zip(exe._symbol.list_auxiliary_states(), exe.aux_arrays):
120                if self.re_prog.match(name):
121                    self.queue.append((self.step, name, self.stat_func(array)))
122        self.activated = False
123        res = []
124        if self.sort:
125            self.queue.sort(key=lambda x: x[1])
126        for n, k, v_list in self.queue:
127            if isinstance(v_list, NDArray):
128                v_list = [v_list]
129            assert isinstance(v_list, list)
130            s = ''
131            for v in v_list:
132                assert isinstance(v, NDArray)
133                if v.shape == (1,):
134                    s += str(v.asscalar()) + '\t'
135                else:
136                    s += str(v.asnumpy()) + '\t'
137            res.append((n, k, s))
138        self.queue = []
139        return res
140
141    def toc_print(self):
142        """End collecting and print results."""
143        res = self.toc()
144        for n, k, v in res:
145            logging.info('Batch: {:7d} {:30s} {:s}'.format(n, k, v))
146