1import sys 2import threading 3import typing as tp # NOQA 4 5from chainer import types # NOQA 6 7 8if types.TYPE_CHECKING: 9 import numpy # NOQA 10 11 from chainer.graph_optimizations import static_graph # NOQA 12 13 14class GlobalConfig(object): 15 16 debug = None # type: bool 17 cudnn_deterministic = None # type: bool 18 warn_nondeterministic = None # type: bool 19 enable_backprop = None # type: bool 20 keep_graph_on_report = None # type: bool 21 train = None # type: bool 22 type_check = None # type: bool 23 use_cudnn = None # type: str 24 use_cudnn_tensor_core = None # type: str 25 autotune = None # type: bool 26 schedule_func = None # type: tp.Optional[static_graph.StaticScheduleFunction] # NOQA 27 use_ideep = None # type: str 28 lazy_grad_sum = None # type: bool 29 cudnn_fast_batch_normalization = None # type: bool 30 dtype = None # type: numpy.dtype 31 in_recomputing = None # type: bool 32 use_static_graph = None # type: bool 33 _will_recompute = None # type: bool 34 compute_mode = None # type: str 35 36 """The plain object that represents the global configuration of Chainer.""" 37 38 def show(self, file=sys.stdout): 39 """show(file=sys.stdout) 40 41 Prints the global config entries. 42 43 The entries are sorted in the lexicographical order of the entry name. 44 45 Args: 46 file: Output file-like object. 47 48 """ 49 keys = sorted(self.__dict__) 50 _print_attrs(self, keys, file) 51 52 53class LocalConfig(object): 54 55 """Thread-local configuration of Chainer. 56 57 This class implements the local configuration. When a value is set to this 58 object, the configuration is only updated in the current thread. When a 59 user tries to access an attribute and there is no local value, it 60 automatically retrieves a value from the global configuration. 61 62 """ 63 64 def __init__(self, global_config): 65 super(LocalConfig, self).__setattr__('_global', global_config) 66 super(LocalConfig, self).__setattr__('_local', threading.local()) 67 68 def __delattr__(self, name): 69 delattr(self._local, name) 70 71 def __getattr__(self, name): 72 dic = self._local.__dict__ 73 if name in dic: 74 return dic[name] 75 return getattr(self._global, name) 76 77 def __setattr__(self, name, value): 78 setattr(self._local, name, value) 79 80 def show(self, file=sys.stdout): 81 """show(file=sys.stdout) 82 83 Prints the config entries. 84 85 The entries are sorted in the lexicographical order of the entry names. 86 87 Args: 88 file: Output file-like object. 89 90 .. admonition:: Example 91 92 You can easily print the list of configurations used in 93 the current thread. 94 95 >>> chainer.config.show() # doctest: +SKIP 96 debug False 97 enable_backprop True 98 train True 99 type_check True 100 101 """ 102 keys = sorted(set(self._global.__dict__) | set(self._local.__dict__)) 103 _print_attrs(self, keys, file) 104 105 106def _print_attrs(obj, keys, file): 107 max_len = max(len(key) for key in keys) 108 for key in keys: 109 spacer = ' ' * (max_len - len(key)) 110 file.write(u'{} {}{}\n'.format(key, spacer, getattr(obj, key))) 111 112 113global_config = GlobalConfig() 114'''Global configuration of Chainer. 115 116It is an instance of :class:`chainer.configuration.GlobalConfig`. 117See :ref:`configuration` for details. 118''' 119 120 121config = LocalConfig(global_config) 122'''Thread-local configuration of Chainer. 123 124It is an instance of :class:`chainer.configuration.LocalConfig`, and is 125referring to :data:`~chainer.global_config` as its default configuration. 126See :ref:`configuration` for details. 127''' 128 129 130class _ConfigContext(object): 131 132 is_local = False 133 old_value = None 134 135 def __init__(self, config, name, value): 136 self.config = config 137 self.name = name 138 self.value = value 139 140 def __enter__(self): 141 name = self.name 142 value = self.value 143 config = self.config 144 is_local = hasattr(config._local, name) 145 if is_local: 146 self.old_value = getattr(config, name) 147 self.is_local = is_local 148 149 setattr(config, name, value) 150 151 def __exit__(self, typ, value, traceback): 152 if self.is_local: 153 setattr(self.config, self.name, self.old_value) 154 else: 155 delattr(self.config, self.name) 156 157 158def using_config(name, value, config=config): 159 """using_config(name, value, config=chainer.config) 160 161 Context manager to temporarily change the thread-local configuration. 162 163 Args: 164 name (str): Name of the configuration to change. 165 value: Temporary value of the configuration entry. 166 config (~chainer.configuration.LocalConfig): Configuration object. 167 Chainer's thread-local configuration is used by default. 168 169 .. seealso:: 170 :ref:`configuration` 171 172 """ 173 return _ConfigContext(config, name, value) 174