1#  _________________________________________________________________________
2#
3#  PyUtilib: A Python utility library.
4#  Copyright (c) 2008 Sandia Corporation.
5#  This software is distributed under the BSD License.
6#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
7#  the U.S. Government retains certain rights in this software.
8#  _________________________________________________________________________
9
10import sys
11from pyutilib.misc import misc
12
13if sys.version_info >= (3, 0):
14    xrange = range
15
16
17def _cross_exec(set_tuple):
18    """
19    Function used by cross() to generate the cross-product of a tuple
20    """
21    resulting_set = []
22    if len(set_tuple) == 1:
23        for val in set_tuple[0]:
24            resulting_set.append([val])
25    else:
26        tmp_set = _cross_exec(set_tuple[1:])
27        for val in set_tuple[0]:
28            for item in tmp_set:
29                #print val, item
30                resulting_set.append([val] + item)
31    return resulting_set
32
33
34def cross(set_tuple):
35    """
36    Returns the cross-product of a tuple of values
37    """
38    result_set = []
39    tmp_set = _cross_exec(set_tuple)
40    for val in tmp_set:
41        result_set.append(tuple(val))
42    return result_set
43
44#def tmp_cross(*args):
45#    ans = [[]]
46#    for arg in args:
47#      ans = [x+[y] for x in ans for y in arg]
48#    return ans
49
50if sys.version_info < (3, 0):
51
52    def cross_iter(*sets):
53        """
54        An iterator function that generates a cross product of
55        a set.
56
57        Derived from code developed by Steven Taschuk
58        """
59        wheels = map(iter, sets)  # wheels like in an odometer
60        digits = [it.next() for it in wheels]
61        while True:
62            yield tuple(digits[:])
63            for i in xrange(len(digits) - 1, -1, -1):
64                try:
65                    digits[i] = wheels[i].next()
66                    break
67                except StopIteration:
68                    wheels[i] = iter(sets[i])
69                    digits[i] = wheels[i].next()
70            else:
71                break
72
73    def flattened_cross_iter(*sets):
74        """
75        An iterator function that generates a cross product of
76        a set, and flattens it.
77        """
78        wheels = map(iter, sets)  # wheels like in an odometer
79        digits = [it.next() for it in wheels]
80        ndigits = len(digits)
81        while True:
82            yield misc.flatten_tuple(tuple(digits[:]))
83            for i in xrange(ndigits - 1, -1, -1):
84                try:
85                    digits[i] = wheels[i].next()
86                    break
87                except StopIteration:
88                    wheels[i] = iter(sets[i])
89                    digits[i] = wheels[i].next()
90            else:
91                break
92
93else:
94
95    def cross_iter(*sets):
96        """
97        An iterator function that generates a cross product of
98        a set.
99
100        Derived from code developed by Steven Taschuk
101        """
102        wheels = list(map(iter, sets))  # wheels like in an odometer
103        digits = [next(it) for it in wheels]
104        while True:
105            yield tuple(digits[:])
106            for i in range(len(digits) - 1, -1, -1):
107                try:
108                    digits[i] = next(wheels[i])
109                    break
110                except StopIteration:
111                    wheels[i] = iter(sets[i])
112                    digits[i] = next(wheels[i])
113            else:
114                break
115
116    def flattened_cross_iter(*sets):
117        """
118        An iterator function that generates a cross product of
119        a set, and flattens it.
120        """
121        wheels = list(map(iter, sets))  # wheels like in an odometer
122        digits = [next(it) for it in wheels]
123        ndigits = len(digits)
124        while True:
125            yield misc.flatten_tuple(tuple(digits[:]))
126            for i in range(ndigits - 1, -1, -1):
127                try:
128                    digits[i] = next(wheels[i])
129                    break
130                except StopIteration:
131                    wheels[i] = iter(sets[i])
132                    digits[i] = next(wheels[i])
133            else:
134                break
135