1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import (
16    Any,
17    Callable,
18    cast,
19    Dict,
20    Iterable,
21    List,
22    Mapping,
23    NamedTuple,
24    Optional,
25    Sequence,
26    Tuple,
27    Union,
28)
29
30import numpy as np
31
32from cirq import value
33from cirq.circuits._block_diagram_drawer import BlockDiagramDrawer
34from cirq.circuits._box_drawing_character_data import (
35    BoxDrawCharacterSet,
36    NORMAL_BOX_CHARS,
37    BOLD_BOX_CHARS,
38    ASCII_BOX_CHARS,
39)
40
41_HorizontalLine = NamedTuple(
42    'HorizontalLine',
43    [
44        ('y', Union[int, float]),
45        ('x1', Union[int, float]),
46        ('x2', Union[int, float]),
47        ('emphasize', bool),
48    ],
49)
50_VerticalLine = NamedTuple(
51    'VerticalLine',
52    [
53        ('x', Union[int, float]),
54        ('y1', Union[int, float]),
55        ('y2', Union[int, float]),
56        ('emphasize', bool),
57    ],
58)
59_DiagramText = NamedTuple(
60    'DiagramText',
61    [
62        ('text', str),
63        ('transposed_text', str),
64    ],
65)
66
67
68def pick_charset(use_unicode: bool, emphasize: bool) -> BoxDrawCharacterSet:
69    if not use_unicode:
70        return ASCII_BOX_CHARS
71    if emphasize:
72        return BOLD_BOX_CHARS
73    return NORMAL_BOX_CHARS
74
75
76@value.value_equality(unhashable=True)
77class TextDiagramDrawer:
78    """A utility class for creating simple text diagrams."""
79
80    def __init__(
81        self,
82        entries: Optional[Mapping[Tuple[int, int], _DiagramText]] = None,
83        horizontal_lines: Optional[Iterable[_HorizontalLine]] = None,
84        vertical_lines: Optional[Iterable[_VerticalLine]] = None,
85        horizontal_padding: Optional[Mapping[int, int]] = None,
86        vertical_padding: Optional[Mapping[int, int]] = None,
87    ) -> None:
88        self.entries: Dict[Tuple[int, int], _DiagramText] = (
89            dict() if entries is None else dict(entries)
90        )
91        self.horizontal_lines: List[_HorizontalLine] = (
92            [] if horizontal_lines is None else list(horizontal_lines)
93        )
94        self.vertical_lines: List[_VerticalLine] = (
95            [] if vertical_lines is None else list(vertical_lines)
96        )
97        self.horizontal_padding: Dict[int, Union[int, float]] = (
98            dict() if horizontal_padding is None else dict(horizontal_padding)
99        )
100        self.vertical_padding: Dict[int, Union[int, float]] = (
101            dict() if vertical_padding is None else dict(vertical_padding)
102        )
103
104    def _value_equality_values_(self):
105        attrs = (
106            'entries',
107            'horizontal_lines',
108            'vertical_lines',
109            'horizontal_padding',
110            'vertical_padding',
111        )
112        return tuple(getattr(self, attr) for attr in attrs)
113
114    def __bool__(self):
115        return any(self._value_equality_values_())
116
117    def write(self, x: int, y: int, text: str, transposed_text: Optional[str] = None):
118        """Adds text to the given location.
119
120        Args:
121            x: The column in which to write the text.
122            y: The row in which to write the text.
123            text: The text to write at location (x, y).
124            transposed_text: Optional text to write instead, if the text
125                diagram is transposed.
126        """
127        entry = self.entries.get((x, y), _DiagramText('', ''))
128        self.entries[(x, y)] = _DiagramText(
129            entry.text + text,
130            entry.transposed_text + (transposed_text if transposed_text else text),
131        )
132
133    def content_present(self, x: int, y: int) -> bool:
134        """Determines if a line or printed text is at the given location."""
135
136        # Text?
137        if (x, y) in self.entries:
138            return True
139
140        # Vertical line?
141        if any(v.x == x and v.y1 < y < v.y2 for v in self.vertical_lines):
142            return True
143
144        # Horizontal line?
145        if any(line_y == y and x1 < x < x2 for line_y, x1, x2, _ in self.horizontal_lines):
146            return True
147
148        return False
149
150    def grid_line(self, x1: int, y1: int, x2: int, y2: int, emphasize: bool = False):
151        """Adds a vertical or horizontal line from (x1, y1) to (x2, y2).
152
153        Horizontal line is selected on equality in the second coordinate and
154        vertical line is selected on equality in the first coordinate.
155
156        Raises:
157            ValueError: If line is neither horizontal nor vertical.
158        """
159        if x1 == x2:
160            self.vertical_line(x1, y1, y2, emphasize)
161        elif y1 == y2:
162            self.horizontal_line(y1, x1, x2, emphasize)
163        else:
164            raise ValueError("Line is neither horizontal nor vertical")
165
166    def vertical_line(
167        self,
168        x: Union[int, float],
169        y1: Union[int, float],
170        y2: Union[int, float],
171        emphasize: bool = False,
172    ) -> None:
173        """Adds a line from (x, y1) to (x, y2)."""
174        y1, y2 = sorted([y1, y2])
175        self.vertical_lines.append(_VerticalLine(x, y1, y2, emphasize))
176
177    def horizontal_line(
178        self,
179        y: Union[int, float],
180        x1: Union[int, float],
181        x2: Union[int, float],
182        emphasize: bool = False,
183    ) -> None:
184        """Adds a line from (x1, y) to (x2, y)."""
185        x1, x2 = sorted([x1, x2])
186        self.horizontal_lines.append(_HorizontalLine(y, x1, x2, emphasize))
187
188    def transpose(self) -> 'TextDiagramDrawer':
189        """Returns the same diagram, but mirrored across its diagonal."""
190        out = TextDiagramDrawer()
191        out.entries = {
192            (y, x): _DiagramText(v.transposed_text, v.text) for (x, y), v in self.entries.items()
193        }
194        out.vertical_lines = [_VerticalLine(*e) for e in self.horizontal_lines]
195        out.horizontal_lines = [_HorizontalLine(*e) for e in self.vertical_lines]
196        out.vertical_padding = self.horizontal_padding.copy()
197        out.horizontal_padding = self.vertical_padding.copy()
198        return out
199
200    def width(self) -> int:
201        """Determines how many entry columns are in the diagram."""
202        max_x = -1.0
203        for x, _ in self.entries.keys():
204            max_x = max(max_x, x)
205        for v in self.vertical_lines:
206            max_x = max(max_x, v.x)
207        for h in self.horizontal_lines:
208            max_x = max(max_x, h.x1, h.x2)
209        return 1 + int(max_x)
210
211    def height(self) -> int:
212        """Determines how many entry rows are in the diagram."""
213        max_y = -1.0
214        for _, y in self.entries.keys():
215            max_y = max(max_y, y)
216        for h in self.horizontal_lines:
217            max_y = max(max_y, h.y)
218        for v in self.vertical_lines:
219            max_y = max(max_y, v.y1, v.y2)
220        return 1 + int(max_y)
221
222    def force_horizontal_padding_after(self, index: int, padding: Union[int, float]) -> None:
223        """Change the padding after the given column."""
224        self.horizontal_padding[index] = padding
225
226    def force_vertical_padding_after(self, index: int, padding: Union[int, float]) -> None:
227        """Change the padding after the given row."""
228        self.vertical_padding[index] = padding
229
230    def _transform_coordinates(
231        self,
232        func: Callable[
233            [Union[int, float], Union[int, float]], Tuple[Union[int, float], Union[int, float]]
234        ],
235    ) -> None:
236        """Helper method to transformer either row or column coordinates."""
237
238        def func_x(x: Union[int, float]) -> Union[int, float]:
239            return func(x, 0)[0]
240
241        def func_y(y: Union[int, float]) -> Union[int, float]:
242            return func(0, y)[1]
243
244        self.entries = {
245            cast(Tuple[int, int], func(int(x), int(y))): v for (x, y), v in self.entries.items()
246        }
247        self.vertical_lines = [
248            _VerticalLine(func_x(x), func_y(y1), func_y(y2), emph)
249            for x, y1, y2, emph in self.vertical_lines
250        ]
251        self.horizontal_lines = [
252            _HorizontalLine(func_y(y), func_x(x1), func_x(x2), emph)
253            for y, x1, x2, emph in self.horizontal_lines
254        ]
255        self.horizontal_padding = {
256            int(func_x(int(x))): padding for x, padding in self.horizontal_padding.items()
257        }
258        self.vertical_padding = {
259            int(func_y(int(y))): padding for y, padding in self.vertical_padding.items()
260        }
261
262    def insert_empty_columns(self, x: int, amount: int = 1) -> None:
263        """Insert a number of columns after the given column."""
264
265        def transform_columns(
266            column: Union[int, float], row: Union[int, float]
267        ) -> Tuple[Union[int, float], Union[int, float]]:
268            return column + (amount if column >= x else 0), row
269
270        self._transform_coordinates(transform_columns)
271
272    def insert_empty_rows(self, y: int, amount: int = 1) -> None:
273        """Insert a number of rows after the given row."""
274
275        def transform_rows(
276            column: Union[int, float], row: Union[int, float]
277        ) -> Tuple[Union[int, float], Union[int, float]]:
278            return column, row + (amount if row >= y else 0)
279
280        self._transform_coordinates(transform_rows)
281
282    def render(
283        self,
284        horizontal_spacing: int = 1,
285        vertical_spacing: int = 1,
286        crossing_char: str = None,
287        use_unicode_characters: bool = True,
288    ) -> str:
289        """Outputs text containing the diagram."""
290
291        block_diagram = BlockDiagramDrawer()
292
293        w = self.width()
294        h = self.height()
295
296        # Communicate padding into block diagram.
297        for x in range(0, w - 1):
298            block_diagram.set_col_min_width(
299                x * 2 + 1,
300                # Horizontal separation looks narrow, so partials round up.
301                int(np.ceil(self.horizontal_padding.get(x, horizontal_spacing))),
302            )
303            block_diagram.set_col_min_width(x * 2, 1)
304        for y in range(0, h - 1):
305            block_diagram.set_row_min_height(
306                y * 2 + 1,
307                # Vertical separation looks wide, so partials round down.
308                int(np.floor(self.vertical_padding.get(y, vertical_spacing))),
309            )
310            block_diagram.set_row_min_height(y * 2, 1)
311
312        # Draw vertical lines.
313        for x_b, y1_b, y2_b, emphasize in self.vertical_lines:
314            x = int(x_b * 2)
315            y1, y2 = int(min(y1_b, y2_b) * 2), int(max(y1_b, y2_b) * 2)
316            charset = pick_charset(use_unicode_characters, emphasize)
317
318            # Caps.
319            block_diagram.mutable_block(x, y1).draw_curve(charset, bottom=True)
320            block_diagram.mutable_block(x, y2).draw_curve(charset, top=True)
321
322            # Span.
323            for y in range(y1 + 1, y2):
324                block_diagram.mutable_block(x, y).draw_curve(charset, top=True, bottom=True)
325
326        # Draw horizontal lines.
327        for y_b, x1_b, x2_b, emphasize in self.horizontal_lines:
328            y = int(y_b * 2)
329            x1, x2 = int(min(x1_b, x2_b) * 2), int(max(x1_b, x2_b) * 2)
330            charset = pick_charset(use_unicode_characters, emphasize)
331
332            # Caps.
333            block_diagram.mutable_block(x1, y).draw_curve(charset, right=True)
334            block_diagram.mutable_block(x2, y).draw_curve(charset, left=True)
335
336            # Span.
337            for x in range(x1 + 1, x2):
338                block_diagram.mutable_block(x, y).draw_curve(
339                    charset, left=True, right=True, crossing_char=crossing_char
340                )
341
342        # Place entries.
343        for (x, y), v in self.entries.items():
344            x *= 2
345            y *= 2
346            block_diagram.mutable_block(x, y).content = v.text
347
348        return block_diagram.render()
349
350    def copy(self):
351        return self.__class__(
352            entries=self.entries,
353            vertical_lines=self.vertical_lines,
354            horizontal_lines=self.horizontal_lines,
355            vertical_padding=self.vertical_padding,
356            horizontal_padding=self.horizontal_padding,
357        )
358
359    def shift(self, dx: int = 0, dy: int = 0) -> 'TextDiagramDrawer':
360        self._transform_coordinates(lambda x, y: (x + dx, y + dy))
361        return self
362
363    def shifted(self, dx: int = 0, dy: int = 0) -> 'TextDiagramDrawer':
364        return self.copy().shift(dx, dy)
365
366    def superimpose(self, other: 'TextDiagramDrawer') -> 'TextDiagramDrawer':
367        self.entries.update(other.entries)
368        self.horizontal_lines += other.horizontal_lines
369        self.vertical_lines += other.vertical_lines
370        self.horizontal_padding.update(other.horizontal_padding)
371        self.vertical_padding.update(other.vertical_padding)
372        return self
373
374    def superimposed(self, other: 'TextDiagramDrawer') -> 'TextDiagramDrawer':
375        return self.copy().superimpose(other)
376
377    @classmethod
378    def vstack(
379        cls,
380        diagrams: Sequence['TextDiagramDrawer'],
381        padding_resolver: Optional[Callable[[Sequence[Optional[int]]], int]] = None,
382    ):
383        """Vertically stack text diagrams.
384
385        Args:
386            diagrams: The diagrams to stack, ordered from bottom to top.
387            padding_resolver: A function that takes a list of paddings
388                specified for a column and returns the padding to use in the
389                stacked diagram. If None, defaults to raising ValueError if the
390                diagrams to stack contain inconsistent padding in any column,
391                including if some specify a padding and others don't.
392
393        Raises:
394            ValueError: Inconsistent padding cannot be resolved.
395
396        Returns:
397            The vertically stacked diagram.
398        """
399
400        if padding_resolver is None:
401            padding_resolver = _same_element_or_throw_error
402
403        stacked = cls()
404        dy = 0
405        for diagram in diagrams:
406            stacked.superimpose(diagram.shifted(dy=dy))
407            dy += diagram.height()
408        for x in stacked.horizontal_padding:
409            resolved_padding = padding_resolver(
410                tuple(
411                    cast(Optional[int], diagram.horizontal_padding.get(x)) for diagram in diagrams
412                )
413            )
414            if resolved_padding is not None:
415                stacked.horizontal_padding[x] = resolved_padding
416        return stacked
417
418    @classmethod
419    def hstack(
420        cls,
421        diagrams: Sequence['TextDiagramDrawer'],
422        padding_resolver: Optional[Callable[[Sequence[Optional[int]]], int]] = None,
423    ):
424        """Horizontally stack text diagrams.
425
426        Args:
427            diagrams: The diagrams to stack, ordered from left to right.
428            padding_resolver: A function that takes a list of paddings
429                specified for a row and returns the padding to use in the
430                stacked diagram. Defaults to raising ValueError if the diagrams
431                to stack contain inconsistent padding in any row, including
432                if some specify a padding and others don't.
433
434        Raises:
435            ValueError: Inconsistent padding cannot be resolved.
436
437        Returns:
438            The horizontally stacked diagram.
439        """
440
441        if padding_resolver is None:
442            padding_resolver = _same_element_or_throw_error
443
444        stacked = cls()
445        dx = 0
446        for diagram in diagrams:
447            stacked.superimpose(diagram.shifted(dx=dx))
448            dx += diagram.width()
449        for y in stacked.vertical_padding:
450            resolved_padding = padding_resolver(
451                tuple(cast(Optional[int], diagram.vertical_padding.get(y)) for diagram in diagrams)
452            )
453            if resolved_padding is not None:
454                stacked.vertical_padding[y] = resolved_padding
455        return stacked
456
457
458def _same_element_or_throw_error(elements: Sequence[Any]):
459    """Extract an element or throw an error.
460
461    Args:
462        elements: A sequence of something.
463
464    copies of it. Returns None on an empty sequence.
465
466    Raises:
467        ValueError: The sequence contains more than one unique element.
468
469    Returns:
470        The element when given a sequence containing only multiple copies of a
471        single element. None if elements is empty.
472    """
473    unique_elements = set(elements)
474    if len(unique_elements) > 1:
475        raise ValueError(f'len(set({elements})) > 1')
476    return unique_elements.pop() if elements else None
477