1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6import argparse
7
8from azure.cli.command_modules.monitor.util import (
9    get_aggregation_map, get_operator_map, get_autoscale_scale_direction_map)
10
11from azure.cli.core.azclierror import InvalidArgumentValueError
12
13
14def timezone_name_type(value):
15    from azure.cli.command_modules.monitor._autoscale_util import AUTOSCALE_TIMEZONES
16    zone = next((x['name'] for x in AUTOSCALE_TIMEZONES if x['name'].lower() == value.lower()), None)
17    if not zone:
18        raise InvalidArgumentValueError(
19            "Invalid time zone: '{}'. Run 'az monitor autoscale profile list-timezones' for values.".format(value))
20    return zone
21
22
23def timezone_offset_type(value):
24
25    try:
26        hour, minute = str(value).split(':')
27    except ValueError:
28        hour = str(value)
29        minute = None
30
31    hour = int(hour)
32
33    if hour > 14 or hour < -12:
34        raise InvalidArgumentValueError('Offset out of range: -12 to +14')
35
36    if 0 <= hour < 10:
37        value = '+0{}'.format(hour)
38    elif hour >= 10:
39        value = '+{}'.format(hour)
40    elif -10 < hour < 0:
41        value = '-0{}'.format(-1 * hour)
42    else:
43        value = str(hour)
44    if minute:
45        value = '{}:{}'.format(value, minute)
46    return value
47
48
49def get_period_type(as_timedelta=False):
50
51    def period_type(value):
52
53        import re
54
55        def _get_substring(indices):
56            if indices == tuple([-1, -1]):
57                return ''
58            return value[indices[0]: indices[1]]
59
60        regex = r'(p)?(\d+y)?(\d+m)?(\d+d)?(t)?(\d+h)?(\d+m)?(\d+s)?'
61        match = re.match(regex, value.lower())
62        match_len = match.span(0)
63        if match_len != tuple([0, len(value)]):
64            raise ValueError('PERIOD should be of the form "##h##m##s" or ISO8601')
65        # simply return value if a valid ISO8601 string is supplied
66        if match.span(1) != tuple([-1, -1]) and match.span(5) != tuple([-1, -1]):
67            return value
68
69        # if shorthand is used, only support days, minutes, hours, seconds
70        # ensure M is interpretted as minutes
71        days = _get_substring(match.span(4))
72        hours = _get_substring(match.span(6))
73        minutes = _get_substring(match.span(7)) or _get_substring(match.span(3))
74        seconds = _get_substring(match.span(8))
75
76        if as_timedelta:
77            from datetime import timedelta
78            return timedelta(
79                days=int(days[:-1]) if days else 0,
80                hours=int(hours[:-1]) if hours else 0,
81                minutes=int(minutes[:-1]) if minutes else 0,
82                seconds=int(seconds[:-1]) if seconds else 0
83            )
84        return 'P{}T{}{}{}'.format(days, minutes, hours, seconds).upper()
85
86    return period_type
87
88
89# pylint: disable=protected-access, too-few-public-methods
90class MetricAlertConditionAction(argparse._AppendAction):
91
92    def __call__(self, parser, namespace, values, option_string=None):
93        # antlr4 is not available everywhere, restrict the import scope so that commands
94        # that do not need it don't fail when it is absent
95        import antlr4
96
97        from azure.cli.command_modules.monitor.grammar.metric_alert import (
98            MetricAlertConditionLexer, MetricAlertConditionParser, MetricAlertConditionValidator)
99        from azure.mgmt.monitor.models import MetricCriteria, DynamicMetricCriteria
100
101        usage = 'usage error: --condition {avg,min,max,total,count} [NAMESPACE.]METRIC\n' \
102                '                         [{=,!=,>,>=,<,<=} THRESHOLD]\n' \
103                '                         [{<,>,><} dynamic SENSITIVITY VIOLATION of EVALUATION [since DATETIME]]\n' \
104                '                         [where DIMENSION {includes,excludes} VALUE [or VALUE ...]\n' \
105                '                         [and   DIMENSION {includes,excludes} VALUE [or VALUE ...] ...]]'
106
107        string_val = ' '.join(values)
108
109        lexer = MetricAlertConditionLexer(antlr4.InputStream(string_val))
110        stream = antlr4.CommonTokenStream(lexer)
111        parser = MetricAlertConditionParser(stream)
112        tree = parser.expression()
113
114        try:
115            validator = MetricAlertConditionValidator()
116            walker = antlr4.ParseTreeWalker()
117            walker.walk(validator, tree)
118            metric_condition = validator.result()
119            if isinstance(metric_condition, MetricCriteria):
120                # static metric criteria
121                for item in ['time_aggregation', 'metric_name', 'operator', 'threshold']:
122                    if not getattr(metric_condition, item, None):
123                        raise InvalidArgumentValueError(usage)
124            elif isinstance(metric_condition, DynamicMetricCriteria):
125                # dynamic metric criteria
126                for item in ['time_aggregation', 'metric_name', 'operator', 'alert_sensitivity', 'failing_periods']:
127                    if not getattr(metric_condition, item, None):
128                        raise InvalidArgumentValueError(usage)
129            else:
130                raise NotImplementedError()
131        except (AttributeError, TypeError, KeyError):
132            raise InvalidArgumentValueError(usage)
133        super(MetricAlertConditionAction, self).__call__(parser, namespace, metric_condition, option_string)
134
135
136# pylint: disable=protected-access, too-few-public-methods
137class MetricAlertAddAction(argparse._AppendAction):
138
139    def __call__(self, parser, namespace, values, option_string=None):
140        action_group_id = values[0]
141        try:
142            webhook_property_candidates = dict(x.split('=', 1) for x in values[1:]) if len(values) > 1 else None
143        except ValueError:
144            err_msg = "value of {} is invalid. Please refer to --help to get insight of correct format".format(
145                option_string
146            )
147            raise InvalidArgumentValueError(err_msg)
148
149        from azure.mgmt.monitor.models import MetricAlertAction
150        action = MetricAlertAction(
151            action_group_id=action_group_id,
152            web_hook_properties=webhook_property_candidates
153        )
154        action.odatatype = 'Microsoft.WindowsAzure.Management.Monitoring.Alerts.Models.Microsoft.AppInsights.Nexus.' \
155                           'DataContracts.Resources.ScheduledQueryRules.Action'
156        super(MetricAlertAddAction, self).__call__(parser, namespace, action, option_string)
157
158
159# pylint: disable=too-few-public-methods
160class ConditionAction(argparse.Action):
161    def __call__(self, parser, namespace, values, option_string=None):
162        from azure.mgmt.monitor.models import ThresholdRuleCondition, RuleMetricDataSource
163        # get default description if not specified
164        if namespace.description is None:
165            namespace.description = ' '.join(values)
166        if len(values) == 1:
167            # workaround because CMD.exe eats > character... Allows condition to be
168            # specified as a quoted expression
169            values = values[0].split(' ')
170        if len(values) < 5:
171            raise InvalidArgumentValueError(
172                '--condition METRIC {>,>=,<,<=} THRESHOLD {avg,min,max,total,last} DURATION')
173        metric_name = ' '.join(values[:-4])
174        operator = get_operator_map()[values[-4]]
175        threshold = int(values[-3])
176        aggregation = get_aggregation_map()[values[-2].lower()]
177        window = get_period_type()(values[-1])
178        metric = RuleMetricDataSource(resource_uri=None, metric_name=metric_name)  # target URI will be filled in later
179        condition = ThresholdRuleCondition(
180            operator=operator, threshold=threshold, data_source=metric,
181            window_size=window, time_aggregation=aggregation)
182        namespace.condition = condition
183
184
185# pylint: disable=protected-access
186class AlertAddAction(argparse._AppendAction):
187    def __call__(self, parser, namespace, values, option_string=None):
188        action = self.get_action(values, option_string)
189        super(AlertAddAction, self).__call__(parser, namespace, action, option_string)
190
191    def get_action(self, values, option_string):  # pylint: disable=no-self-use
192        _type = values[0].lower()
193        if _type == 'email':
194            from azure.mgmt.monitor.models import RuleEmailAction
195            return RuleEmailAction(custom_emails=values[1:])
196        if _type == 'webhook':
197            from azure.mgmt.monitor.models import RuleWebhookAction
198            uri = values[1]
199            try:
200                properties = dict(x.split('=', 1) for x in values[2:])
201            except ValueError:
202                raise InvalidArgumentValueError('{} webhook URI [KEY=VALUE ...]'.format(option_string))
203            return RuleWebhookAction(service_uri=uri, properties=properties)
204        raise InvalidArgumentValueError('usage error: {} TYPE KEY [ARGS]'.format(option_string))
205
206
207class AlertRemoveAction(argparse._AppendAction):
208    def __call__(self, parser, namespace, values, option_string=None):
209        action = self.get_action(values, option_string)
210        super(AlertRemoveAction, self).__call__(parser, namespace, action, option_string)
211
212    def get_action(self, values, option_string):  # pylint: disable=no-self-use
213        # TYPE is artificially enforced to create consistency with the --add-action argument
214        # but it could be enhanced to do additional validation in the future.
215        _type = values[0].lower()
216        if _type not in ['email', 'webhook']:
217            raise InvalidArgumentValueError('{} TYPE KEY [KEY ...]'.format(option_string))
218        return values[1:]
219
220
221# pylint: disable=protected-access
222class AutoscaleAddAction(argparse._AppendAction):
223    def __call__(self, parser, namespace, values, option_string=None):
224        action = self.get_action(values, option_string)
225        super(AutoscaleAddAction, self).__call__(parser, namespace, action, option_string)
226
227    def get_action(self, values, option_string):  # pylint: disable=no-self-use
228        _type = values[0].lower()
229        if _type == 'email':
230            from azure.mgmt.monitor.models import EmailNotification
231            return EmailNotification(custom_emails=values[1:])
232        if _type == 'webhook':
233            from azure.mgmt.monitor.models import WebhookNotification
234            uri = values[1]
235            try:
236                properties = dict(x.split('=', 1) for x in values[2:])
237            except ValueError:
238                raise InvalidArgumentValueError('{} webhook URI [KEY=VALUE ...]'.format(option_string))
239            return WebhookNotification(service_uri=uri, properties=properties)
240        raise InvalidArgumentValueError('{} TYPE KEY [ARGS]'.format(option_string))
241
242
243class AutoscaleRemoveAction(argparse._AppendAction):
244    def __call__(self, parser, namespace, values, option_string=None):
245        action = self.get_action(values, option_string)
246        super(AutoscaleRemoveAction, self).__call__(parser, namespace, action, option_string)
247
248    def get_action(self, values, option_string):  # pylint: disable=no-self-use
249        # TYPE is artificially enforced to create consistency with the --add-action argument
250        # but it could be enhanced to do additional validation in the future.
251        _type = values[0].lower()
252        if _type not in ['email', 'webhook']:
253            raise InvalidArgumentValueError('{} TYPE KEY [KEY ...]'.format(option_string))
254        return values[1:]
255
256
257class AutoscaleConditionAction(argparse.Action):  # pylint: disable=protected-access
258    def __call__(self, parser, namespace, values, option_string=None):
259        # antlr4 is not available everywhere, restrict the import scope so that commands
260        # that do not need it don't fail when it is absent
261        import antlr4
262
263        from azure.cli.command_modules.monitor.grammar.autoscale import (
264            AutoscaleConditionLexer, AutoscaleConditionParser, AutoscaleConditionValidator)
265
266        # pylint: disable=line-too-long
267        usage = '--condition ["NAMESPACE"] METRIC {==,!=,>,>=,<,<=} THRESHOLD {avg,min,max,total,count} PERIOD\n' \
268                '            [where DIMENSION {==,!=} VALUE [or VALUE ...]\n' \
269                '            [and   DIMENSION {==,!=} VALUE [or VALUE ...] ...]]'
270
271        string_val = ' '.join(values)
272
273        lexer = AutoscaleConditionLexer(antlr4.InputStream(string_val))
274        stream = antlr4.CommonTokenStream(lexer)
275        parser = AutoscaleConditionParser(stream)
276        tree = parser.expression()
277
278        try:
279            validator = AutoscaleConditionValidator()
280            walker = antlr4.ParseTreeWalker()
281            walker.walk(validator, tree)
282            autoscale_condition = validator.result()
283            for item in ['time_aggregation', 'metric_name', 'threshold', 'operator', 'time_window']:
284                if not getattr(autoscale_condition, item, None):
285                    raise InvalidArgumentValueError(usage)
286        except (AttributeError, TypeError, KeyError):
287            raise InvalidArgumentValueError(usage)
288
289        namespace.condition = autoscale_condition
290
291
292class AutoscaleScaleAction(argparse.Action):  # pylint: disable=protected-access
293    def __call__(self, parser, namespace, values, option_string=None):
294        from azure.mgmt.monitor.models import ScaleAction, ScaleType
295        if len(values) == 1:
296            # workaround because CMD.exe eats > character... Allows condition to be
297            # specified as a quoted expression
298            values = values[0].split(' ')
299        if len(values) != 2:
300            raise InvalidArgumentValueError('--scale {in,out,to} VALUE[%]')
301        dir_val = values[0]
302        amt_val = values[1]
303        scale_type = None
304        if dir_val == 'to':
305            scale_type = ScaleType.exact_count.value
306        elif str(amt_val).endswith('%'):
307            scale_type = ScaleType.percent_change_count.value
308            amt_val = amt_val[:-1]  # strip off the percent
309        else:
310            scale_type = ScaleType.change_count.value
311
312        scale = ScaleAction(
313            direction=get_autoscale_scale_direction_map()[dir_val],
314            type=scale_type,
315            cooldown=None,  # this will be filled in later
316            value=amt_val
317        )
318        namespace.scale = scale
319
320
321class MultiObjectsDeserializeAction(argparse._AppendAction):  # pylint: disable=protected-access
322    def __call__(self, parser, namespace, values, option_string=None):
323        type_name = values[0]
324        type_properties = values[1:]
325
326        try:
327            super(MultiObjectsDeserializeAction, self).__call__(parser,
328                                                                namespace,
329                                                                self.deserialize_object(type_name, type_properties),
330                                                                option_string)
331        except KeyError:
332            raise InvalidArgumentValueError('the type "{}" is not recognizable.'.format(type_name))
333        except TypeError:
334            raise InvalidArgumentValueError(
335                'Failed to parse "{}" as object of type "{}".'.format(' '.join(values), type_name))
336        except ValueError as ex:
337            raise InvalidArgumentValueError(
338                'Failed to parse "{}" as object of type "{}". {}'.format(
339                    ' '.join(values), type_name, str(ex)))
340
341    def deserialize_object(self, type_name, type_properties):
342        raise NotImplementedError()
343
344
345class ActionGroupReceiverParameterAction(MultiObjectsDeserializeAction):
346    def deserialize_object(self, type_name, type_properties):
347        from azure.mgmt.monitor.models import EmailReceiver, SmsReceiver, WebhookReceiver, \
348            ArmRoleReceiver, AzureAppPushReceiver, ItsmReceiver, AutomationRunbookReceiver, \
349            VoiceReceiver, LogicAppReceiver, AzureFunctionReceiver
350        syntax = {
351            'email': 'NAME EMAIL_ADDRESS [usecommonalertschema]',
352            'sms': 'NAME COUNTRY_CODE PHONE_NUMBER',
353            'webhook': 'NAME URI [useaadauth OBJECT_ID IDENTIFIER URI] [usecommonalertschema]',
354            'armrole': 'NAME ROLE_ID [usecommonalertschema]',
355            'azureapppush': 'NAME EMAIL_ADDRESS',
356            'itsm': 'NAME WORKSPACE_ID CONNECTION_ID TICKET_CONFIG REGION',
357            'automationrunbook': 'NAME AUTOMATION_ACCOUNT_ID RUNBOOK_NAME WEBHOOK_RESOURCE_ID '
358                                 'SERVICE_URI [isglobalrunbook] [usecommonalertschema]',
359            'voice': 'NAME COUNTRY_CODE PHONE_NUMBER',
360            'logicapp': 'NAME RESOURCE_ID CALLBACK_URL [usecommonalertschema]',
361            'azurefunction': 'NAME FUNCTION_APP_RESOURCE_ID '
362                             'FUNCTION_NAME HTTP_TRIGGER_URL [usecommonalertschema]'
363        }
364
365        receiver = None
366        useCommonAlertSchema = 'usecommonalertschema' in (property.lower() for property in type_properties)
367        try:
368            if type_name == 'email':
369                receiver = EmailReceiver(name=type_properties[0], email_address=type_properties[1],
370                                         use_common_alert_schema=useCommonAlertSchema)
371            elif type_name == 'sms':
372                receiver = SmsReceiver(
373                    name=type_properties[0],
374                    country_code=type_properties[1],
375                    phone_number=type_properties[2]
376                )
377            elif type_name == 'webhook':
378                useAadAuth = len(type_properties) >= 3 and type_properties[2] == 'useaadauth'
379                object_id = type_properties[3] if useAadAuth else None
380                identifier_uri = type_properties[4] if useAadAuth else None
381                receiver = WebhookReceiver(name=type_properties[0], service_uri=type_properties[1],
382                                           use_common_alert_schema=useCommonAlertSchema,
383                                           use_aad_auth=useAadAuth, object_id=object_id,
384                                           identifier_uri=identifier_uri)
385            elif type_name == 'armrole':
386                receiver = ArmRoleReceiver(name=type_properties[0], role_id=type_properties[1],
387                                           use_common_alert_schema=useCommonAlertSchema)
388            elif type_name == 'azureapppush':
389                receiver = AzureAppPushReceiver(name=type_properties[0], email_address=type_properties[1])
390            elif type_name == 'itsm':
391                receiver = ItsmReceiver(name=type_properties[0], workspace_id=type_properties[1],
392                                        connection_id=type_properties[2], ticket_configuration=type_properties[3],
393                                        region=type_properties[4])
394            elif type_name == 'automationrunbook':
395                isGlobalRunbook = 'isglobalrunbook' in (property.lower() for property in type_properties)
396                receiver = AutomationRunbookReceiver(name=type_properties[0], automation_account_id=type_properties[1],
397                                                     runbook_name=type_properties[2],
398                                                     webhook_resource_id=type_properties[3],
399                                                     service_uri=type_properties[4],
400                                                     is_global_runbook=isGlobalRunbook,
401                                                     use_common_alert_schema=useCommonAlertSchema)
402            elif type_name == 'voice':
403                receiver = VoiceReceiver(
404                    name=type_properties[0],
405                    country_code=type_properties[1],
406                    phone_number=type_properties[2]
407                )
408            elif type_name == 'logicapp':
409                receiver = LogicAppReceiver(name=type_properties[0], resource_id=type_properties[1],
410                                            callback_url=type_properties[2],
411                                            use_common_alert_schema=useCommonAlertSchema)
412            elif type_name == 'azurefunction':
413                receiver = AzureFunctionReceiver(name=type_properties[0], function_app_resource_id=type_properties[1],
414                                                 function_name=type_properties[2],
415                                                 http_trigger_url=type_properties[3],
416                                                 use_common_alert_schema=useCommonAlertSchema)
417            else:
418                raise InvalidArgumentValueError('The type "{}" is not recognizable.'.format(type_name))
419
420        except IndexError:
421            raise InvalidArgumentValueError('--action {}'.format(syntax[type_name]))
422        return receiver
423