1import enum
2import functools
3from typing import (
4    Any,
5    Callable,
6    Dict,
7    Iterable,
8    Iterator,
9    List,
10    Mapping,
11    MutableSet,
12    Optional,
13    Sequence,
14    Tuple,
15    TypeVar,
16    Union,
17)
18
19import attr
20import psutil
21
22from . import colors, utils
23
24
25T = TypeVar("T")
26
27
28E = TypeVar("E", bound=enum.IntEnum)
29
30
31def enum_next(e: E) -> E:
32    """Return an increment value of given enum.
33
34    >>> class Seasons(enum.IntEnum):
35    ...     winter = 1
36    ...     spring = 2
37    ...     summer = 3
38    ...     autumn = 4
39
40    >>> enum_next(Seasons.winter).name
41    'spring'
42    >>> enum_next(Seasons.spring).name
43    'summer'
44    >>> enum_next(Seasons.autumn).name
45    'winter'
46    """
47    return e.__class__((e.value % max(e.__class__)) + 1)
48
49
50@enum.unique
51class Flag(enum.IntFlag):
52    """Column flag."""
53
54    DATABASE = 1
55    APPNAME = 2
56    CLIENT = 4
57    USER = 8
58    CPU = 16
59    MEM = 32
60    READ = 64
61    WRITE = 128
62    TIME = 256
63    WAIT = 512
64    RELATION = 1024
65    TYPE = 2048
66    MODE = 4096
67    IOWAIT = 8192
68    PID = 16384
69
70    @classmethod
71    def all(cls) -> "Flag":
72        return cls(sum(cls))
73
74    @classmethod
75    def from_options(
76        cls,
77        *,
78        is_local: bool,
79        noappname: bool,
80        noclient: bool,
81        nocpu: bool,
82        nodb: bool,
83        nomem: bool,
84        nopid: bool,
85        noread: bool,
86        notime: bool,
87        nouser: bool,
88        nowait: bool,
89        nowrite: bool,
90        **kwargs: Any,
91    ) -> "Flag":
92        """Build a Flag value from command line options."""
93        flag = cls.all()
94        if nodb:
95            flag ^= cls.DATABASE
96        if nouser:
97            flag ^= cls.USER
98        if nocpu:
99            flag ^= cls.CPU
100        if noclient:
101            flag ^= cls.CLIENT
102        if nomem:
103            flag ^= cls.MEM
104        if noread:
105            flag ^= cls.READ
106        if nowrite:
107            flag ^= cls.WRITE
108        if notime:
109            flag ^= cls.TIME
110        if nowait:
111            flag ^= cls.WAIT
112        if noappname:
113            flag ^= cls.APPNAME
114        if nopid:
115            flag ^= cls.PID
116
117        # Remove some if no running against local pg server.
118        if not is_local and (flag & cls.CPU):
119            flag ^= cls.CPU
120        if not is_local and (flag & cls.MEM):
121            flag ^= cls.MEM
122        if not is_local and (flag & cls.READ):
123            flag ^= cls.READ
124        if not is_local and (flag & cls.WRITE):
125            flag ^= cls.WRITE
126        if not is_local and (flag & cls.IOWAIT):
127            flag ^= cls.IOWAIT
128        return flag
129
130
131class SortKey(enum.Enum):
132    cpu = enum.auto()
133    mem = enum.auto()
134    read = enum.auto()
135    write = enum.auto()
136    duration = enum.auto()
137
138    @classmethod
139    def default(cls) -> "SortKey":
140        return cls.duration
141
142
143@enum.unique
144class QueryDisplayMode(enum.IntEnum):
145    truncate = 1
146    wrap_noindent = 2
147    wrap = 3
148
149    @classmethod
150    def default(cls) -> "QueryDisplayMode":
151        return cls.wrap_noindent
152
153
154@enum.unique
155class QueryMode(enum.Enum):
156    activities = "running queries"
157    waiting = "waiting queries"
158    blocking = "blocking queries"
159
160    @classmethod
161    def default(cls) -> "QueryMode":
162        return cls.activities
163
164
165@enum.unique
166class DurationMode(enum.IntEnum):
167    query = 1
168    transaction = 2
169    backend = 3
170
171
172_color_key_marker = f"{id(object())}"
173
174
175@attr.s(auto_attribs=True, frozen=True, slots=True)
176class Column:
177    """A column in stats table.
178
179    >>> c = Column("pid", "PID", "%-6s", True, SortKey.cpu, max_width=6,
180    ...            transform=lambda v: str(v)[::-1])
181    >>> c.title_render()
182    'PID   '
183    >>> c.title_color(SortKey.cpu)
184    'cyan'
185    >>> c.title_color(SortKey.duration)
186    'green'
187    >>> c.render('1234')
188    '4321  '
189    >>> c.render('12345678')
190    '876543'
191    >>> c.color_key
192    'pid'
193    """
194
195    key: str = attr.ib(repr=False)
196    name: str
197    template_h: str = attr.ib()
198    mandatory: bool = False
199    sort_key: Optional[SortKey] = None
200    max_width: Optional[int] = attr.ib(default=None, repr=False)
201    transform: Callable[[Any], str] = attr.ib(default=str, repr=False)
202    color_key: Union[str, Callable[[Any], str]] = attr.ib(
203        default=_color_key_marker, repr=False
204    )
205
206    @template_h.validator
207    def _template_h_is_a_format_string_(self, attribute: Any, value: str) -> None:
208        """Validate template_h attribute.
209
210        >>> Column("k", "a", "b%%aa")
211        Traceback (most recent call last):
212            ...
213        ValueError: template_h must be a format string with one placeholder
214        >>> Column("k", "a", "baad")
215        Traceback (most recent call last):
216            ...
217        ValueError: template_h must be a format string with one placeholder
218        >>> Column("k", "a", "%s is good")  # doctest: +ELLIPSIS
219        Column(name='a', template_h='%s is good', ...)
220        """
221        if value.count("%") != 1:
222            raise ValueError(
223                f"{attribute.name} must be a format string with one placeholder"
224            )
225
226    def __attrs_post_init__(self) -> None:
227        if self.color_key == _color_key_marker:
228            object.__setattr__(self, "color_key", self.key)
229
230    def title_render(self) -> str:
231        return self.template_h % self.name
232
233    def title_color(self, sort_by: SortKey) -> str:
234        if self.sort_key == sort_by:
235            return "cyan"  # TODO: define a Color enum
236        return "green"
237
238    def render(self, value: Any) -> str:
239        return self.template_h % self.transform(value)[: self.max_width]
240
241    def color(self, value: Any) -> str:
242        if callable(self.color_key):
243            return self.color_key(value)
244        return self.color_key
245
246
247@attr.s(auto_attribs=True, slots=True)
248class UI:
249    """State of the UI."""
250
251    columns_by_querymode: Mapping[QueryMode, Tuple[Column, ...]]
252    min_duration: float = 0.0
253    duration_mode: DurationMode = attr.ib(
254        default=DurationMode.query, converter=DurationMode
255    )
256    query_display_mode: QueryDisplayMode = attr.ib(
257        default=QueryDisplayMode.default(), converter=QueryDisplayMode
258    )
259    sort_key: SortKey = attr.ib(default=SortKey.default(), converter=SortKey)
260    query_mode: QueryMode = attr.ib(default=QueryMode.activities, converter=QueryMode)
261    refresh_time: Union[float, int] = 2
262    in_pause: bool = False
263    interactive_timeout: Optional[int] = None
264
265    @classmethod
266    def make(
267        cls,
268        flag: Flag = Flag.all(),
269        *,
270        max_db_length: int = 16,
271        **kwargs: Any,
272    ) -> "UI":
273        possible_columns: Dict[str, Column] = {}
274
275        def add_column(key: str, **kwargs: Any) -> None:
276            assert key not in possible_columns, f"duplicated key {key}"
277            possible_columns[key] = Column(key, **kwargs)
278
279        if Flag.APPNAME & flag:
280            add_column(
281                key="application_name",
282                name="APP",
283                template_h="%16s ",
284                max_width=16,
285            )
286        if Flag.CLIENT & flag:
287            add_column(
288                key="client",
289                name="CLIENT",
290                template_h="%16s ",
291                max_width=16,
292            )
293        if Flag.CPU & flag:
294            add_column(
295                key="cpu",
296                name="CPU%",
297                template_h="%6s ",
298                sort_key=SortKey.cpu,
299            )
300        if Flag.DATABASE & flag:
301            add_column(
302                key="database",
303                name="DATABASE",
304                template_h=f"%-{max_db_length}s ",
305                transform=functools.lru_cache()(
306                    lambda v: utils.ellipsis(v, width=16) if v else "",
307                ),
308                sort_key=None,
309            )
310        if Flag.IOWAIT & flag:
311            add_column(
312                key="io_wait",
313                name="IOW",
314                template_h="%4s ",
315                transform=utils.yn,
316                color_key=colors.wait,
317            )
318        if Flag.MEM & flag:
319            add_column(
320                key="mem",
321                name="MEM%",
322                template_h="%4s ",
323                sort_key=SortKey.mem,
324                transform=lambda v: str(round(v, 1)),
325            )
326        if Flag.MODE & flag:
327            add_column(
328                key="mode",
329                name="MODE",
330                template_h="%16s ",
331                max_width=16,
332                color_key=colors.lock_mode,
333            )
334        if Flag.PID & flag:
335            add_column(key="pid", name="PID", template_h="%-6s ")
336        add_column(key="query", name="Query", template_h=" %2s")
337        if Flag.READ & flag:
338            add_column(
339                key="read",
340                name="READ/s",
341                template_h="%8s ",
342                sort_key=SortKey.read,
343                transform=utils.naturalsize,
344            )
345        if Flag.RELATION & flag:
346            add_column(
347                key="relation",
348                name="RELATION",
349                template_h="%9s ",
350                max_width=9,
351            )
352        add_column(
353            key="state",
354            name="state",
355            template_h=" %17s  ",
356            transform=utils.short_state,
357            color_key=colors.short_state,
358        )
359        if Flag.TIME & flag:
360            add_column(
361                key="duration",
362                name="TIME+",
363                template_h="%9s ",
364                sort_key=SortKey.duration,
365                transform=lambda v: utils.format_duration(v)[0],
366                color_key=lambda v: utils.format_duration(v)[1],
367            )
368        if Flag.TYPE & flag:
369            add_column(key="type", name="TYPE", template_h="%16s ", max_width=16)
370        if Flag.USER & flag:
371            add_column(key="user", name="USER", template_h="%16s ", max_width=16)
372        if Flag.WAIT & flag:
373            add_column(
374                key="wait",
375                name="Waiting",
376                template_h="%16s ",
377                transform=utils.wait_status,
378                color_key=colors.wait,
379                max_width=16,
380            )
381        if Flag.WRITE & flag:
382            add_column(
383                key="write",
384                name="WRITE/s",
385                template_h="%8s ",
386                sort_key=SortKey.write,
387                transform=utils.naturalsize,
388            )
389
390        columns_key_by_querymode: Mapping[QueryMode, List[str]] = {
391            QueryMode.activities: [
392                "pid",
393                "database",
394                "application_name",
395                "user",
396                "client",
397                "cpu",
398                "mem",
399                "read",
400                "write",
401                "duration",
402                "wait",
403                "io_wait",
404                "state",
405                "query",
406            ],
407            QueryMode.waiting: [
408                "pid",
409                "database",
410                "application_name",
411                "user",
412                "client",
413                "relation",
414                "type",
415                "mode",
416                "duration",
417                "state",
418                "query",
419            ],
420            QueryMode.blocking: [
421                "pid",
422                "database",
423                "application_name",
424                "user",
425                "client",
426                "relation",
427                "type",
428                "mode",
429                "duration",
430                "wait",
431                "state",
432                "query",
433            ],
434        }
435
436        def make_columns_for(query_mode: QueryMode) -> Iterator[Column]:
437            for key in columns_key_by_querymode[query_mode]:
438                try:
439                    yield possible_columns[key]
440                except KeyError:
441                    pass
442
443        columns_by_querymode = {qm: tuple(make_columns_for(qm)) for qm in QueryMode}
444        return cls(columns_by_querymode=columns_by_querymode, **kwargs)
445
446    def interactive(self) -> bool:
447        return self.interactive_timeout is not None
448
449    def start_interactive(self) -> None:
450        """Start interactive mode.
451
452        >>> ui = UI.make()
453        >>> ui.start_interactive()
454        >>> ui.interactive_timeout
455        3
456        """
457        self.interactive_timeout = 3
458
459    def end_interactive(self) -> None:
460        """End interactive mode.
461
462        >>> ui = UI.make()
463        >>> ui.start_interactive()
464        >>> ui.interactive_timeout
465        3
466        >>> ui.end_interactive()
467        >>> ui.interactive_timeout
468        """
469        self.interactive_timeout = None
470
471    def tick_interactive(self) -> None:
472        """End interactive mode.
473
474        >>> ui = UI.make()
475        >>> ui.tick_interactive()
476        Traceback (most recent call last):
477            ...
478        RuntimeError: cannot tick interactive mode
479        >>> ui.start_interactive()
480        >>> ui.interactive_timeout
481        3
482        >>> ui.tick_interactive()
483        >>> ui.interactive_timeout
484        2
485        >>> ui.tick_interactive()
486        >>> ui.interactive_timeout
487        1
488        >>> ui.tick_interactive()
489        >>> ui.interactive_timeout
490        >>> ui.tick_interactive()
491        Traceback (most recent call last):
492            ...
493        RuntimeError: cannot tick interactive mode
494        """
495        if self.interactive_timeout is None:
496            raise RuntimeError("cannot tick interactive mode")
497        assert self.interactive_timeout > 0, self.interactive_timeout
498        self.interactive_timeout = (self.interactive_timeout - 1) or None
499
500    def toggle_pause(self) -> None:
501        """Toggle 'in_pause' attribute.
502
503        >>> ui = UI.make()
504        >>> ui.in_pause
505        False
506        >>> ui.toggle_pause()
507        >>> ui.in_pause
508        True
509        >>> ui.toggle_pause()
510        >>> ui.in_pause
511        False
512        """
513        self.in_pause = not self.in_pause
514
515    def evolve(self, **changes: Any) -> None:
516        """Return a new UI with 'changes' applied.
517
518        >>> ui = UI.make()
519        >>> ui.query_mode.value
520        'running queries'
521        >>> ui.evolve(query_mode=QueryMode.blocking, sort_key=SortKey.write)
522        >>> ui.query_mode.value
523        'blocking queries'
524        >>> ui.sort_key.name
525        'write'
526        """
527        if self.in_pause:
528            return
529        forbidden = set(changes) - {
530            "duration_mode",
531            "query_display_mode",
532            "sort_key",
533            "query_mode",
534            "refresh_time",
535        }
536        assert not forbidden, forbidden
537        fields = attr.fields(self.__class__)
538        for field_name, value in changes.items():
539            field = getattr(fields, field_name)
540            if field.converter:
541                value = field.converter(value)
542            setattr(self, field_name, value)
543
544    def column(self, key: str) -> Column:
545        """Return the column matching 'key'.
546
547        >>> ui = UI.make()
548        >>> ui.column("cpu")  # doctest: +ELLIPSIS
549        Column(name='CPU%', template_h='%6s ', mandatory=False, sort_key=...)
550        >>> ui.column("gloups")
551        Traceback (most recent call last):
552          ...
553        ValueError: gloups
554        """
555        for column in self.columns_by_querymode[self.query_mode]:
556            if column.key == key:
557                return column
558        else:
559            raise ValueError(key)
560
561    def columns(self) -> Tuple[Column, ...]:
562        """Return the tuple of Column for current mode.
563
564        >>> flag = Flag.PID | Flag.DATABASE | Flag.APPNAME | Flag.RELATION
565        >>> ui = UI.make(flag=flag)
566        >>> [c.name for c in ui.columns()]
567        ['PID', 'DATABASE', 'APP', 'state', 'Query']
568        """
569        return self.columns_by_querymode[self.query_mode]
570
571
572@attr.s(auto_attribs=True, frozen=True, slots=True)
573class Host:
574    hostname: str
575    user: str
576    host: str
577    port: int
578    dbname: str
579
580
581@attr.s(auto_attribs=True, slots=True)
582class DBInfo:
583    total_size: int
584    size_ev: int
585
586
587@attr.s(auto_attribs=True, slots=True)
588class MemoryInfo:
589    percent: float
590    used: int
591    total: int
592
593    @classmethod
594    def default(cls) -> "MemoryInfo":
595        return cls(0.0, 0, 0)
596
597
598@attr.s(auto_attribs=True, slots=True)
599class LoadAverage:
600    avg1: float
601    avg5: float
602    avg15: float
603
604    @classmethod
605    def default(cls) -> "LoadAverage":
606        return cls(0.0, 0.0, 0.0)
607
608
609@attr.s(auto_attribs=True, frozen=True, slots=True)
610class IOCounter:
611    count: int
612    bytes: int
613    chars: int = 0
614
615    @classmethod
616    def default(cls) -> "IOCounter":
617        return cls(0, 0)
618
619
620@attr.s(auto_attribs=True, frozen=True, slots=True)
621class SystemInfo:
622    memory: MemoryInfo
623    swap: MemoryInfo
624    load: LoadAverage
625    io_read: IOCounter
626    io_write: IOCounter
627    max_iops: int = 0
628
629    @classmethod
630    def default(
631        cls,
632        *,
633        memory: Optional[MemoryInfo] = None,
634        swap: Optional[MemoryInfo] = None,
635        load: Optional[LoadAverage] = None,
636    ) -> "SystemInfo":
637        """Zero-value builder.
638
639        >>> SystemInfo.default()  # doctest: +NORMALIZE_WHITESPACE
640        SystemInfo(memory=MemoryInfo(percent=0.0, used=0, total=0),
641                   swap=MemoryInfo(percent=0.0, used=0, total=0),
642                   load=LoadAverage(avg1=0.0, avg5=0.0, avg15=0.0),
643                   io_read=IOCounter(count=0, bytes=0, chars=0),
644                   io_write=IOCounter(count=0, bytes=0, chars=0),
645                   max_iops=0)
646        """
647        return cls(
648            memory or MemoryInfo.default(),
649            swap or MemoryInfo.default(),
650            load or LoadAverage.default(),
651            IOCounter.default(),
652            IOCounter.default(),
653        )
654
655
656class LockType(enum.Enum):
657    """Type of lockable object
658
659    https://www.postgresql.org/docs/current/view-pg-locks.html
660    """
661
662    relation = enum.auto()
663    extend = enum.auto()
664    page = enum.auto()
665    tuple = enum.auto()
666    transactionid = enum.auto()
667    virtualxid = enum.auto()
668    object = enum.auto()
669    userlock = enum.auto()
670    advisory = enum.auto()
671
672    def __str__(self) -> str:
673        # Custom str(self) for transparent rendering in views.
674        return self.name
675
676
677def locktype(value: str) -> LockType:
678    try:
679        return LockType[value]
680    except KeyError as exc:
681        raise ValueError(f"invalid lock type {exc}") from None
682
683
684@attr.s(auto_attribs=True, slots=True)
685class BaseProcess:
686    pid: int
687    application_name: str
688    database: Optional[str]
689    user: str
690    client: str
691    duration: Optional[float]
692    state: str
693    query: Optional[str]
694    is_parallel_worker: bool
695
696
697@attr.s(auto_attribs=True, frozen=True, slots=True)
698class RunningProcess(BaseProcess):
699    """Process for a running query."""
700
701    wait: Union[bool, None, str]
702    is_parallel_worker: bool
703
704
705@attr.s(auto_attribs=True, frozen=True, slots=True)
706class WaitingProcess(BaseProcess):
707    """Process for a waiting query."""
708
709    # Lock information from pg_locks view
710    # https://www.postgresql.org/docs/current/view-pg-locks.html
711    mode: str
712    type: LockType = attr.ib(converter=locktype)
713    relation: str
714
715    # TODO: update queries to select/compute this column.
716    is_parallel_worker: bool = attr.ib(default=False, init=False)
717
718
719@attr.s(auto_attribs=True, frozen=True, slots=True)
720class BlockingProcess(BaseProcess):
721    """Process for a blocking query."""
722
723    # Lock information from pg_locks view
724    # https://www.postgresql.org/docs/current/view-pg-locks.html
725    mode: str
726    type: LockType = attr.ib(converter=locktype)
727    relation: str
728    wait: Union[bool, None, str]
729
730    # TODO: update queries to select/compute this column.
731    is_parallel_worker: bool = attr.ib(default=False, init=False)
732
733
734@attr.s(auto_attribs=True, frozen=True, slots=True)
735class SystemProcess:
736    meminfo: Tuple[int, ...]
737    io_read: IOCounter
738    io_write: IOCounter
739    io_time: float
740    mem_percent: float
741    cpu_percent: float
742    cpu_times: Tuple[float, ...]
743    read_delta: float
744    write_delta: float
745    io_wait: bool
746    psutil_proc: Optional[psutil.Process]
747
748
749@attr.s(auto_attribs=True, frozen=True, slots=True)
750class LocalRunningProcess(RunningProcess):
751    cpu: float
752    mem: float
753    read: float
754    write: float
755    io_wait: bool
756
757    @classmethod
758    def from_process(
759        cls, process: RunningProcess, **kwargs: Union[float, str]
760    ) -> "LocalRunningProcess":
761        return cls(**dict(attr.asdict(process), **kwargs))
762
763
764@attr.s(auto_attribs=True, slots=True)
765class SelectableProcesses:
766    """Selectable list of processes.
767
768    >>> @attr.s(auto_attribs=True)
769    ... class Proc:
770    ...     pid: int
771
772    >>> w = SelectableProcesses(list(map(Proc, [456, 123, 789])))
773    >>> len(w)
774    3
775
776    Nothing focused at initialization:
777    >>> w.focused
778
779    >>> w.focus_next()
780    >>> w.focused
781    456
782    >>> w.focus_next()
783    >>> w.focused
784    123
785    >>> w.focus_prev()
786    >>> w.focused
787    456
788    >>> w.focus_prev()
789    >>> w.focused
790    789
791    >>> w.focused = 789
792    >>> w.focus_next()
793    >>> w.focused
794    456
795    >>> w.focus_prev()
796    >>> w.focused
797    789
798    >>> w.set_items(sorted(w.items))
799    >>> w.focused
800    789
801    >>> w.focus_prev()
802    >>> w.focused
803    456
804    >>> w.focus_next()
805    >>> w.focused
806    789
807
808    >>> w.selected, w.focused
809    ([789], 789)
810    >>> w.toggle_pin_focused()
811    >>> w.focus_next()
812    >>> w.toggle_pin_focused()
813    >>> w.selected, w.focused
814    ([123, 789], 123)
815    >>> w.toggle_pin_focused()
816    >>> w.focus_next()
817    >>> w.toggle_pin_focused()
818    >>> w.selected, w.focused
819    ([456, 789], 456)
820    >>> w.reset()
821    >>> w.selected, w.focused
822    ([], None)
823    """
824
825    items: List[BaseProcess]
826    focused: Optional[int] = None
827    pinned: MutableSet[int] = attr.ib(default=attr.Factory(set))
828
829    def __len__(self) -> int:
830        return len(self.items)
831
832    def __iter__(self) -> Iterator[BaseProcess]:
833        return iter(self.items)
834
835    @property
836    def selected(self) -> List[int]:
837        if self.pinned:
838            return list(self.pinned)
839        elif self.focused:
840            return [self.focused]
841        else:
842            return []
843
844    def reset(self) -> None:
845        self.focused = None
846        self.pinned.clear()
847
848    def set_items(self, new_items: Sequence[BaseProcess]) -> None:
849        self.items[:] = list(new_items)
850
851    def _position(self) -> Optional[int]:
852        if self.focused is None:
853            return None
854        for idx, proc in enumerate(self.items):
855            if proc.pid == self.focused:
856                return idx
857        return None
858
859    def focus_next(self) -> None:
860        if not self.items:
861            return
862        idx = self._position()
863        if idx is None:
864            next_idx = 0
865        elif idx == len(self.items) - 1:
866            next_idx = 0
867        else:
868            next_idx = idx + 1
869        self.focused = self.items[next_idx].pid
870
871    def focus_prev(self) -> None:
872        if not self.items:
873            return
874        idx = self._position() or 0
875        self.focused = self.items[idx - 1].pid
876
877    def toggle_pin_focused(self) -> None:
878        assert self.focused is not None
879        try:
880            self.pinned.remove(self.focused)
881        except KeyError:
882            self.pinned.add(self.focused)
883
884
885ActivityStats = Union[
886    Iterable[WaitingProcess],
887    Iterable[RunningProcess],
888    Tuple[Iterable[WaitingProcess], SystemInfo],
889    Tuple[Iterable[BlockingProcess], SystemInfo],
890    Tuple[Iterable[LocalRunningProcess], SystemInfo],
891]
892