1import sys
2import threading
3import typing as tp  # NOQA
4
5from chainer import types  # NOQA
6
7
8if types.TYPE_CHECKING:
9    import numpy  # NOQA
10
11    from chainer.graph_optimizations import static_graph  # NOQA
12
13
14class GlobalConfig(object):
15
16    debug = None  # type: bool
17    cudnn_deterministic = None  # type: bool
18    warn_nondeterministic = None  # type: bool
19    enable_backprop = None  # type: bool
20    keep_graph_on_report = None  # type: bool
21    train = None  # type: bool
22    type_check = None  # type: bool
23    use_cudnn = None  # type: str
24    use_cudnn_tensor_core = None  # type: str
25    autotune = None  # type: bool
26    schedule_func = None  # type: tp.Optional[static_graph.StaticScheduleFunction] # NOQA
27    use_ideep = None  # type: str
28    lazy_grad_sum = None  # type: bool
29    cudnn_fast_batch_normalization = None  # type: bool
30    dtype = None  # type: numpy.dtype
31    in_recomputing = None  # type: bool
32    use_static_graph = None  # type: bool
33    _will_recompute = None  # type: bool
34    compute_mode = None  # type: str
35
36    """The plain object that represents the global configuration of Chainer."""
37
38    def show(self, file=sys.stdout):
39        """show(file=sys.stdout)
40
41        Prints the global config entries.
42
43        The entries are sorted in the lexicographical order of the entry name.
44
45        Args:
46            file: Output file-like object.
47
48        """
49        keys = sorted(self.__dict__)
50        _print_attrs(self, keys, file)
51
52
53class LocalConfig(object):
54
55    """Thread-local configuration of Chainer.
56
57    This class implements the local configuration. When a value is set to this
58    object, the configuration is only updated in the current thread. When a
59    user tries to access an attribute and there is no local value, it
60    automatically retrieves a value from the global configuration.
61
62    """
63
64    def __init__(self, global_config):
65        super(LocalConfig, self).__setattr__('_global', global_config)
66        super(LocalConfig, self).__setattr__('_local', threading.local())
67
68    def __delattr__(self, name):
69        delattr(self._local, name)
70
71    def __getattr__(self, name):
72        dic = self._local.__dict__
73        if name in dic:
74            return dic[name]
75        return getattr(self._global, name)
76
77    def __setattr__(self, name, value):
78        setattr(self._local, name, value)
79
80    def show(self, file=sys.stdout):
81        """show(file=sys.stdout)
82
83        Prints the config entries.
84
85        The entries are sorted in the lexicographical order of the entry names.
86
87        Args:
88            file: Output file-like object.
89
90        .. admonition:: Example
91
92           You can easily print the list of configurations used in
93           the current thread.
94
95              >>> chainer.config.show()  # doctest: +SKIP
96              debug           False
97              enable_backprop True
98              train           True
99              type_check      True
100
101        """
102        keys = sorted(set(self._global.__dict__) | set(self._local.__dict__))
103        _print_attrs(self, keys, file)
104
105
106def _print_attrs(obj, keys, file):
107    max_len = max(len(key) for key in keys)
108    for key in keys:
109        spacer = ' ' * (max_len - len(key))
110        file.write(u'{} {}{}\n'.format(key, spacer, getattr(obj, key)))
111
112
113global_config = GlobalConfig()
114'''Global configuration of Chainer.
115
116It is an instance of :class:`chainer.configuration.GlobalConfig`.
117See :ref:`configuration` for details.
118'''
119
120
121config = LocalConfig(global_config)
122'''Thread-local configuration of Chainer.
123
124It is an instance of :class:`chainer.configuration.LocalConfig`, and is
125referring to :data:`~chainer.global_config` as its default configuration.
126See :ref:`configuration` for details.
127'''
128
129
130class _ConfigContext(object):
131
132    is_local = False
133    old_value = None
134
135    def __init__(self, config, name, value):
136        self.config = config
137        self.name = name
138        self.value = value
139
140    def __enter__(self):
141        name = self.name
142        value = self.value
143        config = self.config
144        is_local = hasattr(config._local, name)
145        if is_local:
146            self.old_value = getattr(config, name)
147            self.is_local = is_local
148
149        setattr(config, name, value)
150
151    def __exit__(self, typ, value, traceback):
152        if self.is_local:
153            setattr(self.config, self.name, self.old_value)
154        else:
155            delattr(self.config, self.name)
156
157
158def using_config(name, value, config=config):
159    """using_config(name, value, config=chainer.config)
160
161    Context manager to temporarily change the thread-local configuration.
162
163    Args:
164        name (str): Name of the configuration to change.
165        value: Temporary value of the configuration entry.
166        config (~chainer.configuration.LocalConfig): Configuration object.
167            Chainer's thread-local configuration is used by default.
168
169    .. seealso::
170        :ref:`configuration`
171
172    """
173    return _ConfigContext(config, name, value)
174