1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2019 Chris Caron <lead2gold@gmail.com>
4# All rights reserved.
5#
6# This code is licensed under the MIT License.
7#
8# Permission is hereby granted, free of charge, to any person obtaining a copy
9# of this software and associated documentation files(the "Software"), to deal
10# in the Software without restriction, including without limitation the rights
11# to use, copy, modify, merge, publish, distribute, sublicense, and / or sell
12# copies of the Software, and to permit persons to whom the Software is
13# furnished to do so, subject to the following conditions :
14#
15# The above copyright notice and this permission notice shall be included in
16# all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.IN NO EVENT SHALL THE
21# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
24# THE SOFTWARE.
25
26import re
27import hmac
28import requests
29from hashlib import sha256
30from datetime import datetime
31from collections import OrderedDict
32from xml.etree import ElementTree
33from itertools import chain
34
35from .NotifyBase import NotifyBase
36from ..URLBase import PrivacyMode
37from ..common import NotifyType
38from ..utils import is_phone_no
39from ..utils import parse_list
40from ..utils import validate_regex
41from ..AppriseLocale import gettext_lazy as _
42
43# Topic Detection
44# Summary: 256 Characters max, only alpha/numeric plus underscore (_) and
45#          dash (-) additionally allowed.
46#
47#   Soure: https://docs.aws.amazon.com/AWSSimpleQueueService/latest\
48#                   /SQSDeveloperGuide/sqs-limits.html#limits-queues
49#
50# Allow a starting hashtag (#) specification to help eliminate possible
51# ambiguity between a topic that is comprised of all digits and a phone number
52IS_TOPIC = re.compile(r'^#?(?P<name>[A-Za-z0-9_-]+)\s*$')
53
54# Because our AWS Access Key Secret contains slashes, we actually use the
55# region as a delimiter. This is a bit hacky; but it's much easier than having
56# users of this product search though this Access Key Secret and escape all
57# of the forward slashes!
58IS_REGION = re.compile(
59    r'^\s*(?P<country>[a-z]{2})-(?P<area>[a-z]+)-(?P<no>[0-9]+)\s*$', re.I)
60
61# Extend HTTP Error Messages
62AWS_HTTP_ERROR_MAP = {
63    403: 'Unauthorized - Invalid Access/Secret Key Combination.',
64}
65
66
67class NotifySNS(NotifyBase):
68    """
69    A wrapper for AWS SNS (Amazon Simple Notification)
70    """
71
72    # The default descriptive name associated with the Notification
73    service_name = 'AWS Simple Notification Service (SNS)'
74
75    # The services URL
76    service_url = 'https://aws.amazon.com/sns/'
77
78    # The default secure protocol
79    secure_protocol = 'sns'
80
81    # A URL that takes you to the setup/help of the specific protocol
82    setup_url = 'https://github.com/caronc/apprise/wiki/Notify_sns'
83
84    # AWS is pretty good for handling data load so request limits
85    # can occur in much shorter bursts
86    request_rate_per_sec = 2.5
87
88    # The maximum length of the body
89    # Source: https://docs.aws.amazon.com/sns/latest/api/API_Publish.html
90    body_maxlen = 160
91
92    # A title can not be used for SMS Messages.  Setting this to zero will
93    # cause any title (if defined) to get placed into the message body.
94    title_maxlen = 0
95
96    # Define object templates
97    templates = (
98        '{schema}://{access_key_id}/{secret_access_key}{region}/{targets}',
99    )
100
101    # Define our template tokens
102    template_tokens = dict(NotifyBase.template_tokens, **{
103        'access_key_id': {
104            'name': _('Access Key ID'),
105            'type': 'string',
106            'private': True,
107            'required': True,
108        },
109        'secret_access_key': {
110            'name': _('Secret Access Key'),
111            'type': 'string',
112            'private': True,
113            'required': True,
114        },
115        'region': {
116            'name': _('Region'),
117            'type': 'string',
118            'required': True,
119            'regex': (r'^[a-z]{2}-[a-z]+-[0-9]+$', 'i'),
120            'map_to': 'region_name',
121        },
122        'target_phone_no': {
123            'name': _('Target Phone No'),
124            'type': 'string',
125            'map_to': 'targets',
126            'regex': (r'^[0-9\s)(+-]+$', 'i')
127        },
128        'target_topic': {
129            'name': _('Target Topic'),
130            'type': 'string',
131            'map_to': 'targets',
132            'prefix': '#',
133            'regex': (r'^[A-Za-z0-9_-]+$', 'i'),
134        },
135        'targets': {
136            'name': _('Targets'),
137            'type': 'list:string',
138        },
139    })
140
141    # Define our template arguments
142    template_args = dict(NotifyBase.template_args, **{
143        'to': {
144            'alias_of': 'targets',
145        },
146    })
147
148    def __init__(self, access_key_id, secret_access_key, region_name,
149                 targets=None, **kwargs):
150        """
151        Initialize Notify AWS SNS Object
152        """
153        super(NotifySNS, self).__init__(**kwargs)
154
155        # Store our AWS API Access Key
156        self.aws_access_key_id = validate_regex(access_key_id)
157        if not self.aws_access_key_id:
158            msg = 'An invalid AWS Access Key ID was specified.'
159            self.logger.warning(msg)
160            raise TypeError(msg)
161
162        # Store our AWS API Secret Access key
163        self.aws_secret_access_key = validate_regex(secret_access_key)
164        if not self.aws_secret_access_key:
165            msg = 'An invalid AWS Secret Access Key ' \
166                  '({}) was specified.'.format(secret_access_key)
167            self.logger.warning(msg)
168            raise TypeError(msg)
169
170        # Acquire our AWS Region Name:
171        # eg. us-east-1, cn-north-1, us-west-2, ...
172        self.aws_region_name = validate_regex(
173            region_name, *self.template_tokens['region']['regex'])
174        if not self.aws_region_name:
175            msg = 'An invalid AWS Region ({}) was specified.'.format(
176                region_name)
177            self.logger.warning(msg)
178            raise TypeError(msg)
179
180        # Initialize topic list
181        self.topics = list()
182
183        # Initialize numbers list
184        self.phone = list()
185
186        # Set our notify_url based on our region
187        self.notify_url = 'https://sns.{}.amazonaws.com/'\
188            .format(self.aws_region_name)
189
190        # AWS Service Details
191        self.aws_service_name = 'sns'
192        self.aws_canonical_uri = '/'
193
194        # AWS Authentication Details
195        self.aws_auth_version = 'AWS4'
196        self.aws_auth_algorithm = 'AWS4-HMAC-SHA256'
197        self.aws_auth_request = 'aws4_request'
198
199        # Validate targets and drop bad ones:
200        for target in parse_list(targets):
201            result = is_phone_no(target)
202            if result:
203                # store valid phone number
204                self.phone.append('+{}'.format(result))
205                continue
206
207            result = IS_TOPIC.match(target)
208            if result:
209                # store valid topic
210                self.topics.append(result.group('name'))
211                continue
212
213            self.logger.warning(
214                'Dropped invalid phone/topic '
215                '(%s) specified.' % target,
216            )
217
218        return
219
220    def send(self, body, title='', notify_type=NotifyType.INFO, **kwargs):
221        """
222        wrapper to send_notification since we can alert more then one channel
223        """
224
225        if len(self.phone) == 0 and len(self.topics) == 0:
226            # We have a bot token and no target(s) to message
227            self.logger.warning('No AWS targets to notify.')
228            return False
229
230        # Initiaize our error tracking
231        error_count = 0
232
233        # Create a copy of our phone #'s to notify against
234        phone = list(self.phone)
235        topics = list(self.topics)
236
237        while len(phone) > 0:
238
239            # Get Phone No
240            no = phone.pop(0)
241
242            # Prepare SNS Message Payload
243            payload = {
244                'Action': u'Publish',
245                'Message': body,
246                'Version': u'2010-03-31',
247                'PhoneNumber': no,
248            }
249
250            (result, _) = self._post(payload=payload, to=no)
251            if not result:
252                error_count += 1
253
254        # Send all our defined topic id's
255        while len(topics):
256
257            # Get Topic
258            topic = topics.pop(0)
259
260            # First ensure our topic exists, if it doesn't, it gets created
261            payload = {
262                'Action': u'CreateTopic',
263                'Version': u'2010-03-31',
264                'Name': topic,
265            }
266
267            (result, response) = self._post(payload=payload, to=topic)
268            if not result:
269                error_count += 1
270                continue
271
272            # Get the Amazon Resource Name
273            topic_arn = response.get('topic_arn')
274            if not topic_arn:
275                # Could not acquire our topic; we're done
276                error_count += 1
277                continue
278
279            # Build our payload now that we know our topic_arn
280            payload = {
281                'Action': u'Publish',
282                'Version': u'2010-03-31',
283                'TopicArn': topic_arn,
284                'Message': body,
285            }
286
287            # Send our payload to AWS
288            (result, _) = self._post(payload=payload, to=topic)
289            if not result:
290                error_count += 1
291
292        return error_count == 0
293
294    def _post(self, payload, to):
295        """
296        Wrapper to request.post() to manage it's response better and make
297        the send() function cleaner and easier to maintain.
298
299        This function returns True if the _post was successful and False
300        if it wasn't.
301        """
302
303        # Always call throttle before any remote server i/o is made; for AWS
304        # time plays a huge factor in the headers being sent with the payload.
305        # So for AWS (SNS) requests we must throttle before they're generated
306        # and not directly before the i/o call like other notification
307        # services do.
308        self.throttle()
309
310        # Convert our payload from a dict() into a urlencoded string
311        payload = NotifySNS.urlencode(payload)
312
313        # Prepare our Notification URL
314        # Prepare our AWS Headers based on our payload
315        headers = self.aws_prepare_request(payload)
316
317        self.logger.debug('AWS POST URL: %s (cert_verify=%r)' % (
318            self.notify_url, self.verify_certificate,
319        ))
320        self.logger.debug('AWS Payload: %s' % str(payload))
321
322        try:
323            r = requests.post(
324                self.notify_url,
325                data=payload,
326                headers=headers,
327                verify=self.verify_certificate,
328                timeout=self.request_timeout,
329            )
330
331            if r.status_code != requests.codes.ok:
332                # We had a problem
333                status_str = \
334                    NotifySNS.http_response_code_lookup(
335                        r.status_code, AWS_HTTP_ERROR_MAP)
336
337                self.logger.warning(
338                    'Failed to send AWS notification to {}: '
339                    '{}{}error={}.'.format(
340                        to,
341                        status_str,
342                        ', ' if status_str else '',
343                        r.status_code))
344
345                self.logger.debug('Response Details:\r\n{}'.format(r.content))
346
347                return (False, NotifySNS.aws_response_to_dict(r.text))
348
349            else:
350                self.logger.info(
351                    'Sent AWS notification to "%s".' % (to))
352
353        except requests.RequestException as e:
354            self.logger.warning(
355                'A Connection error occurred sending AWS '
356                'notification to "%s".' % (to),
357            )
358            self.logger.debug('Socket Exception: %s' % str(e))
359            return (False, NotifySNS.aws_response_to_dict(None))
360
361        return (True, NotifySNS.aws_response_to_dict(r.text))
362
363    def aws_prepare_request(self, payload, reference=None):
364        """
365        Takes the intended payload and returns the headers for it.
366
367        The payload is presumed to have been already urlencoded()
368
369        """
370
371        # Define our AWS header
372        headers = {
373            'User-Agent': self.app_id,
374            'Content-Type': 'application/x-www-form-urlencoded; charset=utf-8',
375
376            # Populated below
377            'Content-Length': 0,
378            'Authorization': None,
379            'X-Amz-Date': None,
380        }
381
382        # Get a reference time (used for header construction)
383        reference = datetime.utcnow()
384
385        # Provide Content-Length
386        headers['Content-Length'] = str(len(payload))
387
388        # Amazon Date Format
389        amzdate = reference.strftime('%Y%m%dT%H%M%SZ')
390        headers['X-Amz-Date'] = amzdate
391
392        # Credential Scope
393        scope = '{date}/{region}/{service}/{request}'.format(
394            date=reference.strftime('%Y%m%d'),
395            region=self.aws_region_name,
396            service=self.aws_service_name,
397            request=self.aws_auth_request,
398        )
399
400        # Similar to headers; but a subset.  keys must be lowercase
401        signed_headers = OrderedDict([
402            ('content-type', headers['Content-Type']),
403            ('host', '{service}.{region}.amazonaws.com'.format(
404                service=self.aws_service_name,
405                region=self.aws_region_name)),
406            ('x-amz-date', headers['X-Amz-Date']),
407        ])
408
409        #
410        # Build Canonical Request Object
411        #
412        canonical_request = '\n'.join([
413            # Method
414            u'POST',
415
416            # URL
417            self.aws_canonical_uri,
418
419            # Query String (none set for POST)
420            '',
421
422            # Header Content (must include \n at end!)
423            # All entries except characters in amazon date must be
424            # lowercase
425            '\n'.join(['%s:%s' % (k, v)
426                      for k, v in signed_headers.items()]) + '\n',
427
428            # Header Entries (in same order identified above)
429            ';'.join(signed_headers.keys()),
430
431            # Payload
432            sha256(payload.encode('utf-8')).hexdigest(),
433        ])
434
435        # Prepare Unsigned Signature
436        to_sign = '\n'.join([
437            self.aws_auth_algorithm,
438            amzdate,
439            scope,
440            sha256(canonical_request.encode('utf-8')).hexdigest(),
441        ])
442
443        # Our Authorization header
444        headers['Authorization'] = ', '.join([
445            '{algorithm} Credential={key}/{scope}'.format(
446                algorithm=self.aws_auth_algorithm,
447                key=self.aws_access_key_id,
448                scope=scope,
449            ),
450            'SignedHeaders={signed_headers}'.format(
451                signed_headers=';'.join(signed_headers.keys()),
452            ),
453            'Signature={signature}'.format(
454                signature=self.aws_auth_signature(to_sign, reference)
455            ),
456        ])
457
458        return headers
459
460    def aws_auth_signature(self, to_sign, reference):
461        """
462        Generates a AWS v4 signature based on provided payload
463        which should be in the form of a string.
464        """
465
466        def _sign(key, msg, to_hex=False):
467            """
468            Perform AWS Signing
469            """
470            if to_hex:
471                return hmac.new(key, msg.encode('utf-8'), sha256).hexdigest()
472            return hmac.new(key, msg.encode('utf-8'), sha256).digest()
473
474        _date = _sign((
475            self.aws_auth_version +
476            self.aws_secret_access_key).encode('utf-8'),
477            reference.strftime('%Y%m%d'))
478
479        _region = _sign(_date, self.aws_region_name)
480        _service = _sign(_region, self.aws_service_name)
481        _signed = _sign(_service, self.aws_auth_request)
482        return _sign(_signed, to_sign, to_hex=True)
483
484    @staticmethod
485    def aws_response_to_dict(aws_response):
486        """
487        Takes an AWS Response object as input and returns it as a dictionary
488        but not befor extracting out what is useful to us first.
489
490        eg:
491          IN:
492            <CreateTopicResponse
493                  xmlns="http://sns.amazonaws.com/doc/2010-03-31/">
494              <CreateTopicResult>
495                <TopicArn>arn:aws:sns:us-east-1:000000000000:abcd</TopicArn>
496                   </CreateTopicResult>
497               <ResponseMetadata>
498               <RequestId>604bef0f-369c-50c5-a7a4-bbd474c83d6a</RequestId>
499               </ResponseMetadata>
500           </CreateTopicResponse>
501
502          OUT:
503           {
504              type: 'CreateTopicResponse',
505              request_id: '604bef0f-369c-50c5-a7a4-bbd474c83d6a',
506              topic_arn: 'arn:aws:sns:us-east-1:000000000000:abcd',
507           }
508        """
509
510        # Define ourselves a set of directives we want to keep if found and
511        # then identify the value we want to map them to in our response
512        # object
513        aws_keep_map = {
514            'RequestId': 'request_id',
515            'TopicArn': 'topic_arn',
516            'MessageId': 'message_id',
517
518            # Error Message Handling
519            'Type': 'error_type',
520            'Code': 'error_code',
521            'Message': 'error_message',
522        }
523
524        # A default response object that we'll manipulate as we pull more data
525        # from our AWS Response object
526        response = {
527            'type': None,
528            'request_id': None,
529        }
530
531        try:
532            # we build our tree, but not before first eliminating any
533            # reference to namespacing (if present) as it makes parsing
534            # the tree so much easier.
535            root = ElementTree.fromstring(
536                re.sub(' xmlns="[^"]+"', '', aws_response, count=1))
537
538            # Store our response tag object name
539            response['type'] = str(root.tag)
540
541            def _xml_iter(root, response):
542                if len(root) > 0:
543                    for child in root:
544                        # use recursion to parse everything
545                        _xml_iter(child, response)
546
547                elif root.tag in aws_keep_map.keys():
548                    response[aws_keep_map[root.tag]] = (root.text).strip()
549
550            # Recursivly iterate over our AWS Response to extract the
551            # fields we're interested in in efforts to populate our response
552            # object.
553            _xml_iter(root, response)
554
555        except (ElementTree.ParseError, TypeError):
556            # bad data just causes us to generate a bad response
557            pass
558
559        return response
560
561    def url(self, privacy=False, *args, **kwargs):
562        """
563        Returns the URL built dynamically based on specified arguments.
564        """
565
566        # Our URL parameters
567        params = self.url_parameters(privacy=privacy, *args, **kwargs)
568
569        return '{schema}://{key_id}/{key_secret}/{region}/{targets}/'\
570            '?{params}'.format(
571                schema=self.secure_protocol,
572                key_id=self.pprint(self.aws_access_key_id, privacy, safe=''),
573                key_secret=self.pprint(
574                    self.aws_secret_access_key, privacy,
575                    mode=PrivacyMode.Secret, safe=''),
576                region=NotifySNS.quote(self.aws_region_name, safe=''),
577                targets='/'.join(
578                    [NotifySNS.quote(x) for x in chain(
579                        # Phone # are prefixed with a plus symbol
580                        ['+{}'.format(x) for x in self.phone],
581                        # Topics are prefixed with a pound/hashtag symbol
582                        ['#{}'.format(x) for x in self.topics],
583                    )]),
584                params=NotifySNS.urlencode(params),
585            )
586
587    @staticmethod
588    def parse_url(url):
589        """
590        Parses the URL and returns enough arguments that can allow
591        us to re-instantiate this object.
592
593        """
594        results = NotifyBase.parse_url(url, verify_host=False)
595        if not results:
596            # We're done early as we couldn't load the results
597            return results
598
599        # The AWS Access Key ID is stored in the hostname
600        access_key_id = NotifySNS.unquote(results['host'])
601
602        # Our AWS Access Key Secret contains slashes in it which unfortunately
603        # means it is of variable length after the hostname.  Since we require
604        # that the user provides the region code, we intentionally use this
605        # as our delimiter to detect where our Secret is.
606        secret_access_key = None
607        region_name = None
608
609        # We need to iterate over each entry in the fullpath and find our
610        # region. Once we get there we stop and build our secret from our
611        # accumulated data.
612        secret_access_key_parts = list()
613
614        # Start with a list of entries to work with
615        entries = NotifySNS.split_path(results['fullpath'])
616
617        # Section 1: Get Region and Access Secret
618        index = 0
619        for i, entry in enumerate(entries):
620
621            # Are we at the region yet?
622            result = IS_REGION.match(entry)
623            if result:
624                # We found our Region; Rebuild our access key secret based on
625                # all entries we found prior to this:
626                secret_access_key = '/'.join(secret_access_key_parts)
627
628                # Ensure region is nicely formatted
629                region_name = "{country}-{area}-{no}".format(
630                    country=result.group('country').lower(),
631                    area=result.group('area').lower(),
632                    no=result.group('no'),
633                )
634
635                # Track our index as we'll use this to grab the remaining
636                # content in the next Section
637                index = i + 1
638
639                # We're done with Section 1
640                break
641
642            # Store our secret parts
643            secret_access_key_parts.append(entry)
644
645        # Section 2: Get our Recipients (basically all remaining entries)
646        results['targets'] = entries[index:]
647
648        # Support the 'to' variable so that we can support rooms this way too
649        # The 'to' makes it easier to use yaml configuration
650        if 'to' in results['qsd'] and len(results['qsd']['to']):
651            results['targets'] += \
652                NotifySNS.parse_list(results['qsd']['to'])
653
654        # Store our other detected data (if at all)
655        results['region_name'] = region_name
656        results['access_key_id'] = access_key_id
657        results['secret_access_key'] = secret_access_key
658
659        # Return our result set
660        return results
661