1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import os
16import sys
17
18def bool_env(varname: str, default: bool) -> bool:
19  """Read an environment variable and interpret it as a boolean.
20
21  True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1';
22  false values are 'n', 'no', 'f', 'false', 'off', and '0'.
23
24  Args:
25    varname: the name of the variable
26    default: the default boolean value
27  Raises: ValueError if the environment variable is anything else.
28  """
29  val = os.getenv(varname, str(default))
30  val = val.lower()
31  if val in ('y', 'yes', 't', 'true', 'on', '1'):
32    return True
33  elif val in ('n', 'no', 'f', 'false', 'off', '0'):
34    return False
35  else:
36    raise ValueError("invalid truth value %r for environment %r" % (val, varname))
37
38def int_env(varname: str, default: int) -> int:
39  """Read an environment variable and interpret it as an integer."""
40  return int(os.getenv(varname, default))
41
42
43class Config:
44  def __init__(self):
45    self.values = {}
46    self.meta = {}
47    self.FLAGS = NameSpace(self.read)
48    self.use_absl = False
49    self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True)
50    self._omnistaging_disablers = []
51
52  def update(self, name, val):
53    if self.use_absl:
54      setattr(self.absl_flags.FLAGS, name, val)
55    else:
56      self.check_exists(name)
57      if name not in self.values:
58        raise Exception("Unrecognized config option: {}".format(name))
59      self.values[name] = val
60
61  def read(self, name):
62    if self.use_absl:
63      return getattr(self.absl_flags.FLAGS, name)
64    else:
65      self.check_exists(name)
66      return self.values[name]
67
68  def add_option(self, name, default, opt_type, meta_args, meta_kwargs):
69    if name in self.values:
70      raise Exception("Config option {} already defined".format(name))
71    self.values[name] = default
72    self.meta[name] = (opt_type, meta_args, meta_kwargs)
73
74  def check_exists(self, name):
75    if name not in self.values:
76      raise AttributeError("Unrecognized config option: {}".format(name))
77
78  def DEFINE_bool(self, name, default, *args, **kwargs):
79    self.add_option(name, default, bool, args, kwargs)
80
81  def DEFINE_integer(self, name, default, *args, **kwargs):
82    self.add_option(name, default, int, args, kwargs)
83
84  def DEFINE_string(self, name, default, *args, **kwargs):
85    self.add_option(name, default, str, args, kwargs)
86
87  def DEFINE_enum(self, name, default, *args, **kwargs):
88    self.add_option(name, default, 'enum', args, kwargs)
89
90  def config_with_absl(self):
91    # Run this before calling `app.run(main)` etc
92    import absl.flags as absl_FLAGS  # noqa: F401
93    from absl import app, flags as absl_flags
94
95    self.use_absl = True
96    self.absl_flags = absl_flags
97    absl_defs = { bool: absl_flags.DEFINE_bool,
98                  int:  absl_flags.DEFINE_integer,
99                  str:  absl_flags.DEFINE_string,
100                  'enum': absl_flags.DEFINE_enum }
101
102    for name, val in self.values.items():
103      flag_type, meta_args, meta_kwargs = self.meta[name]
104      absl_defs[flag_type](name, val, *meta_args, **meta_kwargs)
105
106    app.call_after_init(lambda: self.complete_absl_config(absl_flags))
107
108  def complete_absl_config(self, absl_flags):
109    for name, _ in self.values.items():
110      self.update(name, getattr(absl_flags.FLAGS, name))
111
112  def parse_flags_with_absl(self):
113    global already_configured_with_absl
114    if not already_configured_with_absl:
115      import absl.flags
116      self.config_with_absl()
117      absl.flags.FLAGS(sys.argv, known_only=True)
118      self.complete_absl_config(absl.flags)
119      already_configured_with_absl = True
120
121      if not FLAGS.jax_omnistaging:
122        self.disable_omnistaging()
123
124
125  def register_omnistaging_disabler(self, disabler):
126    if self.omnistaging_enabled:
127      self._omnistaging_disablers.append(disabler)
128    else:
129      disabler()
130
131  def enable_omnistaging(self):
132    if not self.omnistaging_enabled:
133      raise Exception("can't re-enable omnistaging after it's been disabled")
134
135  def disable_omnistaging(self):
136    if self.omnistaging_enabled:
137      for disabler in self._omnistaging_disablers:
138        disabler()
139      self.omnistaging_enabled = False
140
141
142class NameSpace(object):
143  def __init__(self, getter):
144    self._getter = getter
145
146  def __getattr__(self, name):
147    return self._getter(name)
148
149
150config = Config()
151flags = config
152FLAGS = flags.FLAGS
153
154already_configured_with_absl = False
155
156flags.DEFINE_bool(
157    'jax_enable_checks',
158    bool_env('JAX_ENABLE_CHECKS', False),
159    help='Turn on invariant checking (core.skip_checks = False)'
160)
161
162flags.DEFINE_bool(
163    'jax_omnistaging',
164    bool_env('JAX_OMNISTAGING', True),
165    help='Enable staging based on dynamic context rather than data dependence.'
166)
167
168flags.DEFINE_integer(
169    'jax_tracer_error_num_traceback_frames',
170    int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5),
171    help='Set the number of stack frames in JAX tracer error messages.'
172)
173
174flags.DEFINE_bool(
175    'jax_check_tracer_leaks',
176    bool_env('JAX_CHECK_TRACER_LEAKS', False),
177    help=('Turn on checking for leaked tracers as soon as a trace completes. '
178          'Enabling leak checking may have performance impacts: some caching '
179          'is disabled, and other overheads may be added.'),
180)
181