1# Copyright 2021 The Flax Authors. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Flax functional core: Scopes.""" 16 17import contextlib 18import functools 19import hashlib 20import dataclasses 21from typing import Any, Callable, Container, Dict, Generic, Iterable, Mapping, Optional, Sequence, Set, Tuple, TypeVar, Union 22 23from . import tracers 24from flax import errors 25from .frozen_dict import freeze 26from .frozen_dict import FrozenDict 27from .frozen_dict import unfreeze 28import jax 29from jax import numpy as jnp 30from jax import random 31 32T = TypeVar('T') 33 34PRNGKey = Any 35Array = Any 36 37RNGSequences = Dict[str, PRNGKey] 38 39 40Filter = Union[bool, str, Container[str], 'DenyList'] 41 42@dataclasses.dataclass(frozen=True, eq=True) 43class DenyList: 44 deny: Filter 45 46 47CollectionFilter = Filter 48PRNGSequenceFilter = Filter 49 50Collection = Mapping[str, Any] 51MutableCollection = Dict[str, Any] 52 53VariableDict = Mapping[str, Collection] 54FrozenVariableDict = FrozenDict[str, Collection] 55MutableVariableDict = Dict[str, MutableCollection] 56 57 58def _fold_in_str(rng: PRNGKey, data: str) -> PRNGKey: 59 """Folds a string into a jax.random.PRNGKey using its SHA-1 hash. 60 61 This is faster than splitting an PRNGKey because it allows generating new PRNG 62 keys in parellel that are independent of each other. 63 64 Args: 65 rng: the rng to fold the string into. 66 data: the string to be folded in. 67 68 Returns: 69 The newly generated PRNG key. 70 """ 71 m = hashlib.sha1() 72 m.update(data.encode('utf-8')) 73 d = m.digest() 74 hash_int = int.from_bytes(d[:4], byteorder='big') 75 return random.fold_in(rng, jnp.uint32(hash_int)) 76 77 78def in_filter(filter_like: Filter, col: str) -> bool: 79 """Checks whether a filter can be applied to a collection. 80 81 Used for both collections and rng sequence filters. 82 83 Args: 84 filter_like: a filter (either a boolean, a string, or a list of strings) for 85 a collection. 86 col: a collection, which is a string identifying a dictionary of data, for 87 instance "params" or "batch_stats". 88 89 Returns: 90 True if either `filter_like` is True, equal to `col`, or a sequence 91 containing `col`. 92 """ 93 if isinstance(filter_like, str): 94 return col == filter_like 95 if isinstance(filter_like, Container): 96 return col in filter_like 97 if isinstance(filter_like, bool): 98 return filter_like 99 if isinstance(filter_like, DenyList): 100 return not in_filter(filter_like.deny, col) 101 raise errors.InvalidFilterError(filter_like) 102 103 104def filter_to_set(x: Filter) -> Set[str]: 105 """Converts a Filter into a set of collections, fails on the infinite set. 106 107 Args: 108 x: a filter (boolean, string, or list of strings). 109 110 Returns: 111 The input filter represented as a set of strings. 112 """ 113 assert x is not True and not isinstance(x, DenyList), 'Infinite set' 114 if x is False: 115 return set() 116 if isinstance(x, str): 117 return set([x]) 118 if isinstance(x, Iterable): 119 return set(x) 120 raise errors.InvalidFilterError(x) 121 122 123def union_filters(a: Filter, b: Filter) -> Filter: 124 """Takes the union of two filters (similar to a logical or). 125 126 Args: 127 a: a filter. 128 b: a filter. 129 130 Returns: 131 The union of the two input filters. For instance, 132 `union_filters('f1', ['f2']) = {'f1', 'f2'}`. 133 """ 134 if a is True or b is True: 135 return True 136 if isinstance(a, DenyList) and isinstance(b, DenyList): 137 return DenyList(intersect_filters(a.deny, b.deny)) 138 if isinstance(b, DenyList): 139 a, b = b, a 140 if isinstance(a, DenyList): 141 return DenyList(subtract_filters(a.deny, b)) 142 143 a = filter_to_set(a) 144 b = filter_to_set(b) 145 return a.union(b) 146 147 148def subtract_filters(a: Filter, b: Filter) -> Filter: 149 """Returns the subtraction of b from a. 150 151 Args: 152 a: a filter. 153 b: a filter. 154 155 Returns: 156 A filter matching with values in a that are not in b. 157 """ 158 if b is True: 159 return False 160 if a is True: 161 return DenyList(b) 162 if isinstance(a, DenyList) and isinstance(b, DenyList): 163 return subtract_filters(b.deny, a.deny) 164 if isinstance(a, DenyList): 165 return DenyList(union_filters(a.deny, b)) 166 if isinstance(b, DenyList): 167 return intersect_filters(a, b.deny) 168 a = filter_to_set(a) 169 b = filter_to_set(b) 170 return a - b 171 172 173def intersect_filters(a: Filter, b: Filter) -> Filter: 174 """Take the intersection of two filters (similar to a logical and). 175 176 Args: 177 a: a filter. 178 b: a filter. 179 180 Returns: 181 The intersection of the two input filters. For instance, 182 `intersect_filters('f1', ['f1', 'f2']) = {'f1'}`. 183 """ 184 if a is True: 185 return b 186 if b is True: 187 return a 188 if isinstance(a, DenyList) and isinstance(b, DenyList): 189 return DenyList(union_filters(b.deny, a.deny)) 190 if isinstance(b, DenyList): 191 b, a = a, b 192 if isinstance(a, DenyList): 193 return subtract_filters(b, a.deny) 194 a = filter_to_set(a) 195 b = filter_to_set(b) 196 return a.intersection(b) 197 198 199def group_collections( 200 xs: VariableDict, 201 col_filters: Sequence[CollectionFilter]) -> Sequence[MutableVariableDict]: 202 """Groups variables by collection filters. 203 204 Iteratively applies the filters in `col_filters` to `xs`, and adds the result 205 of applying each filter to the output sequence. Each key in `xs` is only added 206 to the output once. 207 208 Args: 209 xs: a dictionary of variables, keyed by collections (strings). 210 col_filters: a list of collection filters. 211 212 Returns: 213 A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of 214 applying filter `col_filters[i]` to the remaining keys in `xs`. 215 """ 216 cols = xs.keys() 217 groups = [] 218 for col_filter in col_filters: 219 remaining_cols = [] 220 group = {} 221 for col in cols: 222 if in_filter(col_filter, col): 223 group[col] = jax.tree_map(lambda x: x, xs[col]) 224 else: 225 remaining_cols.append(col) 226 cols = remaining_cols 227 groups.append(group) 228 return tuple(groups) 229 230 231class Variable(Generic[T]): 232 """A Variable object allows mutable access to a variable in a VariableDict. 233 234 Variables are identified by a collection (e.g., "batch_stats") and a name 235 (e.g., "moving_mean"). The value property gives access to the variable's 236 content and can be assigned to for mutation. 237 """ 238 239 def __init__(self, scope: 'Scope', collection: str, name: str): 240 """Initializes a variable. 241 242 Args: 243 scope: The scope in which the variable is stored. 244 collection: The collection of the variable (e.g., "params"). 245 name: The name of the variable (e.g., "dense"). 246 """ 247 self.scope = scope 248 self.collection = collection 249 self.name = name 250 251 @property 252 def value(self) -> T: 253 """Returns the value of this Variable.""" 254 return self.scope.get_variable(self.collection, self.name) 255 256 @value.setter 257 def value(self, value: T): 258 """Updates the value of this Variable.""" 259 self.scope.put_variable(self.collection, self.name, value) 260 261 def is_mutable(self) -> bool: 262 """Checks if this Variable is mutable.""" 263 return self.scope.is_mutable_collection(self.collection) 264 265 266class Scope: 267 """A Scope allows easy access to variables and manages RNGS of a neural network layer. 268 269 Scopes are purely functional and encapsulated in 270 :class:`flax.linen.module.Module`, so users writing neural network code 271 usually generally do not interact with ``Scopes`` directly. 272 273 See `core design tests 274 <https://github.com/google/flax/tree/master/tests/core/design>`_ 275 for a number of examples using ``Scopes``. 276 """ 277 278 def __init__(self, 279 variables: MutableVariableDict, 280 rngs: Optional[Dict[str, PRNGKey]] = None, 281 name: Optional[str] = None, 282 mutable: CollectionFilter = False, 283 parent: Optional['Scope'] = None, 284 path: Tuple[str] = ()): 285 """Initializes a Scope. 286 287 Args: 288 variables: VariableDict to initialize the Scope with. 289 rngs: RNGs used in this scope or one of the child scopes. 290 name: name of this scope. 291 mutable: A CollectionFilter determining which variables are mutable. 292 parent: The parent scope. 293 path: The path in the variable tree from the root scope to this scope. 294 """ 295 self._variables = variables 296 self.parent = parent 297 self.name = name 298 self.path = path 299 self.rngs = rngs if rngs else {} 300 self.mutable = mutable 301 302 self.root = parent.root if parent else self 303 self.trace_level = tracers.trace_level(tracers.current_trace()) 304 305 self.rng_counters = {key: 0 for key in self.rngs} 306 self.reservations = set() 307 308 self._children = {} 309 310 self._invalid = False 311 312 @property 313 def path_text(self) -> str: 314 """Returns the path as a human readable string with slashes between parts.""" 315 return '/' + '/'.join(self.path) 316 317 @property 318 def invalid(self) -> bool: 319 """Returns true if this scope is invalidated as a result of `Scope.temporary`.""" 320 return self._invalid 321 322 def _check_valid(self): 323 if self._invalid: 324 raise errors.InvalidScopeError(self.name) 325 326 @contextlib.contextmanager 327 def temporary(self): 328 """Returns a context manager that will invalidate this Scope when leaving the context.""" 329 try: 330 yield self 331 finally: 332 self.invalidate() 333 334 def invalidate(self): 335 """Invalidates the Scope.""" 336 self._invalid = True 337 338 def mutable_variables(self) -> VariableDict: 339 """Returns an immutable copy of the mutable variables belonging to this Scope.""" 340 self._populate_collections() 341 xs = {k: v for k, v in self._variables.items() 342 if in_filter(self.mutable, k)} 343 return freeze(xs) 344 345 def variables(self) -> VariableDict: 346 """Returns an immutable copy of the variables belonging to this Scope.""" 347 self._populate_collections() 348 return freeze(self._variables) 349 350 def _validate_trace_level(self): 351 tracers.check_trace_level(self.trace_level) 352 353 def rewound(self, rewind_rngs: bool = False) -> 'Scope': 354 """Returns a rewound version of this Scope. 355 356 Args: 357 rewind_rngs: if true, reset the RNG counter of this scope. 358 359 Returns: 360 A rewound version of this scope, which means reservations and children are 361 emptied, and the rng counter is optionally rewound. 362 """ 363 self._check_valid() 364 scope = Scope(self._variables, self.rngs, self.name, self.mutable, 365 self.parent) 366 if not rewind_rngs: 367 scope.rng_counters = self.rng_counters 368 return scope 369 370 def reserve(self, name: str): 371 """Reserves a name for a child Scope or Variable. 372 373 Args: 374 name: the name to reserve. 375 """ 376 if not isinstance(name, str): 377 raise TypeError('The type of scope "{name}" should be string but ' 378 f'it is {type(name)}') 379 if name in self.reservations: 380 raise ValueError(f'Duplicate use of scope name: "{name}"') 381 self.reservations.add(name) 382 383 def default_name(self, prefix: str) -> str: 384 """Generates an unreserved name with the given prefix. 385 386 Args: 387 prefix: prefix to use for generating an unreserved name. 388 389 Returns: 390 The generated name. 391 """ 392 i = 0 393 while True: 394 name = f'{prefix}{i}' 395 if name not in self.reservations: 396 return name 397 i += 1 398 399 def push(self, 400 name: Optional[str] = None, 401 prefix: str = '', 402 reuse=False) -> 'Scope': 403 """Creates a child Scope. 404 405 Args: 406 name: optional name of the child. 407 prefix: prefix used for generating the name if `name` is `None`. 408 reuse: if True will return a pre-existing child scope with the given name 409 instead of throwing an error. 410 411 Returns: 412 The child scope. 413 """ 414 self._check_valid() 415 self._validate_trace_level() 416 if name is None: 417 name = self.default_name(prefix) 418 if reuse and name in self._children: 419 return self._children[name] 420 self.reserve(name) 421 rngs = {key: _fold_in_str(rng, name) for key, rng in self.rngs.items()} 422 scope = Scope({}, 423 name=name, 424 rngs=rngs, 425 parent=self, 426 path=self.path + (name,)) 427 self._children[name] = scope 428 return scope 429 430 def child(self, 431 fn: Callable[..., Any], 432 name: Optional[str] = None, 433 prefix: Optional[str] = None, 434 named_call: bool = True, 435 **partial_kwargs) -> Callable[..., Any]: 436 """Partially applies a child scope to fn. 437 438 When calling the returned function multiple times variables will be reused. 439 440 Args: 441 fn: the function to partially apply the child Scope to. 442 name: optional name of the child. 443 prefix: prefix used for generating name if it is `None`. 444 named_call: if true, `fn` will be wrapped with `lift.named_call`. The XLA 445 profiler will use this to name tag the computation. 446 **partial_kwargs: additional kwargs partially applied to `fn`. 447 448 Returns: 449 The function with a partially applied scope. 450 """ 451 if name is None: 452 if prefix is None: 453 prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else '' 454 name = self.default_name(prefix) 455 scope = self.push(name) 456 if named_call: 457 # We import named_call at runtime to avoid a circular import issue. 458 from . import lift # type: ignore 459 fn = lift.named_call(fn, name) 460 461 @functools.wraps(fn) 462 def wrapper(*args, **kwargs): 463 kwargs = dict(partial_kwargs, **kwargs) 464 return fn(scope.rewound(), *args, **kwargs) 465 466 return wrapper 467 468 def is_mutable_collection(self, col: str) -> bool: 469 """Returns true if the collection `col` is mutable.""" 470 return in_filter(self.root.mutable, col) 471 472 def _mutable_collection(self, col: str) -> MutableCollection: 473 """Returns the collection `col` as a mutable object.""" 474 assert self.is_mutable_collection(col), f'Collection {col} is not mutable' 475 if col not in self._variables: 476 if self.parent: 477 parent_col = self.parent._mutable_collection(col) 478 if self.name not in parent_col: 479 parent_col[self.name] = {} 480 self._variables[col] = parent_col[self.name] 481 else: 482 self._variables[col] = {} 483 return self._variables[col] 484 485 def _collection(self, col: str) -> Collection: 486 """Returns a collection of variables of collection `col`.""" 487 if col not in self._variables: 488 if self.parent: 489 parent_col = self.parent._collection(col) 490 if self.name not in parent_col: 491 return FrozenDict() 492 self._variables[col] = parent_col[self.name] 493 else: 494 return FrozenDict() 495 return self._variables[col] 496 497 def has_rng(self, name: str) -> bool: 498 """Returns true if a PRNGSequence with name `name` exists.""" 499 return name in self.rngs 500 501 def make_rng(self, name: str) -> PRNGKey: 502 """Generates A PRNGKey from a PRNGSequence with name `name`.""" 503 if not self.has_rng(name): 504 raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') 505 self._check_valid() 506 self._validate_trace_level() 507 self.rng_counters[name] += 1 508 return random.fold_in(self.rngs[name], self.rng_counters[name]) 509 510 def get_variable(self, col: str, name: str, default: T = None) -> T: 511 """Retrieves the value of a Variable. 512 513 Args: 514 col: the variable collection. 515 name: the name of the variable. 516 default: the default value to return if the variable does not exist in 517 this scope. 518 519 Returns: 520 The value of the input variable, of the default value if the variable 521 doesn't exist in this scope. 522 """ 523 variables = self._collection(col) 524 if name in variables: 525 return variables[name] 526 else: 527 return default 528 529 def has_variable(self, col: str, name: str) -> bool: 530 """Returns true if the given variable exists in this scope. 531 532 Args: 533 col: the collection of the variable. 534 name: the name of the variable. 535 """ 536 variables = self._collection(col) 537 return name in variables 538 539 def put_variable(self, col: str, name: str, value: Any): 540 """Updates the value of the given variable if it is mutable, or an error otherwise. 541 542 Args: 543 col: the collection of the variable. 544 name: the name of the variable. 545 value: the new value of the given variable. 546 """ 547 self._check_valid() 548 self._validate_trace_level() 549 if not self.is_mutable_collection(col): 550 raise errors.ModifyScopeVariableError(col, name, self.path_text) 551 variables = self._mutable_collection(col) 552 variables[name] = value 553 554 def variable(self, col: str, name: str, init_fn: Callable[..., T], 555 *init_args) -> Variable[T]: 556 """Creates a variable if it doesn't exist yet in this scope and returns it. 557 558 Args: 559 col: the collection of the variable. 560 name: the name of the variable. 561 init_fn: a function taking a PRNGKey plus any other number of positional 562 arguments. 563 *init_args: the arguments to evaluate init_fn on lazily. 564 565 Returns: 566 The variable. 567 """ 568 self.reserve(name) 569 if not self.has_variable(col, name): 570 if not self.is_mutable_collection(col): 571 raise errors.ScopeVariableNotFoundError(name, col, self.path_text) 572 init_value = init_fn(*init_args) 573 self.put_variable(col, name, init_value) 574 return Variable(self, col, name) 575 576 def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T: 577 """Creates a parameter if it doesn't exist yet in this scope and returns it. 578 579 If the parameter exists already, the existing value is simply returned. 580 581 Args: 582 name: the name of the parameter. 583 init_fn: a function taking a PRNGKey plus any other number of positional 584 arguments. 585 *init_args: the arguments to evaluate init_fn on lazily. 586 587 Returns: 588 The parameters. 589 """ 590 self.reserve(name) 591 if self.has_variable('params', name): 592 abs_rng = jax.ShapeDtypeStruct((2,), jnp.uint32) 593 value = self.get_variable('params', name) 594 # Validate that the shape of the init_fn output is the same as the shape 595 # of the existing parameter. This is to make sure that the hparams set up 596 # in a Flax Module match the shapes coming in during apply, and if not, 597 # catch it with an error message. 598 # NOTE: We could consider moving this to `self.` 599 abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng) 600 abs_value_flat = jax.tree_leaves(abs_value) 601 value_flat = jax.tree_leaves(value) 602 for val, abs_val in zip(value_flat, abs_value_flat): 603 # NOTE: We could check dtype consistency here as well but it's 604 # usefuleness is less obvious. We might intentionally change the dtype 605 # for inference to a half float type for example. 606 if jnp.shape(val) != jnp.shape(abs_val): 607 raise errors.ScopeParamShapeError(name, self.path_text, 608 jnp.shape(val), jnp.shape(abs_val)) 609 else: 610 if not self.is_mutable_collection('params'): 611 raise errors.ScopeParamNotFoundError(name, self.path_text) 612 value = init_fn(self.make_rng('params'), *init_args) 613 self.put_variable('params', name, value) 614 615 return value 616 617 def _populate_collections(self): 618 collections = self.root._variables.keys() 619 for col in collections: 620 self._collection(col) 621 622 623def _unfreeze_variables(variables, mutable): 624 new_variables = {} 625 for key, value in variables.items(): 626 if in_filter(mutable, key): 627 new_variables[key] = unfreeze(value) 628 else: 629 new_variables[key] = freeze(value) 630 return new_variables 631 632 633def bind(variables: VariableDict, 634 rngs: Optional[RNGSequences] = None, 635 mutable: CollectionFilter = False): 636 """Bind variables and rngs to a new ``Scope``. 637 638 bind provides a ``Scope`` instance without transforming a function 639 with ``apply``. This is particulary useful for debugging and 640 interactive use cases like notebooks where a function would limit 641 the ability split up code into different cells. 642 643 a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional 644 and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore, 645 we recommend using ``apply`` when code should be reusable and compatible 646 across the JAX software ecosystem. 647 """ 648 if not _is_valid_variables(variables): 649 raise errors.ApplyScopeInvalidVariablesError() 650 if rngs is not None and not _is_valid_rngs(rngs): 651 raise errors.InvalidRngError( 652 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.') 653 new_variables = _unfreeze_variables(variables, mutable) 654 return Scope(new_variables, rngs=rngs, mutable=mutable) 655 656 657def apply(fn: Callable[..., Any], 658 mutable: CollectionFilter = False) -> Callable[..., Any]: 659 """Functionalize a `Scope` function. 660 661 Args: 662 fn: a function taking a `Scope` as its first argument. 663 mutable: the filter determining which variable collections are mutable. 664 665 Returns: 666 `fn` with the scope partially applied. 667 """ 668 669 @functools.wraps(fn) 670 def wrapper(variables: VariableDict, 671 *args, 672 rngs: Optional[RNGSequences] = None, 673 **kwargs) -> Union[Any, Tuple[Any, VariableDict]]: 674 with bind(variables, rngs=rngs, mutable=mutable).temporary() as root: 675 y = fn(root, *args, **kwargs) 676 if mutable is not False: 677 return y, root.mutable_variables() 678 else: 679 return y 680 681 return wrapper 682 683 684def init(fn: Callable[..., Any], 685 mutable: CollectionFilter = True) -> Callable[..., Any]: 686 """Functionalize a `Scope` function for initialization. 687 688 Args: 689 fn: a function taking a `Scope` as its first argument. 690 mutable: the filter determining which variable collections are mutable. 691 692 Returns: 693 `fn` with the scope partially applied. 694 """ 695 696 @functools.wraps(fn) 697 def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: 698 if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): 699 raise ValueError('First argument passed to an init function should be a ' 700 '`jax.PRNGKey` or a dictionary mapping strings to ' 701 '`jax.PRNGKey`.') 702 if not isinstance(rngs, dict): 703 rngs = {'params': rngs} 704 return apply(fn, mutable=mutable)({}, *args, rngs=rngs, **kwargs) 705 706 return wrapper 707 708 709def _is_valid_collection(col: VariableDict): 710 if not isinstance(col, (FrozenDict, dict)): 711 return False 712 for name in col.keys(): 713 # Any value can be stored in a collection so only keys can be verified. 714 if not isinstance(name, str): 715 return False 716 return True 717 718 719def _is_valid_variables(variables: VariableDict) -> bool: 720 """Checks whether the given variable dict is valid. 721 722 Args: 723 variables: A variable dict. 724 725 Returns: 726 True if `variables` is a valid variable dict. 727 """ 728 for name, col in variables.items(): 729 if not isinstance(name, str): 730 return False 731 if not _is_valid_collection(col): 732 return False 733 return True 734 735 736def _is_valid_rng(rng: Array): 737 if not isinstance(rng, jnp.ndarray): 738 return False 739 if rng.shape != (2,) or rng.dtype != jnp.uint32: 740 return False 741 return True 742 743 744def _is_valid_rngs(rngs: RNGSequences): 745 if not isinstance(rngs, dict): 746 return False 747 for key, val in rngs.items(): 748 if not isinstance(key, str): 749 return False 750 if not _is_valid_rng(val): 751 return False 752 return True 753