1from collections import defaultdict
2from itertools import chain
3from operator import itemgetter
4from typing import Dict, Iterable, List, Optional, Tuple
5
6from .align import Align, AlignMethod
7from .console import Console, ConsoleOptions, RenderableType, RenderResult
8from .constrain import Constrain
9from .measure import Measurement
10from .padding import Padding, PaddingDimensions
11from .table import Table
12from .text import TextType
13from .jupyter import JupyterMixin
14
15
16class Columns(JupyterMixin):
17    """Display renderables in neat columns.
18
19    Args:
20        renderables (Iterable[RenderableType]): Any number of Rich renderables (including str).
21        width (int, optional): The desired width of the columns, or None to auto detect. Defaults to None.
22        padding (PaddingDimensions, optional): Optional padding around cells. Defaults to (0, 1).
23        expand (bool, optional): Expand columns to full width. Defaults to False.
24        equal (bool, optional): Arrange in to equal sized columns. Defaults to False.
25        column_first (bool, optional): Align items from top to bottom (rather than left to right). Defaults to False.
26        right_to_left (bool, optional): Start column from right hand side. Defaults to False.
27        align (str, optional): Align value ("left", "right", or "center") or None for default. Defaults to None.
28        title (TextType, optional): Optional title for Columns.
29    """
30
31    def __init__(
32        self,
33        renderables: Iterable[RenderableType] = None,
34        padding: PaddingDimensions = (0, 1),
35        *,
36        width: int = None,
37        expand: bool = False,
38        equal: bool = False,
39        column_first: bool = False,
40        right_to_left: bool = False,
41        align: AlignMethod = None,
42        title: TextType = None,
43    ) -> None:
44        self.renderables = list(renderables or [])
45        self.width = width
46        self.padding = padding
47        self.expand = expand
48        self.equal = equal
49        self.column_first = column_first
50        self.right_to_left = right_to_left
51        self.align = align
52        self.title = title
53
54    def add_renderable(self, renderable: RenderableType) -> None:
55        """Add a renderable to the columns.
56
57        Args:
58            renderable (RenderableType): Any renderable object.
59        """
60        self.renderables.append(renderable)
61
62    def __rich_console__(
63        self, console: Console, options: ConsoleOptions
64    ) -> RenderResult:
65        render_str = console.render_str
66        renderables = [
67            render_str(renderable) if isinstance(renderable, str) else renderable
68            for renderable in self.renderables
69        ]
70        if not renderables:
71            return
72        _top, right, _bottom, left = Padding.unpack(self.padding)
73        width_padding = max(left, right)
74        max_width = options.max_width
75        widths: Dict[int, int] = defaultdict(int)
76        column_count = len(renderables)
77
78        get_measurement = Measurement.get
79        renderable_widths = [
80            get_measurement(console, options, renderable).maximum
81            for renderable in renderables
82        ]
83        if self.equal:
84            renderable_widths = [max(renderable_widths)] * len(renderable_widths)
85
86        def iter_renderables(
87            column_count: int,
88        ) -> Iterable[Tuple[int, Optional[RenderableType]]]:
89            item_count = len(renderables)
90            if self.column_first:
91                width_renderables = list(zip(renderable_widths, renderables))
92
93                column_lengths: List[int] = [item_count // column_count] * column_count
94                for col_no in range(item_count % column_count):
95                    column_lengths[col_no] += 1
96
97                row_count = (item_count + column_count - 1) // column_count
98                cells = [[-1] * column_count for _ in range(row_count)]
99                row = col = 0
100                for index in range(item_count):
101                    cells[row][col] = index
102                    column_lengths[col] -= 1
103                    if column_lengths[col]:
104                        row += 1
105                    else:
106                        col += 1
107                        row = 0
108                for index in chain.from_iterable(cells):
109                    if index == -1:
110                        break
111                    yield width_renderables[index]
112            else:
113                yield from zip(renderable_widths, renderables)
114            # Pad odd elements with spaces
115            if item_count % column_count:
116                for _ in range(column_count - (item_count % column_count)):
117                    yield 0, None
118
119        table = Table.grid(padding=self.padding, collapse_padding=True, pad_edge=False)
120        table.expand = self.expand
121        table.title = self.title
122
123        if self.width is not None:
124            column_count = (max_width) // (self.width + width_padding)
125            for _ in range(column_count):
126                table.add_column(width=self.width)
127        else:
128            while column_count > 1:
129                widths.clear()
130                column_no = 0
131                for renderable_width, _ in iter_renderables(column_count):
132                    widths[column_no] = max(widths[column_no], renderable_width)
133                    total_width = sum(widths.values()) + width_padding * (
134                        len(widths) - 1
135                    )
136                    if total_width > max_width:
137                        column_count = len(widths) - 1
138                        break
139                    else:
140                        column_no = (column_no + 1) % column_count
141                else:
142                    break
143
144        get_renderable = itemgetter(1)
145        _renderables = [
146            get_renderable(_renderable)
147            for _renderable in iter_renderables(column_count)
148        ]
149        if self.equal:
150            _renderables = [
151                None
152                if renderable is None
153                else Constrain(renderable, renderable_widths[0])
154                for renderable in _renderables
155            ]
156        if self.align:
157            align = self.align
158            _Align = Align
159            _renderables = [
160                None if renderable is None else _Align(renderable, align)
161                for renderable in _renderables
162            ]
163
164        right_to_left = self.right_to_left
165        add_row = table.add_row
166        for start in range(0, len(_renderables), column_count):
167            row = _renderables[start : start + column_count]
168            if right_to_left:
169                row = row[::-1]
170            add_row(*row)
171        yield table
172
173
174if __name__ == "__main__":  # pragma: no cover
175    import os
176
177    console = Console()
178
179    from rich.panel import Panel
180
181    files = [f"{i} {s}" for i, s in enumerate(sorted(os.listdir()))]
182    columns = Columns(files, padding=(0, 1), expand=False, equal=False)
183    console.print(columns)
184    console.rule()
185    columns.column_first = True
186    console.print(columns)
187    columns.right_to_left = True
188    console.rule()
189    console.print(columns)
190