1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import time
7
8from msrest import Deserializer
9from msrest.polling import PollingMethod, LROPoller
10from msrestazure.azure_exceptions import CloudError
11
12from ._constants import get_acr_models, get_finished_run_status, get_succeeded_run_status
13
14
15def get_run_with_polling(cmd,
16                         client,
17                         run_id,
18                         registry_name,
19                         resource_group_name):
20    deserializer = Deserializer(
21        {k: v for k, v in get_acr_models(cmd).__dict__.items() if isinstance(v, type)})
22
23    def deserialize_run(response):
24        return deserializer('Run', response)
25
26    return LROPoller(
27        client=client,
28        initial_response=client.get(
29            resource_group_name, registry_name, run_id, raw=True),
30        deserialization_callback=deserialize_run,
31        polling_method=RunPolling(
32            cmd=cmd,
33            registry_name=registry_name,
34            run_id=run_id
35        ))
36
37
38class RunPolling(PollingMethod):  # pylint: disable=too-many-instance-attributes
39
40    def __init__(self, cmd, registry_name, run_id, timeout=30):
41        self._cmd = cmd
42        self._registry_name = registry_name
43        self._run_id = run_id
44        self._timeout = timeout
45        self._client = None
46        self._response = None  # Will hold latest received response
47        self._url = None  # The URL used to get the run
48        self._deserialize = None  # The deserializer for Run
49        self.operation_status = ""
50        self.operation_result = None
51
52    def initialize(self, client, initial_response, deserialization_callback):
53        self._client = client
54        self._response = initial_response
55        self._url = initial_response.request.url
56        self._deserialize = deserialization_callback
57
58        self._set_operation_status(initial_response)
59
60    def run(self):
61        while not self.finished():
62            time.sleep(self._timeout)
63            self._update_status()
64
65        if self.operation_status not in get_succeeded_run_status(self._cmd):
66            from knack.util import CLIError
67            raise CLIError("The run with ID '{}' finished with unsuccessful status '{}'. "
68                           "Show run details by 'az acr task show-run -r {} --run-id {}'. "
69                           "Show run logs by 'az acr task logs -r {} --run-id {}'.".format(
70                               self._run_id,
71                               self.operation_status,
72                               self._registry_name,
73                               self._run_id,
74                               self._registry_name,
75                               self._run_id
76                           ))
77
78    def status(self):
79        return self.operation_status
80
81    def finished(self):
82        return self.operation_status in get_finished_run_status(self._cmd)
83
84    def resource(self):
85        return self.operation_result
86
87    def _set_operation_status(self, response):
88        RunStatus = self._cmd.get_models('RunStatus')
89        if response.status_code == 200:
90            self.operation_result = self._deserialize(response)
91            self.operation_status = self.operation_result.status or RunStatus.queued.value
92            return
93        raise CloudError(response)
94
95    def _update_status(self):
96        self._response = self._client.send(
97            self._client.get(self._url), stream=False)
98        self._set_operation_status(self._response)
99