1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simplified time-slice of operations within a sequenced circuit."""
16
17from typing import (
18    Any,
19    Callable,
20    Dict,
21    FrozenSet,
22    Iterable,
23    Iterator,
24    overload,
25    Optional,
26    Tuple,
27    TYPE_CHECKING,
28    TypeVar,
29    Union,
30)
31
32from cirq import protocols, ops
33from cirq.ops import raw_types
34from cirq.protocols import circuit_diagram_info_protocol
35from cirq.type_workarounds import NotImplementedType
36
37if TYPE_CHECKING:
38    import cirq
39
40TSelf_Moment = TypeVar('TSelf_Moment', bound='Moment')
41
42
43def _default_breakdown(qid: 'cirq.Qid') -> Tuple[Any, Any]:
44    # Attempt to convert into a position on the complex plane.
45    try:
46        plane_pos = complex(qid)  # type: ignore
47        return plane_pos.real, plane_pos.imag
48    except TypeError:
49        return None, qid
50
51
52class Moment:
53    """A time-slice of operations within a circuit.
54
55    Grouping operations into moments is intended to be a strong suggestion to
56    whatever is scheduling operations on real hardware. Operations in the same
57    moment should execute at the same time (to the extent possible; not all
58    operations have the same duration) and it is expected that all operations
59    in a moment should be completed before beginning the next moment.
60
61    Moment can be indexed by qubit or list of qubits:
62
63    *   `moment[qubit]` returns the Operation in the moment which touches the
64            given qubit, or throws KeyError if there is no such operation.
65    *   `moment[qubits]` returns another Moment which consists only of those
66            operations which touch at least one of the given qubits. If there
67            are no such operations, returns an empty Moment.
68    """
69
70    def __init__(self, *contents: 'cirq.OP_TREE') -> None:
71        """Constructs a moment with the given operations.
72
73        Args:
74            operations: The operations applied within the moment.
75                Will be flattened and frozen into a tuple before storing.
76
77        Raises:
78            ValueError: A qubit appears more than once.
79        """
80        from cirq.ops import op_tree
81
82        self._operations = tuple(op_tree.flatten_to_ops(contents))
83
84        # An internal dictionary to support efficient operation access by qubit.
85        self._qubit_to_op: Dict['cirq.Qid', 'cirq.Operation'] = {}
86        for op in self.operations:
87            for q in op.qubits:
88                # Check that operations don't overlap.
89                if q in self._qubit_to_op:
90                    raise ValueError(f'Overlapping operations: {self.operations}')
91                self._qubit_to_op[q] = op
92
93        self._qubits = frozenset(self._qubit_to_op.keys())
94
95    @property
96    def operations(self) -> Tuple['cirq.Operation', ...]:
97        return self._operations
98
99    @property
100    def qubits(self) -> FrozenSet['cirq.Qid']:
101        return self._qubits
102
103    def operates_on_single_qubit(self, qubit: 'cirq.Qid') -> bool:
104        """Determines if the moment has operations touching the given qubit.
105        Args:
106            qubit: The qubit that may or may not be touched by operations.
107        Returns:
108            Whether this moment has operations involving the qubit.
109        """
110        return qubit in self._qubit_to_op
111
112    def operates_on(self, qubits: Iterable['cirq.Qid']) -> bool:
113        """Determines if the moment has operations touching the given qubits.
114
115        Args:
116            qubits: The qubits that may or may not be touched by operations.
117
118        Returns:
119            Whether this moment has operations involving the qubits.
120        """
121        return bool(set(qubits) & self.qubits)
122
123    def operation_at(self, qubit: raw_types.Qid) -> Optional['cirq.Operation']:
124        """Returns the operation on a certain qubit for the moment.
125
126        Args:
127            qubit: The qubit on which the returned Operation operates
128                on.
129
130        Returns:
131            The operation that operates on the qubit for that moment.
132        """
133        if self.operates_on([qubit]):
134            return self.__getitem__(qubit)
135        else:
136            return None
137
138    def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
139        """Returns an equal moment, but with the given op added.
140
141        Args:
142            operation: The operation to append.
143
144        Returns:
145            The new moment.
146        """
147        if any(q in self._qubits for q in operation.qubits):
148            raise ValueError(f'Overlapping operations: {operation}')
149
150        # Use private variables to facilitate a quick copy.
151        m = Moment()
152        m._operations = self._operations + (operation,)
153        m._qubits = frozenset(self._qubits.union(set(operation.qubits)))
154        m._qubit_to_op = self._qubit_to_op.copy()
155        for q in operation.qubits:
156            m._qubit_to_op[q] = operation
157
158        return m
159
160    def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
161        """Returns a new moment with the given contents added.
162
163        Args:
164            contents: New operations to add to this moment.
165
166        Returns:
167            The new moment.
168        """
169        from cirq.ops import op_tree
170
171        operations = list(self._operations)
172        qubits = set(self._qubits)
173        for op in op_tree.flatten_to_ops(contents):
174            if any(q in qubits for q in op.qubits):
175                raise ValueError(f'Overlapping operations: {op}')
176            operations.append(op)
177            qubits.update(op.qubits)
178
179        # Use private variables to facilitate a quick copy.
180        m = Moment()
181        m._operations = tuple(operations)
182        m._qubits = frozenset(qubits)
183        m._qubit_to_op = self._qubit_to_op.copy()
184        for op in operations:
185            for q in op.qubits:
186                m._qubit_to_op[q] = op
187
188        return m
189
190    def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Moment':
191        """Returns an equal moment, but without ops on the given qubits.
192
193        Args:
194            qubits: Operations that touch these will be removed.
195
196        Returns:
197            The new moment.
198        """
199        qubits = frozenset(qubits)
200        if not self.operates_on(qubits):
201            return self
202        return Moment(
203            operation
204            for operation in self.operations
205            if qubits.isdisjoint(frozenset(operation.qubits))
206        )
207
208    def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
209        return Moment(
210            protocols.with_measurement_key_mapping(op, key_map)
211            if protocols.is_measurement(op)
212            else op
213            for op in self.operations
214        )
215
216    def _with_key_path_(self, path: Tuple[str, ...]):
217        return Moment(
218            protocols.with_key_path(op, path) if protocols.is_measurement(op) else op
219            for op in self.operations
220        )
221
222    def __copy__(self):
223        return type(self)(self.operations)
224
225    def __bool__(self) -> bool:
226        return bool(self.operations)
227
228    def __eq__(self, other) -> bool:
229        if not isinstance(other, type(self)):
230            return NotImplemented
231
232        return sorted(self.operations, key=lambda op: op.qubits) == sorted(
233            other.operations, key=lambda op: op.qubits
234        )
235
236    def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
237        """See `cirq.protocols.SupportsApproximateEquality`."""
238        if not isinstance(other, type(self)):
239            return NotImplemented
240
241        return protocols.approx_eq(
242            sorted(self.operations, key=lambda op: op.qubits),
243            sorted(other.operations, key=lambda op: op.qubits),
244            atol=atol,
245        )
246
247    def __ne__(self, other) -> bool:
248        return not self == other
249
250    def __hash__(self):
251        return hash((Moment, tuple(sorted(self.operations, key=lambda op: op.qubits))))
252
253    def __iter__(self) -> Iterator['cirq.Operation']:
254        return iter(self.operations)
255
256    def __pow__(self, power):
257        if power == 1:
258            return self
259        new_ops = []
260        for op in self.operations:
261            new_op = protocols.pow(op, power, default=None)
262            if new_op is None:
263                return NotImplemented
264            new_ops.append(new_op)
265        return Moment(new_ops)
266
267    def __len__(self) -> int:
268        return len(self.operations)
269
270    def __repr__(self) -> str:
271        if not self.operations:
272            return 'cirq.Moment()'
273
274        block = '\n'.join([repr(op) + ',' for op in self.operations])
275        indented = '    ' + '\n    '.join(block.split('\n'))
276
277        return f'cirq.Moment(\n{indented}\n)'
278
279    def __str__(self) -> str:
280        return self.to_text_diagram()
281
282    def _decompose_(self) -> 'cirq.OP_TREE':
283        """See `cirq.SupportsDecompose`."""
284        return self._operations
285
286    def transform_qubits(
287        self: TSelf_Moment,
288        qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
289    ) -> TSelf_Moment:
290        """Returns the same moment, but with different qubits.
291
292        Args:
293           qubit_map: A function or a dict mapping each current qubit into a
294                      desired new qubit.
295
296        Returns:
297            The receiving moment but with qubits transformed by the given
298                function.
299        """
300        return self.__class__(op.transform_qubits(qubit_map) for op in self.operations)
301
302    def _json_dict_(self) -> Dict[str, Any]:
303        return protocols.obj_to_dict_helper(self, ['operations'])
304
305    @classmethod
306    def _from_json_dict_(cls, operations, **kwargs):
307        return Moment(operations)
308
309    def __add__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
310        from cirq.circuits import circuit
311
312        if isinstance(other, circuit.AbstractCircuit):
313            return NotImplemented  # Delegate to Circuit.__radd__.
314        return self.with_operations(other)
315
316    def __sub__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
317        from cirq.ops import op_tree
318
319        must_remove = set(op_tree.flatten_to_ops(other))
320        new_ops = []
321        for op in self.operations:
322            if op in must_remove:
323                must_remove.remove(op)
324            else:
325                new_ops.append(op)
326        if must_remove:
327            raise ValueError(
328                f"Subtracted missing operations from a moment.\n"
329                f"Missing operations: {must_remove!r}\n"
330                f"Moment: {self!r}"
331            )
332        return Moment(new_ops)
333
334    # pylint: disable=function-redefined
335    @overload
336    def __getitem__(self, key: raw_types.Qid) -> 'cirq.Operation':
337        pass
338
339    @overload
340    def __getitem__(self, key: Iterable[raw_types.Qid]) -> 'cirq.Moment':
341        pass
342
343    def __getitem__(self, key):
344        if isinstance(key, raw_types.Qid):
345            if key not in self._qubit_to_op:
346                raise KeyError("Moment doesn't act on given qubit")
347            return self._qubit_to_op[key]
348        elif isinstance(key, Iterable):
349            qubits_to_keep = frozenset(key)
350            ops_to_keep = []
351            for q in qubits_to_keep:
352                if q in self._qubit_to_op:
353                    ops_to_keep.append(self._qubit_to_op[q])
354            return Moment(frozenset(ops_to_keep))
355
356    # TODO(#3388) Add summary line to docstring.
357    # pylint: disable=docstring-first-line-empty
358    def to_text_diagram(
359        self: 'cirq.Moment',
360        *,
361        xy_breakdown_func: Callable[['cirq.Qid'], Tuple[Any, Any]] = _default_breakdown,
362        extra_qubits: Iterable['cirq.Qid'] = (),
363        use_unicode_characters: bool = True,
364        precision: Optional[int] = None,
365        include_tags: bool = True,
366    ):
367        """
368        Args:
369            xy_breakdown_func: A function to split qubits/qudits into x and y
370                components. For example, the default breakdown turns
371                `cirq.GridQubit(row, col)` into the tuple `(col, row)` and
372                `cirq.LineQubit(x)` into `(x, 0)`.
373            extra_qubits: Extra qubits/qudits to include in the diagram, even
374                if they don't have any operations applied in the moment.
375            use_unicode_characters: Whether or not the output should use fancy
376                unicode characters or stick to plain ASCII. Unicode characters
377                look nicer, but some environments don't draw them with the same
378                width as ascii characters (which ruins the diagrams).
379            precision: How precise numbers, such as angles, should be. Use None
380                for infinite precision, or an integer for a certain number of
381                digits of precision.
382            include_tags: Whether or not to include operation tags in the
383                diagram.
384
385        Returns:
386            The text diagram rendered into text.
387        """
388
389        # Figure out where to place everything.
390        qs = set(self.qubits) | set(extra_qubits)
391        points = {xy_breakdown_func(q) for q in qs}
392        x_keys = sorted({pt[0] for pt in points}, key=_SortByValFallbackToType)
393        y_keys = sorted({pt[1] for pt in points}, key=_SortByValFallbackToType)
394        x_map = {x_key: x + 2 for x, x_key in enumerate(x_keys)}
395        y_map = {y_key: y + 2 for y, y_key in enumerate(y_keys)}
396        qubit_positions = {}
397        for q in qs:
398            a, b = xy_breakdown_func(q)
399            qubit_positions[q] = x_map[a], y_map[b]
400
401        from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
402
403        diagram = TextDiagramDrawer()
404
405        def cleanup_key(key: Any) -> Any:
406            if isinstance(key, float) and key == int(key):
407                return str(int(key))
408            return str(key)
409
410        # Add table headers.
411        for key, x in x_map.items():
412            diagram.write(x, 0, cleanup_key(key))
413        for key, y in y_map.items():
414            diagram.write(0, y, cleanup_key(key))
415        diagram.horizontal_line(1, 0, len(x_map) + 2)
416        diagram.vertical_line(1, 0, len(y_map) + 2)
417        diagram.force_vertical_padding_after(0, 0)
418        diagram.force_vertical_padding_after(1, 0)
419
420        # Add operations.
421        for op in self.operations:
422            args = protocols.CircuitDiagramInfoArgs(
423                known_qubits=op.qubits,
424                known_qubit_count=len(op.qubits),
425                use_unicode_characters=use_unicode_characters,
426                qubit_map=None,
427                precision=precision,
428                include_tags=include_tags,
429            )
430            info = circuit_diagram_info_protocol._op_info_with_fallback(op, args=args)
431            symbols = info._wire_symbols_including_formatted_exponent(args)
432            for label, q in zip(symbols, op.qubits):
433                x, y = qubit_positions[q]
434                diagram.write(x, y, label)
435            if info.connected:
436                for q1, q2 in zip(op.qubits, op.qubits[1:]):
437                    # Sort to get a more consistent orientation for diagonals.
438                    # This reduces how often lines overlap in the diagram.
439                    q1, q2 = sorted([q1, q2])
440
441                    x1, y1 = qubit_positions[q1]
442                    x2, y2 = qubit_positions[q2]
443                    if x1 != x2:
444                        diagram.horizontal_line(y1, x1, x2)
445                    if y1 != y2:
446                        diagram.vertical_line(x2, y1, y2)
447
448        return diagram.render()
449
450    # pylint: enable=docstring-first-line-empty
451    def _commutes_(
452        self, other: Any, *, atol: Union[int, float] = 1e-8
453    ) -> Union[bool, NotImplementedType]:
454        """Determines whether Moment commutes with the Operation.
455
456        Args:
457            other: An Operation object. Other types are not implemented yet.
458                In case a different type is specified, NotImplemented is
459                returned.
460            atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1
461                have a magnitude less than this tolerance, v1 and v2 can be
462                reported as commuting. Defaults to 1e-8.
463
464        Returns:
465            True: The Moment and Operation commute OR they don't have shared
466            quibits.
467            False: The two values do not commute.
468            NotImplemented: In case we don't know how to check this, e.g.
469                the parameter type is not supported yet.
470        """
471        if not isinstance(other, ops.Operation):
472            return NotImplemented
473
474        other_qubits = set(other.qubits)
475        for op in self.operations:
476            if not other_qubits.intersection(set(op.qubits)):
477                continue
478
479            commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented)
480
481            if not commutes or commutes is NotImplemented:
482                return commutes
483
484        return True
485
486
487class _SortByValFallbackToType:
488    def __init__(self, value):
489        self.value = value
490
491    def __lt__(self, other):
492        try:
493            return self.value < other.value
494        except TypeError:
495            t1 = type(self.value)
496            t2 = type(other.value)
497            return str(t1) < str(t2)
498