1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""A simplified time-slice of operations within a sequenced circuit."""
16
17from typing import (
18    AbstractSet,
19    Any,
20    Callable,
21    Dict,
22    FrozenSet,
23    Iterable,
24    Iterator,
25    overload,
26    Optional,
27    Tuple,
28    TYPE_CHECKING,
29    TypeVar,
30    Union,
31)
32
33from cirq import protocols, ops, value
34from cirq.ops import raw_types
35from cirq.protocols import circuit_diagram_info_protocol
36from cirq.type_workarounds import NotImplementedType
37
38if TYPE_CHECKING:
39    import cirq
40
41TSelf_Moment = TypeVar('TSelf_Moment', bound='Moment')
42
43
44def _default_breakdown(qid: 'cirq.Qid') -> Tuple[Any, Any]:
45    # Attempt to convert into a position on the complex plane.
46    try:
47        plane_pos = complex(qid)  # type: ignore
48        return plane_pos.real, plane_pos.imag
49    except TypeError:
50        return None, qid
51
52
53class Moment:
54    """A time-slice of operations within a circuit.
55
56    Grouping operations into moments is intended to be a strong suggestion to
57    whatever is scheduling operations on real hardware. Operations in the same
58    moment should execute at the same time (to the extent possible; not all
59    operations have the same duration) and it is expected that all operations
60    in a moment should be completed before beginning the next moment.
61
62    Moment can be indexed by qubit or list of qubits:
63
64    *   `moment[qubit]` returns the Operation in the moment which touches the
65            given qubit, or throws KeyError if there is no such operation.
66    *   `moment[qubits]` returns another Moment which consists only of those
67            operations which touch at least one of the given qubits. If there
68            are no such operations, returns an empty Moment.
69    """
70
71    def __init__(self, *contents: 'cirq.OP_TREE') -> None:
72        """Constructs a moment with the given operations.
73
74        Args:
75            contents: The operations applied within the moment.
76                Will be flattened and frozen into a tuple before storing.
77
78        Raises:
79            ValueError: A qubit appears more than once.
80        """
81        from cirq.ops import op_tree
82
83        self._operations = tuple(op_tree.flatten_to_ops(contents))
84
85        # An internal dictionary to support efficient operation access by qubit.
86        self._qubit_to_op: Dict['cirq.Qid', 'cirq.Operation'] = {}
87        for op in self.operations:
88            for q in op.qubits:
89                # Check that operations don't overlap.
90                if q in self._qubit_to_op:
91                    raise ValueError(f'Overlapping operations: {self.operations}')
92                self._qubit_to_op[q] = op
93
94        self._qubits = frozenset(self._qubit_to_op.keys())
95        self._measurement_key_objs: Optional[AbstractSet[value.MeasurementKey]] = None
96
97    @property
98    def operations(self) -> Tuple['cirq.Operation', ...]:
99        return self._operations
100
101    @property
102    def qubits(self) -> FrozenSet['cirq.Qid']:
103        return self._qubits
104
105    def operates_on_single_qubit(self, qubit: 'cirq.Qid') -> bool:
106        """Determines if the moment has operations touching the given qubit.
107        Args:
108            qubit: The qubit that may or may not be touched by operations.
109        Returns:
110            Whether this moment has operations involving the qubit.
111        """
112        return qubit in self._qubit_to_op
113
114    def operates_on(self, qubits: Iterable['cirq.Qid']) -> bool:
115        """Determines if the moment has operations touching the given qubits.
116
117        Args:
118            qubits: The qubits that may or may not be touched by operations.
119
120        Returns:
121            Whether this moment has operations involving the qubits.
122        """
123        return bool(set(qubits) & self.qubits)
124
125    def operation_at(self, qubit: raw_types.Qid) -> Optional['cirq.Operation']:
126        """Returns the operation on a certain qubit for the moment.
127
128        Args:
129            qubit: The qubit on which the returned Operation operates
130                on.
131
132        Returns:
133            The operation that operates on the qubit for that moment.
134        """
135        if self.operates_on([qubit]):
136            return self.__getitem__(qubit)
137        else:
138            return None
139
140    # TODO(#3388) Add documentation for Raises.
141    # pylint: disable=missing-raises-doc
142    def with_operation(self, operation: 'cirq.Operation') -> 'cirq.Moment':
143        """Returns an equal moment, but with the given op added.
144
145        Args:
146            operation: The operation to append.
147
148        Returns:
149            The new moment.
150        """
151        if any(q in self._qubits for q in operation.qubits):
152            raise ValueError(f'Overlapping operations: {operation}')
153
154        # Use private variables to facilitate a quick copy.
155        m = Moment()
156        m._operations = self._operations + (operation,)
157        m._qubits = frozenset(self._qubits.union(set(operation.qubits)))
158        m._qubit_to_op = self._qubit_to_op.copy()
159        for q in operation.qubits:
160            m._qubit_to_op[q] = operation
161
162        return m
163
164    # TODO(#3388) Add documentation for Raises.
165    def with_operations(self, *contents: 'cirq.OP_TREE') -> 'cirq.Moment':
166        """Returns a new moment with the given contents added.
167
168        Args:
169            contents: New operations to add to this moment.
170
171        Returns:
172            The new moment.
173        """
174        from cirq.ops import op_tree
175
176        operations = list(self._operations)
177        qubits = set(self._qubits)
178        for op in op_tree.flatten_to_ops(contents):
179            if any(q in qubits for q in op.qubits):
180                raise ValueError(f'Overlapping operations: {op}')
181            operations.append(op)
182            qubits.update(op.qubits)
183
184        # Use private variables to facilitate a quick copy.
185        m = Moment()
186        m._operations = tuple(operations)
187        m._qubits = frozenset(qubits)
188        m._qubit_to_op = self._qubit_to_op.copy()
189        for op in operations:
190            for q in op.qubits:
191                m._qubit_to_op[q] = op
192
193        return m
194
195    # pylint: enable=missing-raises-doc
196    def without_operations_touching(self, qubits: Iterable['cirq.Qid']) -> 'cirq.Moment':
197        """Returns an equal moment, but without ops on the given qubits.
198
199        Args:
200            qubits: Operations that touch these will be removed.
201
202        Returns:
203            The new moment.
204        """
205        qubits = frozenset(qubits)
206        if not self.operates_on(qubits):
207            return self
208        return Moment(
209            operation
210            for operation in self.operations
211            if qubits.isdisjoint(frozenset(operation.qubits))
212        )
213
214    def _with_measurement_key_mapping_(self, key_map: Dict[str, str]):
215        return Moment(
216            protocols.with_measurement_key_mapping(op, key_map)
217            if protocols.is_measurement(op)
218            else op
219            for op in self.operations
220        )
221
222    def _measurement_key_names_(self) -> AbstractSet[str]:
223        return {str(key) for key in self._measurement_key_objs_()}
224
225    def _measurement_key_objs_(self) -> AbstractSet[value.MeasurementKey]:
226        if self._measurement_key_objs is None:
227            self._measurement_key_objs = {
228                key for op in self.operations for key in protocols.measurement_key_objs(op)
229            }
230        return self._measurement_key_objs
231
232    def _with_key_path_(self, path: Tuple[str, ...]):
233        return Moment(
234            protocols.with_key_path(op, path) if protocols.is_measurement(op) else op
235            for op in self.operations
236        )
237
238    def __copy__(self):
239        return type(self)(self.operations)
240
241    def __bool__(self) -> bool:
242        return bool(self.operations)
243
244    def __eq__(self, other) -> bool:
245        if not isinstance(other, type(self)):
246            return NotImplemented
247
248        return sorted(self.operations, key=lambda op: op.qubits) == sorted(
249            other.operations, key=lambda op: op.qubits
250        )
251
252    def _approx_eq_(self, other: Any, atol: Union[int, float]) -> bool:
253        """See `cirq.protocols.SupportsApproximateEquality`."""
254        if not isinstance(other, type(self)):
255            return NotImplemented
256
257        return protocols.approx_eq(
258            sorted(self.operations, key=lambda op: op.qubits),
259            sorted(other.operations, key=lambda op: op.qubits),
260            atol=atol,
261        )
262
263    def __ne__(self, other) -> bool:
264        return not self == other
265
266    def __hash__(self):
267        return hash((Moment, tuple(sorted(self.operations, key=lambda op: op.qubits))))
268
269    def __iter__(self) -> Iterator['cirq.Operation']:
270        return iter(self.operations)
271
272    def __pow__(self, power):
273        if power == 1:
274            return self
275        new_ops = []
276        for op in self.operations:
277            new_op = protocols.pow(op, power, default=None)
278            if new_op is None:
279                return NotImplemented
280            new_ops.append(new_op)
281        return Moment(new_ops)
282
283    def __len__(self) -> int:
284        return len(self.operations)
285
286    def __repr__(self) -> str:
287        if not self.operations:
288            return 'cirq.Moment()'
289
290        block = '\n'.join([repr(op) + ',' for op in self.operations])
291        indented = '    ' + '\n    '.join(block.split('\n'))
292
293        return f'cirq.Moment(\n{indented}\n)'
294
295    def __str__(self) -> str:
296        return self.to_text_diagram()
297
298    def _decompose_(self) -> 'cirq.OP_TREE':
299        """See `cirq.SupportsDecompose`."""
300        return self._operations
301
302    def transform_qubits(
303        self: TSelf_Moment,
304        qubit_map: Union[Dict['cirq.Qid', 'cirq.Qid'], Callable[['cirq.Qid'], 'cirq.Qid']],
305    ) -> TSelf_Moment:
306        """Returns the same moment, but with different qubits.
307
308        Args:
309           qubit_map: A function or a dict mapping each current qubit into a
310                      desired new qubit.
311
312        Returns:
313            The receiving moment but with qubits transformed by the given
314                function.
315        """
316        return self.__class__(op.transform_qubits(qubit_map) for op in self.operations)
317
318    def _json_dict_(self) -> Dict[str, Any]:
319        return protocols.obj_to_dict_helper(self, ['operations'])
320
321    @classmethod
322    def _from_json_dict_(cls, operations, **kwargs):
323        return Moment(operations)
324
325    def __add__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
326        from cirq.circuits import circuit
327
328        if isinstance(other, circuit.AbstractCircuit):
329            return NotImplemented  # Delegate to Circuit.__radd__.
330        return self.with_operations(other)
331
332    def __sub__(self, other: 'cirq.OP_TREE') -> 'cirq.Moment':
333        from cirq.ops import op_tree
334
335        must_remove = set(op_tree.flatten_to_ops(other))
336        new_ops = []
337        for op in self.operations:
338            if op in must_remove:
339                must_remove.remove(op)
340            else:
341                new_ops.append(op)
342        if must_remove:
343            raise ValueError(
344                f"Subtracted missing operations from a moment.\n"
345                f"Missing operations: {must_remove!r}\n"
346                f"Moment: {self!r}"
347            )
348        return Moment(new_ops)
349
350    # pylint: disable=function-redefined
351    @overload
352    def __getitem__(self, key: raw_types.Qid) -> 'cirq.Operation':
353        pass
354
355    @overload
356    def __getitem__(self, key: Iterable[raw_types.Qid]) -> 'cirq.Moment':
357        pass
358
359    def __getitem__(self, key):
360        if isinstance(key, raw_types.Qid):
361            if key not in self._qubit_to_op:
362                raise KeyError("Moment doesn't act on given qubit")
363            return self._qubit_to_op[key]
364        elif isinstance(key, Iterable):
365            qubits_to_keep = frozenset(key)
366            ops_to_keep = []
367            for q in qubits_to_keep:
368                if q in self._qubit_to_op:
369                    ops_to_keep.append(self._qubit_to_op[q])
370            return Moment(frozenset(ops_to_keep))
371
372    # TODO(#3388) Add summary line to docstring.
373    # pylint: disable=docstring-first-line-empty
374    def to_text_diagram(
375        self: 'cirq.Moment',
376        *,
377        xy_breakdown_func: Callable[['cirq.Qid'], Tuple[Any, Any]] = _default_breakdown,
378        extra_qubits: Iterable['cirq.Qid'] = (),
379        use_unicode_characters: bool = True,
380        precision: Optional[int] = None,
381        include_tags: bool = True,
382    ):
383        """
384        Args:
385            xy_breakdown_func: A function to split qubits/qudits into x and y
386                components. For example, the default breakdown turns
387                `cirq.GridQubit(row, col)` into the tuple `(col, row)` and
388                `cirq.LineQubit(x)` into `(x, 0)`.
389            extra_qubits: Extra qubits/qudits to include in the diagram, even
390                if they don't have any operations applied in the moment.
391            use_unicode_characters: Whether or not the output should use fancy
392                unicode characters or stick to plain ASCII. Unicode characters
393                look nicer, but some environments don't draw them with the same
394                width as ascii characters (which ruins the diagrams).
395            precision: How precise numbers, such as angles, should be. Use None
396                for infinite precision, or an integer for a certain number of
397                digits of precision.
398            include_tags: Whether or not to include operation tags in the
399                diagram.
400
401        Returns:
402            The text diagram rendered into text.
403        """
404
405        # Figure out where to place everything.
406        qs = set(self.qubits) | set(extra_qubits)
407        points = {xy_breakdown_func(q) for q in qs}
408        x_keys = sorted({pt[0] for pt in points}, key=_SortByValFallbackToType)
409        y_keys = sorted({pt[1] for pt in points}, key=_SortByValFallbackToType)
410        x_map = {x_key: x + 2 for x, x_key in enumerate(x_keys)}
411        y_map = {y_key: y + 2 for y, y_key in enumerate(y_keys)}
412        qubit_positions = {}
413        for q in qs:
414            a, b = xy_breakdown_func(q)
415            qubit_positions[q] = x_map[a], y_map[b]
416
417        from cirq.circuits.text_diagram_drawer import TextDiagramDrawer
418
419        diagram = TextDiagramDrawer()
420
421        def cleanup_key(key: Any) -> Any:
422            if isinstance(key, float) and key == int(key):
423                return str(int(key))
424            return str(key)
425
426        # Add table headers.
427        for key, x in x_map.items():
428            diagram.write(x, 0, cleanup_key(key))
429        for key, y in y_map.items():
430            diagram.write(0, y, cleanup_key(key))
431        diagram.horizontal_line(1, 0, len(x_map) + 2)
432        diagram.vertical_line(1, 0, len(y_map) + 2)
433        diagram.force_vertical_padding_after(0, 0)
434        diagram.force_vertical_padding_after(1, 0)
435
436        # Add operations.
437        for op in self.operations:
438            args = protocols.CircuitDiagramInfoArgs(
439                known_qubits=op.qubits,
440                known_qubit_count=len(op.qubits),
441                use_unicode_characters=use_unicode_characters,
442                qubit_map=None,
443                precision=precision,
444                include_tags=include_tags,
445            )
446            info = circuit_diagram_info_protocol._op_info_with_fallback(op, args=args)
447            symbols = info._wire_symbols_including_formatted_exponent(args)
448            for label, q in zip(symbols, op.qubits):
449                x, y = qubit_positions[q]
450                diagram.write(x, y, label)
451            if info.connected:
452                for q1, q2 in zip(op.qubits, op.qubits[1:]):
453                    # Sort to get a more consistent orientation for diagonals.
454                    # This reduces how often lines overlap in the diagram.
455                    q1, q2 = sorted([q1, q2])
456
457                    x1, y1 = qubit_positions[q1]
458                    x2, y2 = qubit_positions[q2]
459                    if x1 != x2:
460                        diagram.horizontal_line(y1, x1, x2)
461                    if y1 != y2:
462                        diagram.vertical_line(x2, y1, y2)
463
464        return diagram.render()
465
466    # pylint: enable=docstring-first-line-empty
467    def _commutes_(
468        self, other: Any, *, atol: Union[int, float] = 1e-8
469    ) -> Union[bool, NotImplementedType]:
470        """Determines whether Moment commutes with the Operation.
471
472        Args:
473            other: An Operation object. Other types are not implemented yet.
474                In case a different type is specified, NotImplemented is
475                returned.
476            atol: Absolute error tolerance. If all entries in v1@v2 - v2@v1
477                have a magnitude less than this tolerance, v1 and v2 can be
478                reported as commuting. Defaults to 1e-8.
479
480        Returns:
481            True: The Moment and Operation commute OR they don't have shared
482            quibits.
483            False: The two values do not commute.
484            NotImplemented: In case we don't know how to check this, e.g.
485                the parameter type is not supported yet.
486        """
487        if not isinstance(other, ops.Operation):
488            return NotImplemented
489
490        other_qubits = set(other.qubits)
491        for op in self.operations:
492            if not other_qubits.intersection(set(op.qubits)):
493                continue
494
495            commutes = protocols.commutes(op, other, atol=atol, default=NotImplemented)
496
497            if not commutes or commutes is NotImplemented:
498                return commutes
499
500        return True
501
502
503class _SortByValFallbackToType:
504    def __init__(self, value):
505        self.value = value
506
507    def __lt__(self, other):
508        try:
509            return self.value < other.value
510        except TypeError:
511            t1 = type(self.value)
512            t2 = type(other.value)
513            return str(t1) < str(t2)
514