1# Copyright: See the LICENSE file.
2
3
4import collections
5import logging
6import warnings
7
8from . import builder, declarations, enums, errors, utils
9
10logger = logging.getLogger('factory.generate')
11
12# Factory metaclasses
13
14
15def get_factory_bases(bases):
16    """Retrieve all FactoryMetaClass-derived bases from a list."""
17    return [b for b in bases if issubclass(b, BaseFactory)]
18
19
20def resolve_attribute(name, bases, default=None):
21    """Find the first definition of an attribute according to MRO order."""
22    for base in bases:
23        if hasattr(base, name):
24            return getattr(base, name)
25    return default
26
27
28class FactoryMetaClass(type):
29    """Factory metaclass for handling ordered declarations."""
30
31    def __call__(cls, **kwargs):
32        """Override the default Factory() syntax to call the default strategy.
33
34        Returns an instance of the associated class.
35        """
36
37        if cls._meta.strategy == enums.BUILD_STRATEGY:
38            return cls.build(**kwargs)
39        elif cls._meta.strategy == enums.CREATE_STRATEGY:
40            return cls.create(**kwargs)
41        elif cls._meta.strategy == enums.STUB_STRATEGY:
42            return cls.stub(**kwargs)
43        else:
44            raise errors.UnknownStrategy('Unknown Meta.strategy: {}'.format(
45                cls._meta.strategy))
46
47    def __new__(mcs, class_name, bases, attrs):
48        """Record attributes as a pattern for later instance construction.
49
50        This is called when a new Factory subclass is defined; it will collect
51        attribute declaration from the class definition.
52
53        Args:
54            class_name (str): the name of the class being created
55            bases (list of class): the parents of the class being created
56            attrs (str => obj dict): the attributes as defined in the class
57                definition
58
59        Returns:
60            A new class
61        """
62        parent_factories = get_factory_bases(bases)
63        if parent_factories:
64            base_factory = parent_factories[0]
65        else:
66            base_factory = None
67
68        attrs_meta = attrs.pop('Meta', None)
69        attrs_params = attrs.pop('Params', None)
70
71        base_meta = resolve_attribute('_meta', bases)
72        options_class = resolve_attribute('_options_class', bases, FactoryOptions)
73
74        meta = options_class()
75        attrs['_meta'] = meta
76
77        new_class = super().__new__(
78            mcs, class_name, bases, attrs)
79
80        meta.contribute_to_class(
81            new_class,
82            meta=attrs_meta,
83            base_meta=base_meta,
84            base_factory=base_factory,
85            params=attrs_params,
86        )
87
88        return new_class
89
90    def __str__(cls):
91        if cls._meta.abstract:
92            return '<%s (abstract)>' % cls.__name__
93        else:
94            return f'<{cls.__name__} for {cls._meta.model}>'
95
96
97class BaseMeta:
98    abstract = True
99    strategy = enums.CREATE_STRATEGY
100
101
102class OptionDefault:
103    """The default for an option.
104
105    Attributes:
106        name: str, the name of the option ('class Meta' attribute)
107        value: object, the default value for the option
108        inherit: bool, whether to inherit the value from the parent factory's `class Meta`
109            when no value is provided
110        checker: callable or None, an optional function used to detect invalid option
111            values at declaration time
112    """
113    def __init__(self, name, value, inherit=False, checker=None):
114        self.name = name
115        self.value = value
116        self.inherit = inherit
117        self.checker = checker
118
119    def apply(self, meta, base_meta):
120        value = self.value
121        if self.inherit and base_meta is not None:
122            value = getattr(base_meta, self.name, value)
123        if meta is not None:
124            value = getattr(meta, self.name, value)
125
126        if self.checker is not None:
127            self.checker(meta, value)
128
129        return value
130
131    def __str__(self):
132        return '%s(%r, %r, inherit=%r)' % (
133            self.__class__.__name__,
134            self.name, self.value, self.inherit)
135
136
137class FactoryOptions:
138    def __init__(self):
139        self.factory = None
140        self.base_factory = None
141        self.base_declarations = {}
142        self.parameters = {}
143        self.parameters_dependencies = {}
144        self.pre_declarations = builder.DeclarationSet()
145        self.post_declarations = builder.DeclarationSet()
146
147        self._counter = None
148        self.counter_reference = None
149
150    @property
151    def declarations(self):
152        base_declarations = dict(self.base_declarations)
153        for name, param in utils.sort_ordered_objects(self.parameters.items(), getter=lambda item: item[1]):
154            base_declarations.update(param.as_declarations(name, base_declarations))
155        return base_declarations
156
157    def _build_default_options(self):
158        """"Provide the default value for all allowed fields.
159
160        Custom FactoryOptions classes should override this method
161        to update() its return value.
162        """
163
164        def is_model(meta, value):
165            if isinstance(value, FactoryMetaClass):
166                raise TypeError(
167                    "%s is already a %s"
168                    % (repr(value), Factory.__name__)
169                )
170
171        return [
172            OptionDefault('model', None, inherit=True, checker=is_model),
173            OptionDefault('abstract', False, inherit=False),
174            OptionDefault('strategy', enums.CREATE_STRATEGY, inherit=True),
175            OptionDefault('inline_args', (), inherit=True),
176            OptionDefault('exclude', (), inherit=True),
177            OptionDefault('rename', {}, inherit=True),
178        ]
179
180    def _fill_from_meta(self, meta, base_meta):
181        # Exclude private/protected fields from the meta
182        if meta is None:
183            meta_attrs = {}
184        else:
185            meta_attrs = {
186                k: v
187                for (k, v) in vars(meta).items()
188                if not k.startswith('_')
189            }
190
191        for option in self._build_default_options():
192            assert not hasattr(self, option.name), "Can't override field %s." % option.name
193            value = option.apply(meta, base_meta)
194            meta_attrs.pop(option.name, None)
195            setattr(self, option.name, value)
196
197        if meta_attrs:
198            # Some attributes in the Meta aren't allowed here
199            raise TypeError(
200                "'class Meta' for %r got unknown attribute(s) %s"
201                % (self.factory, ','.join(sorted(meta_attrs.keys()))))
202
203    def contribute_to_class(self, factory, meta=None, base_meta=None, base_factory=None, params=None):
204
205        self.factory = factory
206        self.base_factory = base_factory
207
208        self._fill_from_meta(meta=meta, base_meta=base_meta)
209
210        self.model = self.get_model_class()
211        if self.model is None:
212            self.abstract = True
213
214        self.counter_reference = self._get_counter_reference()
215
216        # Scan the inheritance chain, starting from the furthest point,
217        # excluding the current class, to retrieve all declarations.
218        for parent in reversed(self.factory.__mro__[1:]):
219            if not hasattr(parent, '_meta'):
220                continue
221            self.base_declarations.update(parent._meta.base_declarations)
222            self.parameters.update(parent._meta.parameters)
223
224        for k, v in vars(self.factory).items():
225            if self._is_declaration(k, v):
226                self.base_declarations[k] = v
227
228        if params is not None:
229            for k, v in utils.sort_ordered_objects(vars(params).items(), getter=lambda item: item[1]):
230                if not k.startswith('_'):
231                    self.parameters[k] = declarations.SimpleParameter.wrap(v)
232
233        self._check_parameter_dependencies(self.parameters)
234
235        self.pre_declarations, self.post_declarations = builder.parse_declarations(self.declarations)
236
237    def _get_counter_reference(self):
238        """Identify which factory should be used for a shared counter."""
239
240        if (self.model is not None
241                and self.base_factory is not None
242                and self.base_factory._meta.model is not None
243                and issubclass(self.model, self.base_factory._meta.model)):
244            return self.base_factory._meta.counter_reference
245        else:
246            return self
247
248    def _initialize_counter(self):
249        """Initialize our counter pointer.
250
251        If we're the top-level factory, instantiate a new counter
252        Otherwise, point to the top-level factory's counter.
253        """
254        if self._counter is not None:
255            return
256
257        if self.counter_reference is self:
258            self._counter = _Counter(seq=self.factory._setup_next_sequence())
259        else:
260            self.counter_reference._initialize_counter()
261            self._counter = self.counter_reference._counter
262
263    def next_sequence(self):
264        """Retrieve a new sequence ID.
265
266        This will call, in order:
267        - next_sequence from the base factory, if provided
268        - _setup_next_sequence, if this is the 'toplevel' factory and the
269            sequence counter wasn't initialized yet; then increase it.
270        """
271        self._initialize_counter()
272        return self._counter.next()
273
274    def reset_sequence(self, value=None, force=False):
275        self._initialize_counter()
276
277        if self.counter_reference is not self and not force:
278            raise ValueError(
279                "Can't reset a sequence on descendant factory %r; reset sequence on %r or use `force=True`."
280                % (self.factory, self.counter_reference.factory))
281
282        if value is None:
283            value = self.counter_reference.factory._setup_next_sequence()
284        self._counter.reset(value)
285
286    def prepare_arguments(self, attributes):
287        """Convert an attributes dict to a (args, kwargs) tuple."""
288        kwargs = dict(attributes)
289        # 1. Extension points
290        kwargs = self.factory._adjust_kwargs(**kwargs)
291
292        # 2. Remove hidden objects
293        kwargs = {
294            k: v for k, v in kwargs.items()
295            if k not in self.exclude and k not in self.parameters and v is not declarations.SKIP
296        }
297
298        # 3. Rename fields
299        for old_name, new_name in self.rename.items():
300            if old_name in kwargs:
301                kwargs[new_name] = kwargs.pop(old_name)
302
303        # 4. Extract inline args
304        args = tuple(
305            kwargs.pop(arg_name)
306            for arg_name in self.inline_args
307        )
308
309        return args, kwargs
310
311    def instantiate(self, step, args, kwargs):
312        model = self.get_model_class()
313
314        if step.builder.strategy == enums.BUILD_STRATEGY:
315            return self.factory._build(model, *args, **kwargs)
316        elif step.builder.strategy == enums.CREATE_STRATEGY:
317            return self.factory._create(model, *args, **kwargs)
318        else:
319            assert step.builder.strategy == enums.STUB_STRATEGY
320            return StubObject(**kwargs)
321
322    def use_postgeneration_results(self, step, instance, results):
323        self.factory._after_postgeneration(
324            instance,
325            create=step.builder.strategy == enums.CREATE_STRATEGY,
326            results=results,
327        )
328
329    def _is_declaration(self, name, value):
330        """Determines if a class attribute is a field value declaration.
331
332        Based on the name and value of the class attribute, return ``True`` if
333        it looks like a declaration of a default field value, ``False`` if it
334        is private (name starts with '_') or a classmethod or staticmethod.
335
336        """
337        if isinstance(value, (classmethod, staticmethod)):
338            return False
339        elif enums.get_builder_phase(value):
340            # All objects with a defined 'builder phase' are declarations.
341            return True
342        return not name.startswith("_")
343
344    def _check_parameter_dependencies(self, parameters):
345        """Find out in what order parameters should be called."""
346        # Warning: parameters only provide reverse dependencies; we reverse them into standard dependencies.
347        # deep_revdeps: set of fields a field depend indirectly upon
348        deep_revdeps = collections.defaultdict(set)
349        # Actual, direct dependencies
350        deps = collections.defaultdict(set)
351
352        for name, parameter in parameters.items():
353            if isinstance(parameter, declarations.Parameter):
354                field_revdeps = parameter.get_revdeps(parameters)
355                if not field_revdeps:
356                    continue
357                deep_revdeps[name] = set.union(*(deep_revdeps[dep] for dep in field_revdeps))
358                deep_revdeps[name] |= set(field_revdeps)
359                for dep in field_revdeps:
360                    deps[dep].add(name)
361
362        # Check for cyclical dependencies
363        cyclic = [name for name, field_deps in deep_revdeps.items() if name in field_deps]
364        if cyclic:
365            raise errors.CyclicDefinitionError(
366                "Cyclic definition detected on %r; Params around %s"
367                % (self.factory, ', '.join(cyclic)))
368        return deps
369
370    def get_model_class(self):
371        """Extension point for loading model classes.
372
373        This can be overridden in framework-specific subclasses to hook into
374        existing model repositories, for instance.
375        """
376        return self.model
377
378    def __str__(self):
379        return "<%s for %s>" % (self.__class__.__name__, self.factory.__name__)
380
381    def __repr__(self):
382        return str(self)
383
384
385# Factory base classes
386
387
388class _Counter:
389    """Simple, naive counter.
390
391    Attributes:
392        for_class (obj): the class this counter related to
393        seq (int): the next value
394    """
395
396    def __init__(self, seq):
397        self.seq = seq
398
399    def next(self):
400        value = self.seq
401        self.seq += 1
402        return value
403
404    def reset(self, next_value=0):
405        self.seq = next_value
406
407
408class BaseFactory:
409    """Factory base support for sequences, attributes and stubs."""
410
411    # Backwards compatibility
412    UnknownStrategy = errors.UnknownStrategy
413    UnsupportedStrategy = errors.UnsupportedStrategy
414
415    def __new__(cls, *args, **kwargs):
416        """Would be called if trying to instantiate the class."""
417        raise errors.FactoryError('You cannot instantiate BaseFactory')
418
419    _meta = FactoryOptions()
420
421    # ID to use for the next 'declarations.Sequence' attribute.
422    _counter = None
423
424    @classmethod
425    def reset_sequence(cls, value=None, force=False):
426        """Reset the sequence counter.
427
428        Args:
429            value (int or None): the new 'next' sequence value; if None,
430                recompute the next value from _setup_next_sequence().
431            force (bool): whether to force-reset parent sequence counters
432                in a factory inheritance chain.
433        """
434        cls._meta.reset_sequence(value, force=force)
435
436    @classmethod
437    def _setup_next_sequence(cls):
438        """Set up an initial sequence value for Sequence attributes.
439
440        Returns:
441            int: the first available ID to use for instances of this factory.
442        """
443        return 0
444
445    @classmethod
446    def _adjust_kwargs(cls, **kwargs):
447        """Extension point for custom kwargs adjustment."""
448        return kwargs
449
450    @classmethod
451    def _generate(cls, strategy, params):
452        """generate the object.
453
454        Args:
455            params (dict): attributes to use for generating the object
456            strategy: the strategy to use
457        """
458        if cls._meta.abstract:
459            raise errors.FactoryError(
460                "Cannot generate instances of abstract factory %(f)s; "
461                "Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
462                "is either not set or False." % dict(f=cls.__name__))
463
464        step = builder.StepBuilder(cls._meta, params, strategy)
465        return step.build()
466
467    @classmethod
468    def _after_postgeneration(cls, instance, create, results=None):
469        """Hook called after post-generation declarations have been handled.
470
471        Args:
472            instance (object): the generated object
473            create (bool): whether the strategy was 'build' or 'create'
474            results (dict or None): result of post-generation declarations
475        """
476        pass
477
478    @classmethod
479    def _build(cls, model_class, *args, **kwargs):
480        """Actually build an instance of the model_class.
481
482        Customization point, will be called once the full set of args and kwargs
483        has been computed.
484
485        Args:
486            model_class (type): the class for which an instance should be
487                built
488            args (tuple): arguments to use when building the class
489            kwargs (dict): keyword arguments to use when building the class
490        """
491        return model_class(*args, **kwargs)
492
493    @classmethod
494    def _create(cls, model_class, *args, **kwargs):
495        """Actually create an instance of the model_class.
496
497        Customization point, will be called once the full set of args and kwargs
498        has been computed.
499
500        Args:
501            model_class (type): the class for which an instance should be
502                created
503            args (tuple): arguments to use when creating the class
504            kwargs (dict): keyword arguments to use when creating the class
505        """
506        return model_class(*args, **kwargs)
507
508    @classmethod
509    def build(cls, **kwargs):
510        """Build an instance of the associated class, with overridden attrs."""
511        return cls._generate(enums.BUILD_STRATEGY, kwargs)
512
513    @classmethod
514    def build_batch(cls, size, **kwargs):
515        """Build a batch of instances of the given class, with overridden attrs.
516
517        Args:
518            size (int): the number of instances to build
519
520        Returns:
521            object list: the built instances
522        """
523        return [cls.build(**kwargs) for _ in range(size)]
524
525    @classmethod
526    def create(cls, **kwargs):
527        """Create an instance of the associated class, with overridden attrs."""
528        return cls._generate(enums.CREATE_STRATEGY, kwargs)
529
530    @classmethod
531    def create_batch(cls, size, **kwargs):
532        """Create a batch of instances of the given class, with overridden attrs.
533
534        Args:
535            size (int): the number of instances to create
536
537        Returns:
538            object list: the created instances
539        """
540        return [cls.create(**kwargs) for _ in range(size)]
541
542    @classmethod
543    def stub(cls, **kwargs):
544        """Retrieve a stub of the associated class, with overridden attrs.
545
546        This will return an object whose attributes are those defined in this
547        factory's declarations or in the extra kwargs.
548        """
549        return cls._generate(enums.STUB_STRATEGY, kwargs)
550
551    @classmethod
552    def stub_batch(cls, size, **kwargs):
553        """Stub a batch of instances of the given class, with overridden attrs.
554
555        Args:
556            size (int): the number of instances to stub
557
558        Returns:
559            object list: the stubbed instances
560        """
561        return [cls.stub(**kwargs) for _ in range(size)]
562
563    @classmethod
564    def generate(cls, strategy, **kwargs):
565        """Generate a new instance.
566
567        The instance will be created with the given strategy (one of
568        BUILD_STRATEGY, CREATE_STRATEGY, STUB_STRATEGY).
569
570        Args:
571            strategy (str): the strategy to use for generating the instance.
572
573        Returns:
574            object: the generated instance
575        """
576        assert strategy in (enums.STUB_STRATEGY, enums.BUILD_STRATEGY, enums.CREATE_STRATEGY)
577        action = getattr(cls, strategy)
578        return action(**kwargs)
579
580    @classmethod
581    def generate_batch(cls, strategy, size, **kwargs):
582        """Generate a batch of instances.
583
584        The instances will be created with the given strategy (one of
585        BUILD_STRATEGY, CREATE_STRATEGY, STUB_STRATEGY).
586
587        Args:
588            strategy (str): the strategy to use for generating the instance.
589            size (int): the number of instances to generate
590
591        Returns:
592            object list: the generated instances
593        """
594        assert strategy in (enums.STUB_STRATEGY, enums.BUILD_STRATEGY, enums.CREATE_STRATEGY)
595        batch_action = getattr(cls, '%s_batch' % strategy)
596        return batch_action(size, **kwargs)
597
598    @classmethod
599    def simple_generate(cls, create, **kwargs):
600        """Generate a new instance.
601
602        The instance will be either 'built' or 'created'.
603
604        Args:
605            create (bool): whether to 'build' or 'create' the instance.
606
607        Returns:
608            object: the generated instance
609        """
610        strategy = enums.CREATE_STRATEGY if create else enums.BUILD_STRATEGY
611        return cls.generate(strategy, **kwargs)
612
613    @classmethod
614    def simple_generate_batch(cls, create, size, **kwargs):
615        """Generate a batch of instances.
616
617        These instances will be either 'built' or 'created'.
618
619        Args:
620            size (int): the number of instances to generate
621            create (bool): whether to 'build' or 'create' the instances.
622
623        Returns:
624            object list: the generated instances
625        """
626        strategy = enums.CREATE_STRATEGY if create else enums.BUILD_STRATEGY
627        return cls.generate_batch(strategy, size, **kwargs)
628
629
630class Factory(BaseFactory, metaclass=FactoryMetaClass):
631    """Factory base with build and create support.
632
633    This class has the ability to support multiple ORMs by using custom creation
634    functions.
635    """
636
637    class Meta(BaseMeta):
638        pass
639
640
641# Backwards compatibility
642Factory.AssociatedClassError = errors.AssociatedClassError
643
644
645class StubObject:
646    """A generic container."""
647    def __init__(self, **kwargs):
648        for field, value in kwargs.items():
649            setattr(self, field, value)
650
651
652class StubFactory(Factory):
653
654    class Meta:
655        strategy = enums.STUB_STRATEGY
656        model = StubObject
657
658    @classmethod
659    def build(cls, **kwargs):
660        return cls.stub(**kwargs)
661
662    @classmethod
663    def create(cls, **kwargs):
664        raise errors.UnsupportedStrategy()
665
666
667class BaseDictFactory(Factory):
668    """Factory for dictionary-like classes."""
669    class Meta:
670        abstract = True
671
672    @classmethod
673    def _build(cls, model_class, *args, **kwargs):
674        if args:
675            raise ValueError(
676                "DictFactory %r does not support Meta.inline_args." % cls)
677        return model_class(**kwargs)
678
679    @classmethod
680    def _create(cls, model_class, *args, **kwargs):
681        return cls._build(model_class, *args, **kwargs)
682
683
684class DictFactory(BaseDictFactory):
685    class Meta:
686        model = dict
687
688
689class BaseListFactory(Factory):
690    """Factory for list-like classes."""
691    class Meta:
692        abstract = True
693
694    @classmethod
695    def _build(cls, model_class, *args, **kwargs):
696        if args:
697            raise ValueError(
698                "ListFactory %r does not support Meta.inline_args." % cls)
699
700        # kwargs are constructed from a list, their insertion order matches the list
701        # order, no additional sorting is required.
702        values = kwargs.values()
703        return model_class(values)
704
705    @classmethod
706    def _create(cls, model_class, *args, **kwargs):
707        return cls._build(model_class, *args, **kwargs)
708
709
710class ListFactory(BaseListFactory):
711    class Meta:
712        model = list
713
714
715def use_strategy(new_strategy):
716    """Force the use of a different strategy.
717
718    This is an alternative to setting default_strategy in the class definition.
719    """
720    warnings.warn(
721        "use_strategy() is deprecated and will be removed in the future.",
722        DeprecationWarning,
723        stacklevel=2,
724    )
725
726    def wrapped_class(klass):
727        klass._meta.strategy = new_strategy
728        return klass
729    return wrapped_class
730