1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#===----------------------------------------------------------------------===##
4#
5# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
6# See https://llvm.org/LICENSE.txt for license information.
7# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
8#
9#===----------------------------------------------------------------------===##
10"""Tests for revert_checker.
11
12Note that these tests require having LLVM's git history available, since our
13repository has a few interesting instances of edge-cases.
14"""
15
16import os
17import logging
18import unittest
19from typing import List
20
21import revert_checker
22
23# pylint: disable=protected-access
24
25
26def get_llvm_project_path() -> str:
27  """Returns the path to llvm-project's root."""
28  my_dir = os.path.dirname(__file__)
29  return os.path.realpath(os.path.join(my_dir, '..', '..'))
30
31
32class _SilencingFilter(logging.Filter):
33  """Silences all log messages.
34
35  Also collects info about log messages that would've been emitted.
36  """
37
38  def __init__(self) -> None:
39    self.messages: List[str] = []
40
41  def filter(self, record: logging.LogRecord) -> bool:
42    self.messages.append(record.getMessage())
43    return False
44
45
46class Test(unittest.TestCase):
47  """Tests for revert_checker."""
48
49  def silence_logging(self) -> _SilencingFilter:
50    root = logging.getLogger()
51    filt = _SilencingFilter()
52    root.addFilter(filt)
53    self.addCleanup(root.removeFilter, filt)
54    return filt
55
56  def test_log_stream_with_known_sha_range(self) -> None:
57    start_sha = 'e241573d5972d34a323fa5c64774c4207340beb3'
58    end_sha = 'a7a37517751ffb0f5529011b4ba96e67fcb27510'
59    commits = [
60        revert_checker._LogEntry(
61            'e241573d5972d34a323fa5c64774c4207340beb3', '\n'.join((
62                '[mlir] NFC: remove IntegerValueSet / MutableIntegerSet',
63                '',
64                'Summary:',
65                '- these are unused and really not needed now given flat '
66                'affine',
67                '  constraints',
68                '',
69                'Differential Revision: https://reviews.llvm.org/D75792',
70            ))),
71        revert_checker._LogEntry(
72            '97572fa6e9daecd648873496fd11f7d1e25a55f0',
73            '[NFC] use hasAnyOperatorName and hasAnyOverloadedOperatorName '
74            'functions in clang-tidy matchers',
75        ),
76    ]
77
78    logs = list(
79        revert_checker._log_stream(
80            get_llvm_project_path(),
81            root_sha=start_sha,
82            end_at_sha=end_sha,
83        ))
84    self.assertEqual(commits, logs)
85
86  def test_reverted_noncommit_object_is_a_nop(self) -> None:
87    log_filter = self.silence_logging()
88    # c9944df916e41b1014dff5f6f75d52297b48ecdc mentions reverting a non-commit
89    # object. It sits between the given base_ref and root.
90    reverts = revert_checker.find_reverts(
91        git_dir=get_llvm_project_path(),
92        across_ref='c9944df916e41b1014dff5f6f75d52297b48ecdc~',
93        root='c9944df916e41b1014dff5f6f75d52297b48ecdc')
94    self.assertEqual(reverts, [])
95
96    complaint = ('Failed to resolve reverted object '
97                 'edd18355be574122aaa9abf58c15d8c50fb085a1')
98    self.assertTrue(
99        any(x.startswith(complaint) for x in log_filter.messages),
100        log_filter.messages)
101
102  def test_known_reverts_across_arbitrary_llvm_rev(self) -> None:
103    reverts = revert_checker.find_reverts(
104        git_dir=get_llvm_project_path(),
105        across_ref='c47f971694be0159ffddfee8a75ae515eba91439',
106        root='9f981e9adf9c8d29bb80306daf08d2770263ade6')
107    self.assertEqual(reverts, [
108        revert_checker.Revert(
109            sha='4e0fe038f438ae1679eae9e156e1f248595b2373',
110            reverted_sha='65b21282c710afe9c275778820c6e3c1cf46734b'),
111        revert_checker.Revert(
112            sha='9f981e9adf9c8d29bb80306daf08d2770263ade6',
113            reverted_sha='4060016fce3e6a0b926ee9fc59e440a612d3a2ec'),
114    ])
115
116
117if __name__ == '__main__':
118  unittest.main()
119