1import os.path
2import platform
3from rich.containers import Lines
4import textwrap
5from abc import ABC, abstractmethod
6from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union
7
8from pygments.lexers import get_lexer_by_name, guess_lexer_for_filename
9from pygments.style import Style as PygmentsStyle
10from pygments.styles import get_style_by_name
11from pygments.token import (
12    Comment,
13    Error,
14    Generic,
15    Keyword,
16    Name,
17    Number,
18    Operator,
19    String,
20    Token,
21    Whitespace,
22)
23from pygments.util import ClassNotFound
24
25from ._loop import loop_first
26from .color import Color, blend_rgb
27from .console import Console, ConsoleOptions, JustifyMethod, RenderResult
28from .jupyter import JupyterMixin
29from .measure import Measurement
30from .segment import Segment
31from .style import Style
32from .text import Text
33
34TokenType = Tuple[str, ...]
35
36WINDOWS = platform.system() == "Windows"
37DEFAULT_THEME = "monokai"
38
39# The following styles are based on https://github.com/pygments/pygments/blob/master/pygments/formatters/terminal.py
40# A few modifications were made
41
42ANSI_LIGHT: Dict[TokenType, Style] = {
43    Token: Style(),
44    Whitespace: Style(color="white"),
45    Comment: Style(dim=True),
46    Comment.Preproc: Style(color="cyan"),
47    Keyword: Style(color="blue"),
48    Keyword.Type: Style(color="cyan"),
49    Operator.Word: Style(color="magenta"),
50    Name.Builtin: Style(color="cyan"),
51    Name.Function: Style(color="green"),
52    Name.Namespace: Style(color="cyan", underline=True),
53    Name.Class: Style(color="green", underline=True),
54    Name.Exception: Style(color="cyan"),
55    Name.Decorator: Style(color="magenta", bold=True),
56    Name.Variable: Style(color="red"),
57    Name.Constant: Style(color="red"),
58    Name.Attribute: Style(color="cyan"),
59    Name.Tag: Style(color="bright_blue"),
60    String: Style(color="yellow"),
61    Number: Style(color="blue"),
62    Generic.Deleted: Style(color="bright_red"),
63    Generic.Inserted: Style(color="green"),
64    Generic.Heading: Style(bold=True),
65    Generic.Subheading: Style(color="magenta", bold=True),
66    Generic.Prompt: Style(bold=True),
67    Generic.Error: Style(color="bright_red"),
68    Error: Style(color="red", underline=True),
69}
70
71ANSI_DARK: Dict[TokenType, Style] = {
72    Token: Style(),
73    Whitespace: Style(color="bright_black"),
74    Comment: Style(dim=True),
75    Comment.Preproc: Style(color="bright_cyan"),
76    Keyword: Style(color="bright_blue"),
77    Keyword.Type: Style(color="bright_cyan"),
78    Operator.Word: Style(color="bright_magenta"),
79    Name.Builtin: Style(color="bright_cyan"),
80    Name.Function: Style(color="bright_green"),
81    Name.Namespace: Style(color="bright_cyan", underline=True),
82    Name.Class: Style(color="bright_green", underline=True),
83    Name.Exception: Style(color="bright_cyan"),
84    Name.Decorator: Style(color="bright_magenta", bold=True),
85    Name.Variable: Style(color="bright_red"),
86    Name.Constant: Style(color="bright_red"),
87    Name.Attribute: Style(color="bright_cyan"),
88    Name.Tag: Style(color="bright_blue"),
89    String: Style(color="yellow"),
90    Number: Style(color="bright_blue"),
91    Generic.Deleted: Style(color="bright_red"),
92    Generic.Inserted: Style(color="bright_green"),
93    Generic.Heading: Style(bold=True),
94    Generic.Subheading: Style(color="bright_magenta", bold=True),
95    Generic.Prompt: Style(bold=True),
96    Generic.Error: Style(color="bright_red"),
97    Error: Style(color="red", underline=True),
98}
99
100RICH_SYNTAX_THEMES = {"ansi_light": ANSI_LIGHT, "ansi_dark": ANSI_DARK}
101
102
103class SyntaxTheme(ABC):
104    """Base class for a syntax theme."""
105
106    @abstractmethod
107    def get_style_for_token(self, token_type: TokenType) -> Style:
108        """Get a style for a given Pygments token."""
109        raise NotImplementedError  # pragma: no cover
110
111    @abstractmethod
112    def get_background_style(self) -> Style:
113        """Get the background color."""
114        raise NotImplementedError  # pragma: no cover
115
116
117class PygmentsSyntaxTheme(SyntaxTheme):
118    """Syntax theme that delegates to Pygments theme."""
119
120    def __init__(self, theme: Union[str, Type[PygmentsStyle]]) -> None:
121        self._style_cache: Dict[TokenType, Style] = {}
122        if isinstance(theme, str):
123            try:
124                self._pygments_style_class = get_style_by_name(theme)
125            except ClassNotFound:
126                self._pygments_style_class = get_style_by_name("default")
127        else:
128            self._pygments_style_class = theme
129
130        self._background_color = self._pygments_style_class.background_color
131        self._background_style = Style(bgcolor=self._background_color)
132
133    def get_style_for_token(self, token_type: TokenType) -> Style:
134        """Get a style from a Pygments class."""
135        try:
136            return self._style_cache[token_type]
137        except KeyError:
138            try:
139                pygments_style = self._pygments_style_class.style_for_token(token_type)
140            except KeyError:
141                style = Style.null()
142            else:
143                color = pygments_style["color"]
144                bgcolor = pygments_style["bgcolor"]
145                style = Style(
146                    color="#" + color if color else "#000000",
147                    bgcolor="#" + bgcolor if bgcolor else self._background_color,
148                    bold=pygments_style["bold"],
149                    italic=pygments_style["italic"],
150                    underline=pygments_style["underline"],
151                )
152            self._style_cache[token_type] = style
153        return style
154
155    def get_background_style(self) -> Style:
156        return self._background_style
157
158
159class ANSISyntaxTheme(SyntaxTheme):
160    """Syntax theme to use standard colors."""
161
162    def __init__(self, style_map: Dict[TokenType, Style]) -> None:
163        self.style_map = style_map
164        self._missing_style = Style.null()
165        self._background_style = Style.null()
166        self._style_cache: Dict[TokenType, Style] = {}
167
168    def get_style_for_token(self, token_type: TokenType) -> Style:
169        """Look up style in the style map."""
170        try:
171            return self._style_cache[token_type]
172        except KeyError:
173            # Styles form a hierarchy
174            # We need to go from most to least specific
175            # e.g. ("foo", "bar", "baz") to ("foo", "bar")  to ("foo",)
176            get_style = self.style_map.get
177            token = tuple(token_type)
178            style = self._missing_style
179            while token:
180                _style = get_style(token)
181                if _style is not None:
182                    style = _style
183                    break
184                token = token[:-1]
185            self._style_cache[token_type] = style
186            return style
187
188    def get_background_style(self) -> Style:
189        return self._background_style
190
191
192class Syntax(JupyterMixin):
193    """Construct a Syntax object to render syntax highlighted code.
194
195    Args:
196        code (str): Code to highlight.
197        lexer_name (str): Lexer to use (see https://pygments.org/docs/lexers/)
198        theme (str, optional): Color theme, aka Pygments style (see https://pygments.org/docs/styles/#getting-a-list-of-available-styles). Defaults to "monokai".
199        dedent (bool, optional): Enable stripping of initial whitespace. Defaults to False.
200        line_numbers (bool, optional): Enable rendering of line numbers. Defaults to False.
201        start_line (int, optional): Starting number for line numbers. Defaults to 1.
202        line_range (Tuple[int, int], optional): If given should be a tuple of the start and end line to render.
203        highlight_lines (Set[int]): A set of line numbers to highlight.
204        code_width: Width of code to render (not including line numbers), or ``None`` to use all available width.
205        tab_size (int, optional): Size of tabs. Defaults to 4.
206        word_wrap (bool, optional): Enable word wrapping.
207        background_color (str, optional): Optional background color, or None to use theme color. Defaults to None.
208        indent_guides (bool, optional): Show indent guides. Defaults to False.
209    """
210
211    _pygments_style_class: Type[PygmentsStyle]
212    _theme: SyntaxTheme
213
214    @classmethod
215    def get_theme(cls, name: Union[str, SyntaxTheme]) -> SyntaxTheme:
216        """Get a syntax theme instance."""
217        if isinstance(name, SyntaxTheme):
218            return name
219        theme: SyntaxTheme
220        if name in RICH_SYNTAX_THEMES:
221            theme = ANSISyntaxTheme(RICH_SYNTAX_THEMES[name])
222        else:
223            theme = PygmentsSyntaxTheme(name)
224        return theme
225
226    def __init__(
227        self,
228        code: str,
229        lexer_name: str,
230        *,
231        theme: Union[str, SyntaxTheme] = DEFAULT_THEME,
232        dedent: bool = False,
233        line_numbers: bool = False,
234        start_line: int = 1,
235        line_range: Optional[Tuple[int, int]] = None,
236        highlight_lines: Optional[Set[int]] = None,
237        code_width: Optional[int] = None,
238        tab_size: int = 4,
239        word_wrap: bool = False,
240        background_color: Optional[str] = None,
241        indent_guides: bool = False,
242    ) -> None:
243        self.code = code
244        self.lexer_name = lexer_name
245        self.dedent = dedent
246        self.line_numbers = line_numbers
247        self.start_line = start_line
248        self.line_range = line_range
249        self.highlight_lines = highlight_lines or set()
250        self.code_width = code_width
251        self.tab_size = tab_size
252        self.word_wrap = word_wrap
253        self.background_color = background_color
254        self.background_style = (
255            Style(bgcolor=background_color) if background_color else Style()
256        )
257        self.indent_guides = indent_guides
258
259        self._theme = self.get_theme(theme)
260
261    @classmethod
262    def from_path(
263        cls,
264        path: str,
265        encoding: str = "utf-8",
266        theme: Union[str, SyntaxTheme] = DEFAULT_THEME,
267        dedent: bool = False,
268        line_numbers: bool = False,
269        line_range: Optional[Tuple[int, int]] = None,
270        start_line: int = 1,
271        highlight_lines: Optional[Set[int]] = None,
272        code_width: Optional[int] = None,
273        tab_size: int = 4,
274        word_wrap: bool = False,
275        background_color: Optional[str] = None,
276        indent_guides: bool = False,
277    ) -> "Syntax":
278        """Construct a Syntax object from a file.
279
280        Args:
281            path (str): Path to file to highlight.
282            encoding (str): Encoding of file.
283            theme (str, optional): Color theme, aka Pygments style (see https://pygments.org/docs/styles/#getting-a-list-of-available-styles). Defaults to "emacs".
284            dedent (bool, optional): Enable stripping of initial whitespace. Defaults to True.
285            line_numbers (bool, optional): Enable rendering of line numbers. Defaults to False.
286            start_line (int, optional): Starting number for line numbers. Defaults to 1.
287            line_range (Tuple[int, int], optional): If given should be a tuple of the start and end line to render.
288            highlight_lines (Set[int]): A set of line numbers to highlight.
289            code_width: Width of code to render (not including line numbers), or ``None`` to use all available width.
290            tab_size (int, optional): Size of tabs. Defaults to 4.
291            word_wrap (bool, optional): Enable word wrapping of code.
292            background_color (str, optional): Optional background color, or None to use theme color. Defaults to None.
293            indent_guides (bool, optional): Show indent guides. Defaults to False.
294
295        Returns:
296            [Syntax]: A Syntax object that may be printed to the console
297        """
298        with open(path, "rt", encoding=encoding) as code_file:
299            code = code_file.read()
300
301        lexer = None
302        lexer_name = "default"
303        try:
304            _, ext = os.path.splitext(path)
305            if ext:
306                extension = ext.lstrip(".").lower()
307                lexer = get_lexer_by_name(extension)
308                lexer_name = lexer.name
309        except ClassNotFound:
310            pass
311
312        if lexer is None:
313            try:
314                lexer_name = guess_lexer_for_filename(path, code).name
315            except ClassNotFound:
316                pass
317
318        return cls(
319            code,
320            lexer_name,
321            theme=theme,
322            dedent=dedent,
323            line_numbers=line_numbers,
324            line_range=line_range,
325            start_line=start_line,
326            highlight_lines=highlight_lines,
327            code_width=code_width,
328            tab_size=tab_size,
329            word_wrap=word_wrap,
330            background_color=background_color,
331            indent_guides=indent_guides,
332        )
333
334    def _get_base_style(self) -> Style:
335        """Get the base style."""
336        default_style = self._theme.get_background_style() + self.background_style
337        return default_style
338
339    def _get_token_color(self, token_type: TokenType) -> Optional[Color]:
340        """Get a color (if any) for the given token.
341
342        Args:
343            token_type (TokenType): A token type tuple from Pygments.
344
345        Returns:
346            Optional[Color]: Color from theme, or None for no color.
347        """
348        style = self._theme.get_style_for_token(token_type)
349        return style.color
350
351    def highlight(
352        self, code: str, line_range: Optional[Tuple[int, int]] = None
353    ) -> Text:
354        """Highlight code and return a Text instance.
355
356        Args:
357            code (str): Code to highlight.
358            line_range(Tuple[int, int], optional): Optional line range to highlight.
359
360        Returns:
361            Text: A text instance containing highlighted syntax.
362        """
363
364        base_style = self._get_base_style()
365        justify: JustifyMethod = (
366            "default" if base_style.transparent_background else "left"
367        )
368
369        text = Text(
370            justify=justify,
371            style=base_style,
372            tab_size=self.tab_size,
373            no_wrap=not self.word_wrap,
374        )
375        _get_theme_style = self._theme.get_style_for_token
376        try:
377            lexer = get_lexer_by_name(
378                self.lexer_name,
379                stripnl=False,
380                ensurenl=True,
381                tabsize=self.tab_size,
382            )
383        except ClassNotFound:
384            text.append(code)
385        else:
386            if line_range:
387                # More complicated path to only stylize a portion of the code
388                # This speeds up further operations as there are less spans to process
389                line_start, line_end = line_range
390
391                def line_tokenize() -> Iterable[Tuple[Any, str]]:
392                    """Split tokens to one per line."""
393                    for token_type, token in lexer.get_tokens(code):
394                        while token:
395                            line_token, new_line, token = token.partition("\n")
396                            yield token_type, line_token + new_line
397
398                def tokens_to_spans() -> Iterable[Tuple[str, Optional[Style]]]:
399                    """Convert tokens to spans."""
400                    tokens = iter(line_tokenize())
401                    line_no = 0
402                    _line_start = line_start - 1
403
404                    # Skip over tokens until line start
405                    while line_no < _line_start:
406                        _token_type, token = next(tokens)
407                        yield (token, None)
408                        if token.endswith("\n"):
409                            line_no += 1
410                    # Generate spans until line end
411                    for token_type, token in tokens:
412                        yield (token, _get_theme_style(token_type))
413                        if token.endswith("\n"):
414                            line_no += 1
415                            if line_no >= line_end:
416                                break
417
418                text.append_tokens(tokens_to_spans())
419
420            else:
421                text.append_tokens(
422                    (token, _get_theme_style(token_type))
423                    for token_type, token in lexer.get_tokens(code)
424                )
425            if self.background_color is not None:
426                text.stylize(f"on {self.background_color}")
427        return text
428
429    def _get_line_numbers_color(self, blend: float = 0.3) -> Color:
430        background_style = self._theme.get_background_style() + self.background_style
431        background_color = background_style.bgcolor
432        if background_color is None or background_color.is_system_defined:
433            return Color.default()
434        foreground_color = self._get_token_color(Token.Text)
435        if foreground_color is None or foreground_color.is_system_defined:
436            return foreground_color or Color.default()
437        new_color = blend_rgb(
438            background_color.get_truecolor(),
439            foreground_color.get_truecolor(),
440            cross_fade=blend,
441        )
442        return Color.from_triplet(new_color)
443
444    @property
445    def _numbers_column_width(self) -> int:
446        """Get the number of characters used to render the numbers column."""
447        column_width = 0
448        if self.line_numbers:
449            column_width = len(str(self.start_line + self.code.count("\n"))) + 2
450        return column_width
451
452    def _get_number_styles(self, console: Console) -> Tuple[Style, Style, Style]:
453        """Get background, number, and highlight styles for line numbers."""
454        background_style = self._get_base_style()
455        if background_style.transparent_background:
456            return Style.null(), Style(dim=True), Style.null()
457        if console.color_system in ("256", "truecolor"):
458            number_style = Style.chain(
459                background_style,
460                self._theme.get_style_for_token(Token.Text),
461                Style(color=self._get_line_numbers_color()),
462                self.background_style,
463            )
464            highlight_number_style = Style.chain(
465                background_style,
466                self._theme.get_style_for_token(Token.Text),
467                Style(bold=True, color=self._get_line_numbers_color(0.9)),
468                self.background_style,
469            )
470        else:
471            number_style = background_style + Style(dim=True)
472            highlight_number_style = background_style + Style(dim=False)
473        return background_style, number_style, highlight_number_style
474
475    def __rich_measure__(
476        self, console: "Console", options: "ConsoleOptions"
477    ) -> "Measurement":
478        if self.code_width is not None:
479            width = self.code_width + self._numbers_column_width
480            return Measurement(self._numbers_column_width, width)
481        return Measurement(self._numbers_column_width, options.max_width)
482
483    def __rich_console__(
484        self, console: Console, options: ConsoleOptions
485    ) -> RenderResult:
486
487        transparent_background = self._get_base_style().transparent_background
488        code_width = (
489            (
490                (options.max_width - self._numbers_column_width - 1)
491                if self.line_numbers
492                else options.max_width
493            )
494            if self.code_width is None
495            else self.code_width
496        )
497
498        line_offset = 0
499        if self.line_range:
500            start_line, end_line = self.line_range
501            line_offset = max(0, start_line - 1)
502
503        ends_on_nl = self.code.endswith("\n")
504        code = self.code if ends_on_nl else self.code + "\n"
505        code = textwrap.dedent(code) if self.dedent else code
506        code = code.expandtabs(self.tab_size)
507        text = self.highlight(code, self.line_range)
508
509        (
510            background_style,
511            number_style,
512            highlight_number_style,
513        ) = self._get_number_styles(console)
514
515        if not self.line_numbers and not self.word_wrap and not self.line_range:
516            if not ends_on_nl:
517                text.remove_suffix("\n")
518            # Simple case of just rendering text
519            style = (
520                self._get_base_style()
521                + self._theme.get_style_for_token(Comment)
522                + Style(dim=True)
523                + self.background_style
524            )
525            if self.indent_guides and not options.ascii_only:
526                text = text.with_indent_guides(self.tab_size, style=style)
527                text.overflow = "crop"
528            if style.transparent_background:
529                yield from console.render(
530                    text, options=options.update(width=code_width)
531                )
532            else:
533                syntax_lines = console.render_lines(
534                    text,
535                    options.update(width=code_width, height=None),
536                    style=self.background_style,
537                    pad=True,
538                    new_lines=True,
539                )
540                for syntax_line in syntax_lines:
541                    yield from syntax_line
542            return
543
544        lines: Union[List[Text], Lines] = text.split("\n", allow_blank=ends_on_nl)
545        if self.line_range:
546            lines = lines[line_offset:end_line]
547
548        if self.indent_guides and not options.ascii_only:
549            style = (
550                self._get_base_style()
551                + self._theme.get_style_for_token(Comment)
552                + Style(dim=True)
553                + self.background_style
554            )
555            lines = (
556                Text("\n")
557                .join(lines)
558                .with_indent_guides(self.tab_size, style=style)
559                .split("\n", allow_blank=True)
560            )
561
562        numbers_column_width = self._numbers_column_width
563        render_options = options.update(width=code_width)
564
565        highlight_line = self.highlight_lines.__contains__
566        _Segment = Segment
567        padding = _Segment(" " * numbers_column_width + " ", background_style)
568        new_line = _Segment("\n")
569
570        line_pointer = "> " if options.legacy_windows else "❱ "
571
572        for line_no, line in enumerate(lines, self.start_line + line_offset):
573            if self.word_wrap:
574                wrapped_lines = console.render_lines(
575                    line,
576                    render_options.update(height=None),
577                    style=background_style,
578                    pad=not transparent_background,
579                )
580
581            else:
582                segments = list(line.render(console, end=""))
583                if options.no_wrap:
584                    wrapped_lines = [segments]
585                else:
586                    wrapped_lines = [
587                        _Segment.adjust_line_length(
588                            segments,
589                            render_options.max_width,
590                            style=background_style,
591                            pad=not transparent_background,
592                        )
593                    ]
594            if self.line_numbers:
595                for first, wrapped_line in loop_first(wrapped_lines):
596                    if first:
597                        line_column = str(line_no).rjust(numbers_column_width - 2) + " "
598                        if highlight_line(line_no):
599                            yield _Segment(line_pointer, Style(color="red"))
600                            yield _Segment(line_column, highlight_number_style)
601                        else:
602                            yield _Segment("  ", highlight_number_style)
603                            yield _Segment(line_column, number_style)
604                    else:
605                        yield padding
606                    yield from wrapped_line
607                    yield new_line
608            else:
609                for wrapped_line in wrapped_lines:
610                    yield from wrapped_line
611                    yield new_line
612
613
614if __name__ == "__main__":  # pragma: no cover
615
616    import argparse
617    import sys
618
619    parser = argparse.ArgumentParser(
620        description="Render syntax to the console with Rich"
621    )
622    parser.add_argument(
623        "path",
624        metavar="PATH",
625        help="path to file, or - for stdin",
626    )
627    parser.add_argument(
628        "-c",
629        "--force-color",
630        dest="force_color",
631        action="store_true",
632        default=None,
633        help="force color for non-terminals",
634    )
635    parser.add_argument(
636        "-i",
637        "--indent-guides",
638        dest="indent_guides",
639        action="store_true",
640        default=False,
641        help="display indent guides",
642    )
643    parser.add_argument(
644        "-l",
645        "--line-numbers",
646        dest="line_numbers",
647        action="store_true",
648        help="render line numbers",
649    )
650    parser.add_argument(
651        "-w",
652        "--width",
653        type=int,
654        dest="width",
655        default=None,
656        help="width of output (default will auto-detect)",
657    )
658    parser.add_argument(
659        "-r",
660        "--wrap",
661        dest="word_wrap",
662        action="store_true",
663        default=False,
664        help="word wrap long lines",
665    )
666    parser.add_argument(
667        "-s",
668        "--soft-wrap",
669        action="store_true",
670        dest="soft_wrap",
671        default=False,
672        help="enable soft wrapping mode",
673    )
674    parser.add_argument(
675        "-t", "--theme", dest="theme", default="monokai", help="pygments theme"
676    )
677    parser.add_argument(
678        "-b",
679        "--background-color",
680        dest="background_color",
681        default=None,
682        help="Override background color",
683    )
684    parser.add_argument(
685        "-x",
686        "--lexer",
687        default="default",
688        dest="lexer_name",
689        help="Lexer name",
690    )
691    args = parser.parse_args()
692
693    from rich.console import Console
694
695    console = Console(force_terminal=args.force_color, width=args.width)
696
697    if args.path == "-":
698        code = sys.stdin.read()
699        syntax = Syntax(
700            code=code,
701            lexer_name=args.lexer_name,
702            line_numbers=args.line_numbers,
703            word_wrap=args.word_wrap,
704            theme=args.theme,
705            background_color=args.background_color,
706            indent_guides=args.indent_guides,
707        )
708    else:
709        syntax = Syntax.from_path(
710            args.path,
711            line_numbers=args.line_numbers,
712            word_wrap=args.word_wrap,
713            theme=args.theme,
714            background_color=args.background_color,
715            indent_guides=args.indent_guides,
716        )
717    console.print(syntax, soft_wrap=args.soft_wrap)
718