1# Copyright 2020 Microsoft Corporation
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#
15# Requires Python 2.6+ and Openssl 1.0+
16#
17
18import json
19import re
20import uuid
21import os
22import contextlib
23import subprocess
24
25import azurelinuxagent.common.conf as conf
26from azurelinuxagent.common.future import ustr
27from azurelinuxagent.common.utils import fileutil
28from azurelinuxagent.ga.exthandlers import ExtHandlerInstance
29
30from tests.tools import Mock, patch
31from tests.protocol.mocks import HttpRequestPredicates
32
33
34class ExtensionCommandNames(object):
35    INSTALL = "install"
36    UNINSTALL = "uninstall"
37    UPDATE = "update"
38    ENABLE = "enable"
39    DISABLE = "disable"
40
41
42class Actions(object):
43    """
44    A collection of static methods providing some basic functionality for the ExtensionEmulator
45    class' actions.
46    """
47
48    @staticmethod
49    def succeed_action(*_, **__):
50        """
51        A nop action with the correct function signature for ExtensionEmulator actions.
52        """
53        return 0
54
55    @staticmethod
56    def generate_unique_fail():
57        """
58        Utility function for tracking the return code of a command. Returns both a
59        unique return code, and a function pointer which returns said return code.
60        """
61        return_code = str(uuid.uuid4())
62
63        def fail_action(*_, **__):
64            return return_code
65
66        return return_code, fail_action
67
68
69def extension_emulator(name="OSTCExtensions.ExampleHandlerLinux", version="1.0.0",
70    update_mode="UpdateWithInstall", report_heartbeat=False, continue_on_update_failure=False,
71    install_action=Actions.succeed_action, uninstall_action=Actions.succeed_action,
72    enable_action=Actions.succeed_action, disable_action=Actions.succeed_action,
73    update_action=Actions.succeed_action):
74    """
75    Factory method for ExtensionEmulator objects with sensible defaults.
76    """
77    # Linter reports too many arguments, but this isn't an issue because all are defaulted;
78    # no caller will have to actually provide all of the arguments listed.
79
80    return ExtensionEmulator(name, version, update_mode, report_heartbeat, continue_on_update_failure,
81        install_action, uninstall_action, enable_action, disable_action, update_action)
82
83@contextlib.contextmanager
84def enable_invocations(*emulators):
85    """
86    Allows ExtHandlersHandler objects to call the specified emulators and keeps
87    track of the order of those invocations. Returns the invocation record.
88
89    Note that this method patches subprocess.Popen and
90    ExtHandlerInstance.load_manifest.
91    """
92    invocation_record = InvocationRecord()
93
94    patched_popen = generate_patched_popen(invocation_record, *emulators)
95    patched_load_manifest = generate_mock_load_manifest(*emulators)
96
97    with patch.object(ExtHandlerInstance, "load_manifest", patched_load_manifest):
98        with patch("subprocess.Popen", patched_popen):
99            yield invocation_record
100
101def generate_put_handler(*emulators):
102    """
103    Create a HTTP handler to store status blobs for each provided emulator.
104    For use with tests.protocol.mocks.mock_wire_protocol.
105    """
106
107    def mock_put_handler(url, *args, **_):
108
109        if HttpRequestPredicates.is_host_plugin_status_request(url):
110            return
111
112        handler_statuses = json.loads(args[0]).get("aggregateStatus", {}).get("handlerAggregateStatus", [])
113
114        for handler_status in handler_statuses:
115            supplied_name = handler_status.get("handlerName", None)
116            supplied_version = handler_status.get("handlerVersion", None)
117
118            try:
119                matching_ext = _first_matching_emulator(emulators, supplied_name, supplied_version)
120                matching_ext.status_blobs.append(handler_status)
121
122            except StopIteration:
123                # Tests will want to know that the agent is running an extension they didn't specifically allocate.
124                raise Exception("Extension running, but not present in emulators: {0}, {1}".format(supplied_name, supplied_version))
125
126    return mock_put_handler
127
128class InvocationRecord:
129
130    def __init__(self):
131        self._queue = []
132
133    def add(self, ext_name, ext_ver, ext_cmd):
134        self._queue.append((ext_name, ext_ver, ext_cmd))
135
136    def compare(self, *expected_cmds):
137        """
138        Verifies that any and all recorded invocations appear in the provided command list in that exact ordering.
139
140        Each cmd in expected_cmds should be a tuple of the form (ExtensionEmulator, ExtensionCommandNames object).
141        """
142
143        for (expected_ext_emulator, command_name) in expected_cmds:
144
145            try:
146                (ext_name, ext_ver, ext_cmd) = self._queue.pop(0)
147
148                if not expected_ext_emulator.matches(ext_name, ext_ver) or command_name != ext_cmd:
149                    raise Exception("Unexpected invocation: have ({0}, {1}, {2}), but expected ({3}, {4}, {5})".format(
150                        ext_name, ext_ver, ext_cmd, expected_ext_emulator.name, expected_ext_emulator.version, command_name
151                    ))
152
153            except IndexError:
154                raise Exception("No more invocations recorded. Expected ({0}, {1}, {2}).".format(expected_ext_emulator.name,
155                    expected_ext_emulator.version, command_name))
156
157        if self._queue:
158            raise Exception("Invocation recorded, but not expected: ({0}, {1}, {2})".format(
159                *self._queue[0]
160            ))
161
162def _first_matching_emulator(emulators, name, version):
163    for ext in emulators:
164        if ext.matches(name, version):
165            return ext
166
167    raise StopIteration
168
169class ExtensionEmulator:
170    """
171    A wrapper class for the possible actions and options that an extension might support.
172    """
173
174    def __init__(self, name, version,
175        update_mode, report_heartbeat,
176        continue_on_update_failure,
177        install_action, uninstall_action,
178        enable_action, disable_action,
179        update_action):
180        # Linter reports too many arguments, but this constructor has its access mediated by
181        # a factory method; the calls affected by the number of arguments here is very
182        # limited in scope.
183
184        self.name = name
185        self.version = version
186
187        self.update_mode = update_mode
188        self.report_heartbeat = report_heartbeat
189        self.continue_on_update_failure = continue_on_update_failure
190
191        self._actions = {
192            ExtensionCommandNames.INSTALL: ExtensionEmulator._extend_func(install_action),
193            ExtensionCommandNames.UNINSTALL: ExtensionEmulator._extend_func(uninstall_action),
194            ExtensionCommandNames.UPDATE: ExtensionEmulator._extend_func(update_action),
195            ExtensionCommandNames.ENABLE: ExtensionEmulator._extend_func(enable_action),
196            ExtensionCommandNames.DISABLE: ExtensionEmulator._extend_func(disable_action)
197        }
198
199        self._status_blobs = []
200
201    @property
202    def actions(self):
203        """
204        A read-only property designed to allow inspection for the emulated extension's
205        actions. `actions` maps an ExtensionCommandNames object to a mock wrapping the
206        function this emulator was initialized with.
207        """
208        return self._actions
209
210    @property
211    def status_blobs(self):
212        """
213        A property for storing and retreiving the status blobs for the extension this object
214        is emulating that are uploaded to the HTTP PUT /status endpoint.
215        """
216        return self._status_blobs
217
218    @staticmethod
219    def _extend_func(func):
220        """
221        Convert a function such that its returned value mimicks a Popen object (i.e. with
222        correct return values for poll() and wait() calls).
223        """
224
225        def wrapped_func(cmd, *args, **kwargs):
226            return_value = func(cmd, *args, **kwargs)
227
228            config_dir = os.path.join(os.path.dirname(cmd), "config")
229
230            regex = r'{directory}{sep}(?P<seq>{sequence})\.settings'.format(
231                directory=config_dir, sep=os.path.sep, sequence=r'[0-9]+'
232            )
233
234            seq = 0
235            for config_file in map(lambda filename: os.path.join(config_dir, filename), os.listdir(config_dir)):
236                if not os.path.isfile(config_file):
237                    continue
238
239                match = re.match(regex, config_file)
240                if not match:
241                    continue
242
243                if seq < int(match.group("seq")):
244                    seq = int(match.group("seq"))
245
246            status_file = os.path.join(os.path.dirname(cmd), "status", "{seq}.status".format(seq=seq))
247
248            if return_value == 0:
249                status_contents = [{ "status": {"status": "success"} }]
250            else:
251                status_contents = [{ "status": {"status": "error", "substatus": {"exit_code": return_value}} }]
252
253            fileutil.write_file(status_file, json.dumps(status_contents))
254
255            return Mock(**{
256                "poll.return_value": return_value,
257                "wait.return_value": return_value
258            })
259
260        # Wrap the function in a mock to allow invocation reflection a la .assert_not_called(), etc.
261        return Mock(wraps=wrapped_func)
262
263
264    def matches(self, name, version):
265        return self.name == name and self.version == version
266
267def generate_patched_popen(invocation_record, *emulators):
268    """
269    Create a mock popen function able to invoke the proper action for an extension
270    emulator in emulators.
271    """
272    original_popen = subprocess.Popen
273
274    def patched_popen(cmd, *args, **kwargs):
275
276        try:
277            ext_name, ext_version, command_name = _extract_extension_info_from_command(cmd)
278            invocation_record.add(ext_name, ext_version, command_name)
279        except ValueError:
280            return original_popen(cmd, *args, **kwargs)
281
282        try:
283            matching_ext = _first_matching_emulator(emulators, ext_name, ext_version)
284
285            return matching_ext.actions[command_name](cmd, *args, **kwargs)
286
287        except StopIteration:
288            raise Exception("Extension('{name}', '{version}') not listed as a parameter. Is it being emulated?".format(
289                name=ext_name, version=ext_version
290            ))
291
292    return patched_popen
293
294def generate_mock_load_manifest(*emulators):
295
296    original_load_manifest = ExtHandlerInstance.load_manifest
297
298    def mock_load_manifest(self):
299
300        try:
301            matching_emulator = _first_matching_emulator(emulators, self.ext_handler.name, self.ext_handler.properties.version)
302        except StopIteration:
303            raise Exception("Extension('{name}', '{version}') not listed as a parameter. Is it being emulated?".format(
304                name=self.ext_handler.name, version=self.ext_handler.properties.version
305            ))
306
307        base_manifest = original_load_manifest(self)
308
309        base_manifest.data["handlerManifest"].update({
310            "continueOnUpdateFailure": matching_emulator.continue_on_update_failure,
311            "reportHeartbeat": matching_emulator.report_heartbeat,
312            "updateMode": matching_emulator.update_mode
313        })
314
315        return base_manifest
316
317    return mock_load_manifest
318
319def _extract_extension_info_from_command(command):
320    """
321    Parse a command into a tuple of extension info.
322    """
323    if not isinstance(command, (str, ustr)):
324        raise Exception("Cannot extract extension info from non-string commands")
325
326    # Group layout of the expected command; this lets us grab what we want after a match
327    template = r'(?<={base_dir}/)(?P<name>{ext_name})-(?P<ver>{ext_ver})(?:/{script_file} -)(?P<cmd>{ext_cmd})'
328
329    base_dir_regex = conf.get_lib_dir()
330    script_file_regex = r'[^\s]+'
331    ext_cmd_regex = r'[a-zA-Z]+'
332    ext_name_regex = r'[a-zA-Z]+(?:\.[a-zA-Z]+)?'
333    ext_ver_regex = r'[0-9]+(?:\.[0-9]+)*'
334
335    full_regex = template.format(
336        ext_name=ext_name_regex,
337        ext_ver=ext_ver_regex,
338        base_dir=base_dir_regex, script_file=script_file_regex,
339        ext_cmd=ext_cmd_regex
340    )
341
342    match_obj = re.search(full_regex, command)
343
344    if not match_obj:
345        raise ValueError("Command does not match the desired format: {0}".format(command))
346
347    return match_obj.group('name', 'ver', 'cmd')