1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import os 16import sys 17 18def bool_env(varname: str, default: bool) -> bool: 19 """Read an environment variable and interpret it as a boolean. 20 21 True values are (case insensitive): 'y', 'yes', 't', 'true', 'on', and '1'; 22 false values are 'n', 'no', 'f', 'false', 'off', and '0'. 23 24 Args: 25 varname: the name of the variable 26 default: the default boolean value 27 Raises: ValueError if the environment variable is anything else. 28 """ 29 val = os.getenv(varname, str(default)) 30 val = val.lower() 31 if val in ('y', 'yes', 't', 'true', 'on', '1'): 32 return True 33 elif val in ('n', 'no', 'f', 'false', 'off', '0'): 34 return False 35 else: 36 raise ValueError("invalid truth value %r for environment %r" % (val, varname)) 37 38def int_env(varname: str, default: int) -> int: 39 """Read an environment variable and interpret it as an integer.""" 40 return int(os.getenv(varname, default)) 41 42 43class Config: 44 def __init__(self): 45 self.values = {} 46 self.meta = {} 47 self.FLAGS = NameSpace(self.read) 48 self.use_absl = False 49 self.omnistaging_enabled = bool_env('JAX_OMNISTAGING', True) 50 self._omnistaging_disablers = [] 51 52 def update(self, name, val): 53 if self.use_absl: 54 setattr(self.absl_flags.FLAGS, name, val) 55 else: 56 self.check_exists(name) 57 if name not in self.values: 58 raise Exception("Unrecognized config option: {}".format(name)) 59 self.values[name] = val 60 61 def read(self, name): 62 if self.use_absl: 63 return getattr(self.absl_flags.FLAGS, name) 64 else: 65 self.check_exists(name) 66 return self.values[name] 67 68 def add_option(self, name, default, opt_type, meta_args, meta_kwargs): 69 if name in self.values: 70 raise Exception("Config option {} already defined".format(name)) 71 self.values[name] = default 72 self.meta[name] = (opt_type, meta_args, meta_kwargs) 73 74 def check_exists(self, name): 75 if name not in self.values: 76 raise AttributeError("Unrecognized config option: {}".format(name)) 77 78 def DEFINE_bool(self, name, default, *args, **kwargs): 79 self.add_option(name, default, bool, args, kwargs) 80 81 def DEFINE_integer(self, name, default, *args, **kwargs): 82 self.add_option(name, default, int, args, kwargs) 83 84 def DEFINE_string(self, name, default, *args, **kwargs): 85 self.add_option(name, default, str, args, kwargs) 86 87 def DEFINE_enum(self, name, default, *args, **kwargs): 88 self.add_option(name, default, 'enum', args, kwargs) 89 90 def config_with_absl(self): 91 # Run this before calling `app.run(main)` etc 92 import absl.flags as absl_FLAGS # noqa: F401 93 from absl import app, flags as absl_flags 94 95 self.use_absl = True 96 self.absl_flags = absl_flags 97 absl_defs = { bool: absl_flags.DEFINE_bool, 98 int: absl_flags.DEFINE_integer, 99 str: absl_flags.DEFINE_string, 100 'enum': absl_flags.DEFINE_enum } 101 102 for name, val in self.values.items(): 103 flag_type, meta_args, meta_kwargs = self.meta[name] 104 absl_defs[flag_type](name, val, *meta_args, **meta_kwargs) 105 106 app.call_after_init(lambda: self.complete_absl_config(absl_flags)) 107 108 def complete_absl_config(self, absl_flags): 109 for name, _ in self.values.items(): 110 self.update(name, getattr(absl_flags.FLAGS, name)) 111 112 def parse_flags_with_absl(self): 113 global already_configured_with_absl 114 if not already_configured_with_absl: 115 import absl.flags 116 self.config_with_absl() 117 absl.flags.FLAGS(sys.argv, known_only=True) 118 self.complete_absl_config(absl.flags) 119 already_configured_with_absl = True 120 121 if not FLAGS.jax_omnistaging: 122 self.disable_omnistaging() 123 124 125 def register_omnistaging_disabler(self, disabler): 126 if self.omnistaging_enabled: 127 self._omnistaging_disablers.append(disabler) 128 else: 129 disabler() 130 131 def enable_omnistaging(self): 132 if not self.omnistaging_enabled: 133 raise Exception("can't re-enable omnistaging after it's been disabled") 134 135 def disable_omnistaging(self): 136 if self.omnistaging_enabled: 137 for disabler in self._omnistaging_disablers: 138 disabler() 139 self.omnistaging_enabled = False 140 141 142class NameSpace(object): 143 def __init__(self, getter): 144 self._getter = getter 145 146 def __getattr__(self, name): 147 return self._getter(name) 148 149 150config = Config() 151flags = config 152FLAGS = flags.FLAGS 153 154already_configured_with_absl = False 155 156flags.DEFINE_bool( 157 'jax_enable_checks', 158 bool_env('JAX_ENABLE_CHECKS', False), 159 help='Turn on invariant checking (core.skip_checks = False)' 160) 161 162flags.DEFINE_bool( 163 'jax_omnistaging', 164 bool_env('JAX_OMNISTAGING', True), 165 help='Enable staging based on dynamic context rather than data dependence.' 166) 167 168flags.DEFINE_integer( 169 'jax_tracer_error_num_traceback_frames', 170 int_env('JAX_TRACER_ERROR_NUM_TRACEBACK_FRAMES', 5), 171 help='Set the number of stack frames in JAX tracer error messages.' 172) 173 174flags.DEFINE_bool( 175 'jax_check_tracer_leaks', 176 bool_env('JAX_CHECK_TRACER_LEAKS', False), 177 help=('Turn on checking for leaked tracers as soon as a trace completes. ' 178 'Enabling leak checking may have performance impacts: some caching ' 179 'is disabled, and other overheads may be added.'), 180) 181