1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=protected-access, logging-format-interpolation, invalid-name, no-member, too-many-branches 20"""Monitor outputs, weights, and gradients for debugging.""" 21 22import re 23import ctypes 24import logging 25from math import sqrt 26 27from .ndarray import NDArray 28from .base import NDArrayHandle, py_str 29from . import ndarray 30 31 32class Monitor(object): 33 """Monitor inputs, outputs, weights, and gradients for debugging. 34 35 Parameters 36 ---------- 37 interval : int 38 Number of batches between printing. 39 stat_func : function 40 A function that computes statistics of tensors. 41 Takes an `NDArray` and returns an `NDArray`. Defaults to mean 42 absolute value abs(x)/size(x). 43 pattern : str 44 A regular expression specifying which tensors to monitor. 45 Only tensors with names that match `name_pattern` will be included. 46 For example, '.*weight|.*output' will print all weights and outputs and 47 '.*backward.*' will print all gradients. 48 monitor_all : bool, default False 49 If true, monitor both input and output, otherwise monitor output only. 50 """ 51 def __init__(self, interval, stat_func=None, pattern='.*', sort=False, monitor_all=False): 52 if stat_func is None: 53 def asum_stat(x): 54 """returns |x|/size(x), async execution.""" 55 return ndarray.norm(x)/sqrt(x.size) 56 stat_func = asum_stat 57 self.stat_func = stat_func 58 self.interval = interval 59 self.activated = False 60 self.queue = [] 61 self.step = 0 62 self.exes = [] 63 self.re_prog = re.compile(pattern) 64 self.sort = sort 65 self.monitor_all = monitor_all 66 def stat_helper(name, array): 67 """wrapper for executor callback""" 68 array = ctypes.cast(array, NDArrayHandle) 69 array = NDArray(array, writable=False) 70 if not self.activated or not self.re_prog.match(py_str(name)): 71 return 72 self.queue.append((self.step, py_str(name), self.stat_func(array))) 73 self.stat_helper = stat_helper 74 75 def install(self, exe): 76 """install callback to executor. 77 Supports installing to multiple exes. 78 79 Parameters 80 ---------- 81 exe : mx.executor.Executor 82 The Executor (returned by symbol.bind) to install to. 83 """ 84 exe.set_monitor_callback(self.stat_helper, self.monitor_all) 85 self.exes.append(exe) 86 87 def tic(self): 88 """Start collecting stats for current batch. 89 Call before calling forward.""" 90 if self.step % self.interval == 0: 91 for exe in self.exes: 92 for array in exe.arg_arrays: 93 array.wait_to_read() 94 for array in exe.aux_arrays: 95 array.wait_to_read() 96 self.queue = [] 97 self.activated = True 98 self.step += 1 99 100 101 def toc(self): 102 """End collecting for current batch and return results. 103 Call after computation of current batch. 104 105 Returns 106 ------- 107 res : list of """ 108 if not self.activated: 109 return [] 110 for exe in self.exes: 111 for array in exe.arg_arrays: 112 array.wait_to_read() 113 for array in exe.aux_arrays: 114 array.wait_to_read() 115 for exe in self.exes: 116 for name, array in zip(exe._symbol.list_arguments(), exe.arg_arrays): 117 if self.re_prog.match(name): 118 self.queue.append((self.step, name, self.stat_func(array))) 119 for name, array in zip(exe._symbol.list_auxiliary_states(), exe.aux_arrays): 120 if self.re_prog.match(name): 121 self.queue.append((self.step, name, self.stat_func(array))) 122 self.activated = False 123 res = [] 124 if self.sort: 125 self.queue.sort(key=lambda x: x[1]) 126 for n, k, v_list in self.queue: 127 if isinstance(v_list, NDArray): 128 v_list = [v_list] 129 assert isinstance(v_list, list) 130 s = '' 131 for v in v_list: 132 assert isinstance(v, NDArray) 133 if v.shape == (1,): 134 s += str(v.asscalar()) + '\t' 135 else: 136 s += str(v.asnumpy()) + '\t' 137 res.append((n, k, s)) 138 self.queue = [] 139 return res 140 141 def toc_print(self): 142 """End collecting and print results.""" 143 res = self.toc() 144 for n, k, v in res: 145 logging.info('Batch: {:7d} {:30s} {:s}'.format(n, k, v)) 146