1# Copyright 2018 The Cirq Developers
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import List, Optional
16
17import collections
18
19from cirq.circuits._box_drawing_character_data import box_draw_character, BoxDrawCharacterSet
20
21
22class Block:
23    """The mutable building block that block diagrams are made of."""
24
25    def __init__(self):
26        self.left = ''
27        self.right = ''
28        self.top = ''
29        self.bottom = ''
30        self.center = ''
31        self.content = ''
32        self.horizontal_alignment = 0
33        self._prev_curve_grid_chars = None
34
35    def min_width(self) -> int:
36        """Minimum width necessary to render the block's contents."""
37        return max(
38            max(len(e) for e in self.content.split('\n')),
39            # Only horizontal lines can cross 0 width blocks.
40            int(any([self.top, self.bottom])),
41        )
42
43    def min_height(self) -> int:
44        """Minimum height necessary to render the block's contents."""
45        return max(
46            len(self.content.split('\n')) if self.content else 0,
47            # Only vertical lines can cross 0 height blocks.
48            int(any([self.left, self.right])),
49        )
50
51    def draw_curve(
52        self,
53        grid_characters: BoxDrawCharacterSet,
54        *,
55        top: bool = False,
56        left: bool = False,
57        right: bool = False,
58        bottom: bool = False,
59        crossing_char: Optional[str] = None,
60    ):
61        """Draws lines in the box using the given character set.
62
63        Supports merging the new lines with the lines from a previous call to
64        draw_curve, including when they have different character sets (assuming
65        there exist characters merging the two).
66
67        Args:
68            grid_characters: The character set to draw the curve with.
69            top: Draw topward leg?
70            left: Draw leftward leg?
71            right: Draw rightward leg?
72            bottom: Draw downward leg?
73            crossing_char: Overrides the all-legs-present character. Useful for
74                ascii diagrams, where the + doesn't always look the clearest.
75        """
76        if not any([top, left, right, bottom]):
77            return
78
79        # Remember which legs are new, old, or missing.
80        sign_top = +1 if top else -1 if self.top else 0
81        sign_bottom = +1 if bottom else -1 if self.bottom else 0
82        sign_left = +1 if left else -1 if self.left else 0
83        sign_right = +1 if right else -1 if self.right else 0
84
85        # Add new segments.
86        if top:
87            self.top = grid_characters.top_bottom
88        if bottom:
89            self.bottom = grid_characters.top_bottom
90        if left:
91            self.left = grid_characters.left_right
92        if right:
93            self.right = grid_characters.left_right
94
95        # Fill center.
96        if not all([crossing_char, self.top, self.bottom, self.left, self.right]):
97            crossing_char = box_draw_character(
98                self._prev_curve_grid_chars,
99                grid_characters,
100                top=sign_top,
101                bottom=sign_bottom,
102                left=sign_left,
103                right=sign_right,
104            )
105        self.center = crossing_char or ''
106
107        self._prev_curve_grid_chars = grid_characters
108
109    def render(self, width: int, height: int) -> List[str]:
110        """Returns a list of text lines representing the block's contents.
111
112        Args:
113            width: The width of the output text. Must be at least as large as
114                the block's minimum width.
115            height: The height of the output text. Must be at least as large as
116                the block's minimum height.
117
118        Returns:
119            Text pre-split into lines.
120        """
121        if width == 0 or height == 0:
122            return [''] * height
123
124        out_chars = [[' '] * width for _ in range(height)]
125
126        mid_x = int((width - 1) * self.horizontal_alignment)
127        mid_y = (height - 1) // 2
128
129        # Horizontal line legs.
130        if self.left:
131            out_chars[mid_y][: mid_x + 1] = self.left * (mid_x + 1)
132        if self.right:
133            out_chars[mid_y][mid_x:] = self.right * (width - mid_x)
134
135        # Vertical line legs.
136        if self.top:
137            for y in range(mid_y + 1):
138                out_chars[y][mid_x] = self.top
139        if self.bottom:
140            for y in range(mid_y, height):
141                out_chars[y][mid_x] = self.bottom
142
143        # Central content.
144        mid = self.content or self.center
145        if self.content or self.center:
146            content_lines = mid.split('\n')
147            y = mid_y - (len(content_lines) - 1) // 2
148            for dy, content_line in enumerate(content_lines):
149                s = int((len(content_line) - 1) * self.horizontal_alignment)
150                x = mid_x - s
151                for dx, c in enumerate(content_line):
152                    out_chars[y + dy][x + dx] = c
153
154        return [''.join(line) for line in out_chars]
155
156
157class BlockDiagramDrawer:
158    """Aligns text and curve data placed onto an abstract 2d grid of blocks."""
159
160    def __init__(self):
161        self._blocks = collections.defaultdict(Block)  # type: Dict[Tuple[int, int], Block]
162        self._min_widths = collections.defaultdict(lambda: 0)  # type: Dict[int, int]
163        self._min_heights = collections.defaultdict(lambda: 0)  # type: Dict[int, int]
164
165        # Populate the origin.
166        _ = self._blocks[(0, 0)]
167        _ = self._min_widths[0]
168        _ = self._min_heights[0]
169
170    def mutable_block(self, x: int, y: int) -> Block:
171        """Returns the block at (x, y) so it can be edited."""
172        if x < 0 or y < 0:
173            raise IndexError('x < 0 or y < 0')
174        return self._blocks[(x, y)]
175
176    def set_col_min_width(self, x: int, min_width: int):
177        """Sets a minimum width for blocks in the column with coordinate x."""
178        if x < 0:
179            raise IndexError('x < 0')
180        self._min_widths[x] = min_width
181
182    def set_row_min_height(self, y: int, min_height: int):
183        """Sets a minimum height for blocks in the row with coordinate y."""
184        if y < 0:
185            raise IndexError('y < 0')
186        self._min_heights[y] = min_height
187
188    def render(
189        self,
190        *,
191        block_span_x: Optional[int] = None,
192        block_span_y: Optional[int] = None,
193        min_block_width: int = 0,
194        min_block_height: int = 0,
195    ) -> str:
196        """Outputs text containing the diagram.
197
198        Args:
199            block_span_x: The width of the diagram in blocks. Set to None to
200                default to using the smallest width that would include all
201                accessed blocks and columns with a specified minimum width.
202            block_span_y: The height of the diagram in blocks. Set to None to
203                default to using the smallest height that would include all
204                accessed blocks and rows with a specified minimum height.
205            min_block_width: A global minimum width for all blocks.
206            min_block_height: A global minimum height for all blocks.
207
208        Returns:
209            The diagram as a string.
210        """
211
212        # Determine desired size of diagram in blocks.
213        if block_span_x is None:
214            block_span_x = 1 + max(
215                max(x for x, _ in self._blocks.keys()),
216                max(self._min_widths.keys()),
217            )
218        if block_span_y is None:
219            block_span_y = 1 + max(
220                max(y for _, y in self._blocks.keys()),
221                max(self._min_heights.keys()),
222            )
223
224        # Method for accessing blocks without creating new entries.
225        empty = Block()
226
227        def block(x: int, y: int) -> Block:
228            return self._blocks.get((x, y), empty)
229
230        # Determine the width of every column and the height of every row.
231        widths = {
232            x: max(
233                max(block(x, y).min_width() for y in range(block_span_y)),
234                self._min_widths.get(x, 0),
235                min_block_width,
236            )
237            for x in range(block_span_x)
238        }
239        heights = {
240            y: max(
241                max(block(x, y).min_height() for x in range(block_span_x)),
242                self._min_heights.get(y, 0),
243                min_block_height,
244            )
245            for y in range(block_span_y)
246        }
247
248        # Get the individually rendered blocks.
249        block_renders = {
250            (x, y): block(x, y).render(widths[x], heights[y])
251            for x in range(block_span_x)
252            for y in range(block_span_y)
253        }
254
255        # Paste together all of the rows of rendered block content.
256        out_lines = []  # type: List[str]
257        for y in range(block_span_y):
258            for by in range(heights[y]):
259                out_line_chunks = []  # type: List[str]
260                for x in range(block_span_x):
261                    out_line_chunks.extend(block_renders[x, y][by])
262                out_lines.append(''.join(out_line_chunks).rstrip())
263
264        # Then paste together the rows.
265        return '\n'.join(out_lines)
266