1# Copyright 2021 The Flax Authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Flax functional core: Scopes."""
16
17import contextlib
18import functools
19import hashlib
20import dataclasses
21from typing import Any, Callable, Container, Dict, Generic, Iterable, Mapping, Optional, Sequence, Set, Tuple, TypeVar, Union
22
23from . import tracers
24from flax import errors
25from .frozen_dict import freeze
26from .frozen_dict import FrozenDict
27from .frozen_dict import unfreeze
28import jax
29from jax import numpy as jnp
30from jax import random
31
32T = TypeVar('T')
33
34PRNGKey = Any
35Array = Any
36
37RNGSequences = Dict[str, PRNGKey]
38
39
40Filter = Union[bool, str, Container[str], 'DenyList']
41
42@dataclasses.dataclass(frozen=True, eq=True)
43class DenyList:
44  deny: Filter
45
46
47CollectionFilter = Filter
48PRNGSequenceFilter = Filter
49
50Collection = Mapping[str, Any]
51MutableCollection = Dict[str, Any]
52
53VariableDict = Mapping[str, Collection]
54FrozenVariableDict = FrozenDict[str, Collection]
55MutableVariableDict = Dict[str, MutableCollection]
56
57
58def _fold_in_str(rng: PRNGKey, data: str) -> PRNGKey:
59  """Folds a string into a jax.random.PRNGKey using its SHA-1 hash.
60
61  This is faster than splitting an PRNGKey because it allows generating new PRNG
62  keys in parellel that are independent of each other.
63
64  Args:
65   rng: the rng to fold the string into.
66   data: the string to be folded in.
67
68  Returns:
69   The newly generated PRNG key.
70  """
71  m = hashlib.sha1()
72  m.update(data.encode('utf-8'))
73  d = m.digest()
74  hash_int = int.from_bytes(d[:4], byteorder='big')
75  return random.fold_in(rng, jnp.uint32(hash_int))
76
77
78def in_filter(filter_like: Filter, col: str) -> bool:
79  """Checks whether a filter can be applied to a collection.
80
81  Used for both collections and rng sequence filters.
82
83  Args:
84    filter_like: a filter (either a boolean, a string, or a list of strings) for
85      a collection.
86    col: a collection, which is a string identifying a dictionary of data, for
87      instance "params" or "batch_stats".
88
89  Returns:
90    True if either `filter_like` is True, equal to `col`, or a sequence
91    containing `col`.
92  """
93  if isinstance(filter_like, str):
94    return col == filter_like
95  if isinstance(filter_like, Container):
96    return col in filter_like
97  if isinstance(filter_like, bool):
98    return filter_like
99  if isinstance(filter_like, DenyList):
100    return not in_filter(filter_like.deny, col)
101  raise errors.InvalidFilterError(filter_like)
102
103
104def filter_to_set(x: Filter) -> Set[str]:
105  """Converts a Filter into a set of collections, fails on the infinite set.
106
107  Args:
108    x: a filter (boolean, string, or list of strings).
109
110  Returns:
111    The input filter represented as a set of strings.
112  """
113  assert x is not True and not isinstance(x, DenyList), 'Infinite set'
114  if x is False:
115    return set()
116  if isinstance(x, str):
117    return set([x])
118  if isinstance(x, Iterable):
119    return set(x)
120  raise errors.InvalidFilterError(x)
121
122
123def union_filters(a: Filter, b: Filter) -> Filter:
124  """Takes the union of two filters (similar to a logical or).
125
126  Args:
127    a: a filter.
128    b: a filter.
129
130  Returns:
131    The union of the two input filters. For instance,
132    `union_filters('f1', ['f2']) = {'f1', 'f2'}`.
133  """
134  if a is True or b is True:
135    return True
136  if isinstance(a, DenyList) and isinstance(b, DenyList):
137    return DenyList(intersect_filters(a.deny, b.deny))
138  if isinstance(b, DenyList):
139    a, b = b, a
140  if isinstance(a, DenyList):
141    return DenyList(subtract_filters(a.deny, b))
142
143  a = filter_to_set(a)
144  b = filter_to_set(b)
145  return a.union(b)
146
147
148def subtract_filters(a: Filter, b: Filter) -> Filter:
149  """Returns the subtraction of b from a.
150
151  Args:
152    a: a filter.
153    b: a filter.
154
155  Returns:
156    A filter matching with values in a that are not in b.
157  """
158  if b is True:
159    return False
160  if a is True:
161    return DenyList(b)
162  if isinstance(a, DenyList) and isinstance(b, DenyList):
163    return subtract_filters(b.deny, a.deny)
164  if isinstance(a, DenyList):
165    return DenyList(union_filters(a.deny, b))
166  if isinstance(b, DenyList):
167    return intersect_filters(a, b.deny)
168  a = filter_to_set(a)
169  b = filter_to_set(b)
170  return a - b
171
172
173def intersect_filters(a: Filter, b: Filter) -> Filter:
174  """Take the intersection of two filters (similar to a logical and).
175
176  Args:
177    a: a filter.
178    b: a filter.
179
180  Returns:
181    The intersection of the two input filters. For instance,
182    `intersect_filters('f1', ['f1', 'f2']) = {'f1'}`.
183  """
184  if a is True:
185    return b
186  if b is True:
187    return a
188  if isinstance(a, DenyList) and isinstance(b, DenyList):
189    return DenyList(union_filters(b.deny, a.deny))
190  if isinstance(b, DenyList):
191    b, a = a, b
192  if isinstance(a, DenyList):
193    return subtract_filters(b, a.deny)
194  a = filter_to_set(a)
195  b = filter_to_set(b)
196  return a.intersection(b)
197
198
199def group_collections(
200    xs: VariableDict,
201    col_filters: Sequence[CollectionFilter]) -> Sequence[MutableVariableDict]:
202  """Groups variables by collection filters.
203
204  Iteratively applies the filters in `col_filters` to `xs`, and adds the result
205  of applying each filter to the output sequence. Each key in `xs` is only added
206  to the output once.
207
208  Args:
209    xs: a dictionary of variables, keyed by collections (strings).
210    col_filters: a list of collection filters.
211
212  Returns:
213    A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of
214    applying filter `col_filters[i]` to the remaining keys in `xs`.
215    """
216  cols = xs.keys()
217  groups = []
218  for col_filter in col_filters:
219    remaining_cols = []
220    group = {}
221    for col in cols:
222      if in_filter(col_filter, col):
223        group[col] = jax.tree_map(lambda x: x, xs[col])
224      else:
225        remaining_cols.append(col)
226    cols = remaining_cols
227    groups.append(group)
228  return tuple(groups)
229
230
231class Variable(Generic[T]):
232  """A Variable object allows mutable access to a variable in a VariableDict.
233
234  Variables are identified by a collection (e.g., "batch_stats") and a name
235  (e.g., "moving_mean"). The value property gives access to the variable's
236  content and can be assigned to for mutation.
237  """
238
239  def __init__(self, scope: 'Scope', collection: str, name: str):
240    """Initializes a variable.
241
242    Args:
243      scope: The scope in which the variable is stored.
244      collection: The collection of the variable (e.g., "params").
245      name: The name of the variable (e.g., "dense").
246    """
247    self.scope = scope
248    self.collection = collection
249    self.name = name
250
251  @property
252  def value(self) -> T:
253    """Returns the value of this Variable."""
254    return self.scope.get_variable(self.collection, self.name)
255
256  @value.setter
257  def value(self, value: T):
258    """Updates the value of this Variable."""
259    self.scope.put_variable(self.collection, self.name, value)
260
261  def is_mutable(self) -> bool:
262    """Checks if this Variable is mutable."""
263    return self.scope.is_mutable_collection(self.collection)
264
265
266class Scope:
267  """A Scope allows easy access to variables and manages RNGS of a neural network layer.
268
269  Scopes are purely functional and encapsulated in
270  :class:`flax.linen.module.Module`, so users writing neural network code
271  usually generally do not interact with ``Scopes`` directly.
272
273  See `core design tests
274  <https://github.com/google/flax/tree/master/tests/core/design>`_
275  for a number of examples using ``Scopes``.
276  """
277
278  def __init__(self,
279               variables: MutableVariableDict,
280               rngs: Optional[Dict[str, PRNGKey]] = None,
281               name: Optional[str] = None,
282               mutable: CollectionFilter = False,
283               parent: Optional['Scope'] = None,
284               path: Tuple[str] = ()):
285    """Initializes a Scope.
286
287    Args:
288      variables: VariableDict to initialize the Scope with.
289      rngs: RNGs used in this scope or one of the child scopes.
290      name: name of this scope.
291      mutable: A CollectionFilter determining which variables are mutable.
292      parent: The parent scope.
293      path: The path in the variable tree from the root scope to this scope.
294    """
295    self._variables = variables
296    self.parent = parent
297    self.name = name
298    self.path = path
299    self.rngs = rngs if rngs else {}
300    self.mutable = mutable
301
302    self.root = parent.root if parent else self
303    self.trace_level = tracers.trace_level(tracers.current_trace())
304
305    self.rng_counters = {key: 0 for key in self.rngs}
306    self.reservations = set()
307
308    self._children = {}
309
310    self._invalid = False
311
312  @property
313  def path_text(self) -> str:
314    """Returns the path as a human readable string with slashes between parts."""
315    return '/' + '/'.join(self.path)
316
317  @property
318  def invalid(self) -> bool:
319    """Returns true if this scope is invalidated as a result of `Scope.temporary`."""
320    return self._invalid
321
322  def _check_valid(self):
323    if self._invalid:
324      raise errors.InvalidScopeError(self.name)
325
326  @contextlib.contextmanager
327  def temporary(self):
328    """Returns a context manager that will invalidate this Scope when leaving the context."""
329    try:
330      yield self
331    finally:
332      self.invalidate()
333
334  def invalidate(self):
335    """Invalidates the Scope."""
336    self._invalid = True
337
338  def mutable_variables(self) -> VariableDict:
339    """Returns an immutable copy of the mutable variables belonging to this Scope."""
340    self._populate_collections()
341    xs = {k: v for k, v in self._variables.items()
342          if in_filter(self.mutable, k)}
343    return freeze(xs)
344
345  def variables(self) -> VariableDict:
346    """Returns an immutable copy of the variables belonging to this Scope."""
347    self._populate_collections()
348    return freeze(self._variables)
349
350  def _validate_trace_level(self):
351    tracers.check_trace_level(self.trace_level)
352
353  def rewound(self, rewind_rngs: bool = False) -> 'Scope':
354    """Returns a rewound version of this Scope.
355
356    Args:
357      rewind_rngs: if true, reset the RNG counter of this scope.
358
359    Returns:
360      A rewound version of this scope, which means reservations and children are
361      emptied, and the rng counter is optionally rewound.
362    """
363    self._check_valid()
364    scope = Scope(self._variables, self.rngs, self.name, self.mutable,
365                  self.parent)
366    if not rewind_rngs:
367      scope.rng_counters = self.rng_counters
368    return scope
369
370  def reserve(self, name: str):
371    """Reserves a name for a child Scope or Variable.
372
373    Args:
374      name: the name to reserve.
375    """
376    if not isinstance(name, str):
377      raise TypeError('The type of scope "{name}" should be string but '
378                     f'it is {type(name)}')
379    if name in self.reservations:
380      raise ValueError(f'Duplicate use of scope name: "{name}"')
381    self.reservations.add(name)
382
383  def default_name(self, prefix: str) -> str:
384    """Generates an unreserved name with the given prefix.
385
386    Args:
387      prefix: prefix to use for generating an unreserved name.
388
389    Returns:
390      The generated name.
391    """
392    i = 0
393    while True:
394      name = f'{prefix}{i}'
395      if name not in self.reservations:
396        return name
397      i += 1
398
399  def push(self,
400           name: Optional[str] = None,
401           prefix: str = '',
402           reuse=False) -> 'Scope':
403    """Creates a child Scope.
404
405    Args:
406      name: optional name of the child.
407      prefix: prefix used for generating the name if `name` is `None`.
408      reuse: if True will return a pre-existing child scope with the given name
409        instead of throwing an error.
410
411    Returns:
412      The child scope.
413    """
414    self._check_valid()
415    self._validate_trace_level()
416    if name is None:
417      name = self.default_name(prefix)
418    if reuse and name in self._children:
419      return self._children[name]
420    self.reserve(name)
421    rngs = {key: _fold_in_str(rng, name) for key, rng in self.rngs.items()}
422    scope = Scope({},
423                  name=name,
424                  rngs=rngs,
425                  parent=self,
426                  path=self.path + (name,))
427    self._children[name] = scope
428    return scope
429
430  def child(self,
431            fn: Callable[..., Any],
432            name: Optional[str] = None,
433            prefix: Optional[str] = None,
434            named_call: bool = True,
435            **partial_kwargs) -> Callable[..., Any]:
436    """Partially applies a child scope to fn.
437
438    When calling the returned function multiple times variables will be reused.
439
440    Args:
441      fn: the function to partially apply the child Scope to.
442      name: optional name of the child.
443      prefix: prefix used for generating name if it is `None`.
444      named_call: if true, `fn` will be wrapped with `lift.named_call`. The XLA
445        profiler will use this to name tag the computation.
446      **partial_kwargs: additional kwargs partially applied to `fn`.
447
448    Returns:
449      The function with a partially applied scope.
450    """
451    if name is None:
452      if prefix is None:
453        prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else ''
454      name = self.default_name(prefix)
455    scope = self.push(name)
456    if named_call:
457      # We import named_call at runtime to avoid a circular import issue.
458      from . import lift  # type: ignore
459      fn = lift.named_call(fn, name)
460
461    @functools.wraps(fn)
462    def wrapper(*args, **kwargs):
463      kwargs = dict(partial_kwargs, **kwargs)
464      return fn(scope.rewound(), *args, **kwargs)
465
466    return wrapper
467
468  def is_mutable_collection(self, col: str) -> bool:
469    """Returns true if the collection `col` is mutable."""
470    return in_filter(self.root.mutable, col)
471
472  def _mutable_collection(self, col: str) -> MutableCollection:
473    """Returns the collection `col` as a mutable object."""
474    assert self.is_mutable_collection(col), f'Collection {col} is not mutable'
475    if col not in self._variables:
476      if self.parent:
477        parent_col = self.parent._mutable_collection(col)
478        if self.name not in parent_col:
479          parent_col[self.name] = {}
480        self._variables[col] = parent_col[self.name]
481      else:
482        self._variables[col] = {}
483    return self._variables[col]
484
485  def _collection(self, col: str) -> Collection:
486    """Returns a collection of variables of collection `col`."""
487    if col not in self._variables:
488      if self.parent:
489        parent_col = self.parent._collection(col)
490        if self.name not in parent_col:
491          return FrozenDict()
492        self._variables[col] = parent_col[self.name]
493      else:
494        return FrozenDict()
495    return self._variables[col]
496
497  def has_rng(self, name: str) -> bool:
498    """Returns true if a PRNGSequence with name `name` exists."""
499    return name in self.rngs
500
501  def make_rng(self, name: str) -> PRNGKey:
502    """Generates A PRNGKey from a PRNGSequence with name `name`."""
503    if not self.has_rng(name):
504      raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"')
505    self._check_valid()
506    self._validate_trace_level()
507    self.rng_counters[name] += 1
508    return random.fold_in(self.rngs[name], self.rng_counters[name])
509
510  def get_variable(self, col: str, name: str, default: T = None) -> T:
511    """Retrieves the value of a Variable.
512
513    Args:
514      col: the variable collection.
515      name: the name of the variable.
516      default: the default value to return if the variable does not exist in
517        this scope.
518
519    Returns:
520      The value of the input variable, of the default value if the variable
521      doesn't exist in this scope.
522    """
523    variables = self._collection(col)
524    if name in variables:
525      return variables[name]
526    else:
527      return default
528
529  def has_variable(self, col: str, name: str) -> bool:
530    """Returns true if the given variable exists in this scope.
531
532    Args:
533      col: the collection of the variable.
534      name: the name of the variable.
535    """
536    variables = self._collection(col)
537    return name in variables
538
539  def put_variable(self, col: str, name: str, value: Any):
540    """Updates the value of the given variable if it is mutable, or an error otherwise.
541
542    Args:
543      col: the collection of the variable.
544      name: the name of the variable.
545      value: the new value of the given variable.
546    """
547    self._check_valid()
548    self._validate_trace_level()
549    if not self.is_mutable_collection(col):
550      raise errors.ModifyScopeVariableError(col, name, self.path_text)
551    variables = self._mutable_collection(col)
552    variables[name] = value
553
554  def variable(self, col: str, name: str, init_fn: Callable[..., T],
555               *init_args) -> Variable[T]:
556    """Creates a variable if it doesn't exist yet in this scope and returns it.
557
558    Args:
559      col: the collection of the variable.
560      name: the name of the variable.
561      init_fn: a function taking a PRNGKey plus any other number of positional
562        arguments.
563      *init_args: the arguments to evaluate init_fn on lazily.
564
565    Returns:
566      The variable.
567    """
568    self.reserve(name)
569    if not self.has_variable(col, name):
570      if not self.is_mutable_collection(col):
571        raise errors.ScopeVariableNotFoundError(name, col, self.path_text)
572      init_value = init_fn(*init_args)
573      self.put_variable(col, name, init_value)
574    return Variable(self, col, name)
575
576  def param(self, name: str, init_fn: Callable[..., T], *init_args) -> T:
577    """Creates a parameter if it doesn't exist yet in this scope and returns it.
578
579    If the parameter exists already, the existing value is simply returned.
580
581    Args:
582      name: the name of the parameter.
583      init_fn: a function taking a PRNGKey plus any other number of positional
584        arguments.
585      *init_args: the arguments to evaluate init_fn on lazily.
586
587    Returns:
588      The parameters.
589    """
590    self.reserve(name)
591    if self.has_variable('params', name):
592      abs_rng = jax.ShapeDtypeStruct((2,), jnp.uint32)
593      value = self.get_variable('params', name)
594      # Validate that the shape of the init_fn output is the same as the shape
595      # of the existing parameter. This is to make sure that the hparams set up
596      # in a Flax Module match the shapes coming in during apply, and if not,
597      # catch it with an error message.
598      # NOTE: We could consider moving this to `self.`
599      abs_value = jax.eval_shape(lambda rng: init_fn(rng, *init_args), abs_rng)
600      abs_value_flat = jax.tree_leaves(abs_value)
601      value_flat = jax.tree_leaves(value)
602      for val, abs_val in zip(value_flat, abs_value_flat):
603        # NOTE: We could check dtype consistency here as well but it's
604        # usefuleness is less obvious. We might intentionally change the dtype
605        # for inference to a half float type for example.
606        if jnp.shape(val) != jnp.shape(abs_val):
607          raise errors.ScopeParamShapeError(name, self.path_text,
608              jnp.shape(val), jnp.shape(abs_val))
609    else:
610      if not self.is_mutable_collection('params'):
611        raise errors.ScopeParamNotFoundError(name, self.path_text)
612      value = init_fn(self.make_rng('params'), *init_args)
613      self.put_variable('params', name, value)
614
615    return value
616
617  def _populate_collections(self):
618    collections = self.root._variables.keys()
619    for col in collections:
620      self._collection(col)
621
622
623def _unfreeze_variables(variables, mutable):
624  new_variables = {}
625  for key, value in variables.items():
626    if in_filter(mutable, key):
627      new_variables[key] = unfreeze(value)
628    else:
629      new_variables[key] = freeze(value)
630  return new_variables
631
632
633def bind(variables: VariableDict,
634         rngs: Optional[RNGSequences] = None,
635         mutable: CollectionFilter = False):
636  """Bind variables and rngs to a new ``Scope``.
637
638  bind provides a ``Scope`` instance without transforming a function
639  with ``apply``. This is particulary useful for debugging and
640  interactive use cases like notebooks where a function would limit
641  the ability split up code into different cells.
642
643  a ``Scope`` instance is a stateful object. Note that idiomatic JAX is functional
644  and therefore a ``Scope` does not mix well well with vanilla JAX APIs. Therefore,
645  we recommend using ``apply`` when code should be reusable and compatible
646  across the JAX software ecosystem.
647  """
648  if not _is_valid_variables(variables):
649    raise errors.ApplyScopeInvalidVariablesError()
650  if rngs is not None and not _is_valid_rngs(rngs):
651    raise errors.InvalidRngError(
652      'rngs should be a dictionary mapping strings to `jax.PRNGKey`.')
653  new_variables = _unfreeze_variables(variables, mutable)
654  return Scope(new_variables, rngs=rngs, mutable=mutable)
655
656
657def apply(fn: Callable[..., Any],
658          mutable: CollectionFilter = False) -> Callable[..., Any]:
659  """Functionalize a `Scope` function.
660
661  Args:
662    fn: a function taking a `Scope` as its first argument.
663    mutable: the filter determining which variable collections are mutable.
664
665  Returns:
666    `fn` with the scope partially applied.
667  """
668
669  @functools.wraps(fn)
670  def wrapper(variables: VariableDict,
671              *args,
672              rngs: Optional[RNGSequences] = None,
673              **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:
674    with bind(variables, rngs=rngs, mutable=mutable).temporary() as root:
675      y = fn(root, *args, **kwargs)
676    if mutable is not False:
677      return y, root.mutable_variables()
678    else:
679      return y
680
681  return wrapper
682
683
684def init(fn: Callable[..., Any],
685         mutable: CollectionFilter = True) -> Callable[..., Any]:
686  """Functionalize a `Scope` function for initialization.
687
688  Args:
689    fn: a function taking a `Scope` as its first argument.
690    mutable: the filter determining which variable collections are mutable.
691
692  Returns:
693    `fn` with the scope partially applied.
694  """
695
696  @functools.wraps(fn)
697  def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]:
698    if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs):
699      raise ValueError('First argument passed to an init function should be a '
700                       '`jax.PRNGKey` or a dictionary mapping strings to '
701                       '`jax.PRNGKey`.')
702    if not isinstance(rngs, dict):
703      rngs = {'params': rngs}
704    return apply(fn, mutable=mutable)({}, *args, rngs=rngs, **kwargs)
705
706  return wrapper
707
708
709def _is_valid_collection(col: VariableDict):
710  if not isinstance(col, (FrozenDict, dict)):
711    return False
712  for name in col.keys():
713    # Any value can be stored in a collection so only keys can be verified.
714    if not isinstance(name, str):
715      return False
716  return True
717
718
719def _is_valid_variables(variables: VariableDict) -> bool:
720  """Checks whether the given variable dict is valid.
721
722  Args:
723    variables: A variable dict.
724
725  Returns:
726    True if `variables` is a valid variable dict.
727  """
728  for name, col in variables.items():
729    if not isinstance(name, str):
730      return False
731    if not _is_valid_collection(col):
732      return False
733  return True
734
735
736def _is_valid_rng(rng: Array):
737  if not isinstance(rng, jnp.ndarray):
738    return False
739  if rng.shape != (2,) or rng.dtype != jnp.uint32:
740    return False
741  return True
742
743
744def _is_valid_rngs(rngs: RNGSequences):
745  if not isinstance(rngs, dict):
746    return False
747  for key, val in rngs.items():
748    if not isinstance(key, str):
749      return False
750    if not _is_valid_rng(val):
751      return False
752  return True
753