1import copy
2import json
3import logging
4import os
5import platform
6import time
7
8from aws_xray_sdk import global_sdk_config
9from aws_xray_sdk.version import VERSION
10from .models.segment import Segment, SegmentContextManager
11from .models.subsegment import Subsegment, SubsegmentContextManager
12from .models.default_dynamic_naming import DefaultDynamicNaming
13from .models.dummy_entities import DummySegment, DummySubsegment
14from .emitters.udp_emitter import UDPEmitter
15from .sampling.sampler import DefaultSampler
16from .sampling.local.sampler import LocalSampler
17from .streaming.default_streaming import DefaultStreaming
18from .context import Context
19from .daemon_config import DaemonConfig
20from .plugins.utils import get_plugin_modules
21from .lambda_launcher import check_in_lambda
22from .exceptions.exceptions import SegmentNameMissingException, SegmentNotFoundException
23from .utils.compat import string_types
24from .utils import stacktrace
25
26log = logging.getLogger(__name__)
27
28TRACING_NAME_KEY = 'AWS_XRAY_TRACING_NAME'
29DAEMON_ADDR_KEY = 'AWS_XRAY_DAEMON_ADDRESS'
30CONTEXT_MISSING_KEY = 'AWS_XRAY_CONTEXT_MISSING'
31
32XRAY_META = {
33    'xray': {
34        'sdk': 'X-Ray for Python',
35        'sdk_version': VERSION
36    }
37}
38
39SERVICE_INFO = {
40    'runtime': platform.python_implementation(),
41    'runtime_version': platform.python_version()
42}
43
44
45class AWSXRayRecorder(object):
46    """
47    A global AWS X-Ray recorder that will begin/end segments/subsegments
48    and send them to the X-Ray daemon. This recorder is initialized during
49    loading time so you can use::
50
51        from aws_xray_sdk.core import xray_recorder
52
53    in your module to access it
54    """
55    def __init__(self):
56
57        self._streaming = DefaultStreaming()
58        context = check_in_lambda()
59        if context:
60            # Special handling when running on AWS Lambda.
61            self._context = context
62            self.streaming_threshold = 0
63            self._sampler = LocalSampler()
64        else:
65            self._context = Context()
66            self._sampler = DefaultSampler()
67
68        self._emitter = UDPEmitter()
69        self._sampling = True
70        self._max_trace_back = 10
71        self._plugins = None
72        self._service = os.getenv(TRACING_NAME_KEY)
73        self._dynamic_naming = None
74        self._aws_metadata = copy.deepcopy(XRAY_META)
75        self._origin = None
76        self._stream_sql = True
77
78        if type(self.sampler).__name__ == 'DefaultSampler':
79            self.sampler.load_settings(DaemonConfig(), self.context)
80
81    def configure(self, sampling=None, plugins=None,
82                  context_missing=None, sampling_rules=None,
83                  daemon_address=None, service=None,
84                  context=None, emitter=None, streaming=None,
85                  dynamic_naming=None, streaming_threshold=None,
86                  max_trace_back=None, sampler=None,
87                  stream_sql=True):
88        """Configure global X-Ray recorder.
89
90        Configure needs to run before patching thrid party libraries
91        to avoid creating dangling subsegment.
92        :param bool sampling: If sampling is enabled, every time the recorder
93            creates a segment it decides whether to send this segment to
94            the X-Ray daemon. This setting is not used if the recorder
95            is running in AWS Lambda. The recorder always respect the incoming
96            sampling decisions regardless of this setting.
97        :param sampling_rules: Pass a set of local custom sampling rules.
98            Can be an absolute path of the sampling rule config json file
99            or a dictionary that defines those rules. This will also be the
100            fallback rules in case of centralized sampling opted-in while
101            the cetralized sampling rules are not available.
102        :param sampler: The sampler used to make sampling decisions. The SDK
103            provides two built-in samplers. One is centralized rules based and
104            the other is local rules based. The former is the default.
105        :param tuple plugins: plugins that add extra metadata to each segment.
106            Currently available plugins are EC2Plugin, ECS plugin and
107            ElasticBeanstalkPlugin.
108            If you want to disable all previously enabled plugins,
109            pass an empty tuple ``()``.
110        :param str context_missing: recorder behavior when it tries to mutate
111            a segment or add a subsegment but there is no active segment.
112            RUNTIME_ERROR means the recorder will raise an exception.
113            LOG_ERROR means the recorder will only log the error and
114            do nothing.
115        :param str daemon_address: The X-Ray daemon address where the recorder
116            sends data to.
117        :param str service: default segment name if creating a segment without
118            providing a name.
119        :param context: You can pass your own implementation of context storage
120            for active segment/subsegment by overriding the default
121            ``Context`` class.
122        :param emitter: The emitter that sends a segment/subsegment to
123            the X-Ray daemon. You can override ``UDPEmitter`` class.
124        :param dynamic_naming: a string that defines a pattern that host names
125            should match. Alternatively you can pass a module which
126            overrides ``DefaultDynamicNaming`` module.
127        :param streaming: The streaming module to stream out trace documents
128            when they grow too large. You can override ``DefaultStreaming``
129            class to have your own implementation of the streaming process.
130        :param streaming_threshold: If breaks within a single segment it will
131            start streaming out children subsegments. By default it is the
132            maximum number of subsegments within a segment.
133        :param int max_trace_back: The maxinum number of stack traces recorded
134            by auto-capture. Lower this if a single document becomes too large.
135        :param bool stream_sql: Whether SQL query texts should be streamed.
136
137        Environment variables AWS_XRAY_DAEMON_ADDRESS, AWS_XRAY_CONTEXT_MISSING
138        and AWS_XRAY_TRACING_NAME respectively overrides arguments
139        daemon_address, context_missing and service.
140        """
141
142        if sampling is not None:
143            self.sampling = sampling
144        if sampler:
145            self.sampler = sampler
146        if service:
147            self.service = os.getenv(TRACING_NAME_KEY, service)
148        if sampling_rules:
149            self._load_sampling_rules(sampling_rules)
150        if emitter:
151            self.emitter = emitter
152        if daemon_address:
153            self.emitter.set_daemon_address(os.getenv(DAEMON_ADDR_KEY, daemon_address))
154        if context:
155            self.context = context
156        if context_missing:
157            self.context.context_missing = os.getenv(CONTEXT_MISSING_KEY, context_missing)
158        if dynamic_naming:
159            self.dynamic_naming = dynamic_naming
160        if streaming:
161            self.streaming = streaming
162        if streaming_threshold is not None:
163            self.streaming_threshold = streaming_threshold
164        if type(max_trace_back) == int and max_trace_back >= 0:
165            self.max_trace_back = max_trace_back
166        if stream_sql is not None:
167            self.stream_sql = stream_sql
168
169        if plugins:
170            plugin_modules = get_plugin_modules(plugins)
171            for plugin in plugin_modules:
172                plugin.initialize()
173                if plugin.runtime_context:
174                    self._aws_metadata[plugin.SERVICE_NAME] = plugin.runtime_context
175                    self._origin = plugin.ORIGIN
176        # handling explicitly using empty list to clean up plugins.
177        elif plugins is not None:
178            self._aws_metadata = copy.deepcopy(XRAY_META)
179            self._origin = None
180
181        if type(self.sampler).__name__ == 'DefaultSampler':
182            self.sampler.load_settings(DaemonConfig(daemon_address),
183                                       self.context, self._origin)
184
185    def in_segment(self, name=None, **segment_kwargs):
186        """
187        Return a segment context manager.
188
189        :param str name: the name of the segment
190        :param dict segment_kwargs: remaining arguments passed directly to `begin_segment`
191        """
192        return SegmentContextManager(self, name=name, **segment_kwargs)
193
194    def in_subsegment(self, name=None, **subsegment_kwargs):
195        """
196        Return a subsegment context manager.
197
198        :param str name: the name of the subsegment
199        :param dict subsegment_kwargs: remaining arguments passed directly to `begin_subsegment`
200        """
201        return SubsegmentContextManager(self, name=name, **subsegment_kwargs)
202
203    def begin_segment(self, name=None, traceid=None,
204                      parent_id=None, sampling=None):
205        """
206        Begin a segment on the current thread and return it. The recorder
207        only keeps one segment at a time. Create the second one without
208        closing existing one will overwrite it.
209
210        :param str name: the name of the segment
211        :param str traceid: trace id of the segment
212        :param int sampling: 0 means not sampled, 1 means sampled
213        """
214        # Disable the recorder; return a generated dummy segment.
215        if not global_sdk_config.sdk_enabled():
216            return DummySegment(global_sdk_config.DISABLED_ENTITY_NAME)
217
218        seg_name = name or self.service
219        if not seg_name:
220            raise SegmentNameMissingException("Segment name is required.")
221
222        # Sampling decision is None if not sampled.
223        # In a sampled case it could be either a string or 1
224        # depending on if centralized or local sampling rule takes effect.
225        decision = True
226
227        # we respect the input sampling decision
228        # regardless of recorder configuration.
229        if sampling == 0:
230            decision = False
231        elif sampling:
232            decision = sampling
233        elif self.sampling:
234            decision = self._sampler.should_trace()
235
236        if not decision:
237            segment = DummySegment(seg_name)
238        else:
239            segment = Segment(name=seg_name, traceid=traceid,
240                              parent_id=parent_id)
241            self._populate_runtime_context(segment, decision)
242
243        self.context.put_segment(segment)
244        return segment
245
246    def end_segment(self, end_time=None):
247        """
248        End the current segment and send it to X-Ray daemon
249        if it is ready to send. Ready means segment and
250        all its subsegments are closed.
251
252        :param float end_time: segment completion in unix epoch in seconds.
253        """
254        # When the SDK is disabled we return
255        if not global_sdk_config.sdk_enabled():
256            return
257
258        self.context.end_segment(end_time)
259        segment = self.current_segment()
260        if segment and segment.ready_to_send():
261            self._send_segment()
262
263    def current_segment(self):
264        """
265        Return the currently active segment. In a multithreading environment,
266        this will make sure the segment returned is the one created by the
267        same thread.
268        """
269
270        entity = self.get_trace_entity()
271        if self._is_subsegment(entity):
272            return entity.parent_segment
273        else:
274            return entity
275
276    def begin_subsegment(self, name, namespace='local'):
277        """
278        Begin a new subsegment.
279        If there is open subsegment, the newly created subsegment will be the
280        child of latest opened subsegment.
281        If not, it will be the child of the current open segment.
282
283        :param str name: the name of the subsegment.
284        :param str namespace: currently can only be 'local', 'remote', 'aws'.
285        """
286        # Generating the parent dummy segment is necessary.
287        # We don't need to store anything in context. Assumption here
288        # is that we only work with recorder-level APIs.
289        if not global_sdk_config.sdk_enabled():
290            return DummySubsegment(DummySegment(global_sdk_config.DISABLED_ENTITY_NAME))
291
292        segment = self.current_segment()
293        if not segment:
294            log.warning("No segment found, cannot begin subsegment %s." % name)
295            return None
296
297        if not segment.sampled:
298            subsegment = DummySubsegment(segment, name)
299        else:
300            subsegment = Subsegment(name, namespace, segment)
301
302        self.context.put_subsegment(subsegment)
303
304        return subsegment
305
306    def current_subsegment(self):
307        """
308        Return the latest opened subsegment. In a multithreading environment,
309        this will make sure the subsegment returned is one created
310        by the same thread.
311        """
312        if not global_sdk_config.sdk_enabled():
313            return DummySubsegment(DummySegment(global_sdk_config.DISABLED_ENTITY_NAME))
314
315        entity = self.get_trace_entity()
316        if self._is_subsegment(entity):
317            return entity
318        else:
319            return None
320
321    def end_subsegment(self, end_time=None):
322        """
323        End the current active subsegment. If this is the last one open
324        under its parent segment, the entire segment will be sent.
325
326        :param float end_time: subsegment compeletion in unix epoch in seconds.
327        """
328        if not global_sdk_config.sdk_enabled():
329            return
330
331        if not self.context.end_subsegment(end_time):
332            return
333
334        # if segment is already close, we check if we can send entire segment
335        # otherwise we check if we need to stream some subsegments
336        if self.current_segment().ready_to_send():
337            self._send_segment()
338        else:
339            self.stream_subsegments()
340
341    def put_annotation(self, key, value):
342        """
343        Annotate current active trace entity with a key-value pair.
344        Annotations will be indexed for later search query.
345
346        :param str key: annotation key
347        :param object value: annotation value. Any type other than
348            string/number/bool will be dropped
349        """
350        if not global_sdk_config.sdk_enabled():
351            return
352        entity = self.get_trace_entity()
353        if entity and entity.sampled:
354            entity.put_annotation(key, value)
355
356    def put_metadata(self, key, value, namespace='default'):
357        """
358        Add metadata to the current active trace entity.
359        Metadata is not indexed but can be later retrieved
360        by BatchGetTraces API.
361
362        :param str namespace: optional. Default namespace is `default`.
363            It must be a string and prefix `AWS.` is reserved.
364        :param str key: metadata key under specified namespace
365        :param object value: any object that can be serialized into JSON string
366        """
367        if not global_sdk_config.sdk_enabled():
368            return
369        entity = self.get_trace_entity()
370        if entity and entity.sampled:
371            entity.put_metadata(key, value, namespace)
372
373    def is_sampled(self):
374        """
375        Check if the current trace entity is sampled or not.
376        Return `False` if no active entity found.
377        """
378        if not global_sdk_config.sdk_enabled():
379            # Disabled SDK is never sampled
380            return False
381        entity = self.get_trace_entity()
382        if entity:
383            return entity.sampled
384        return False
385
386    def get_trace_entity(self):
387        """
388        A pass through method to ``context.get_trace_entity()``.
389        """
390        return self.context.get_trace_entity()
391
392    def set_trace_entity(self, trace_entity):
393        """
394        A pass through method to ``context.set_trace_entity()``.
395        """
396        self.context.set_trace_entity(trace_entity)
397
398    def clear_trace_entities(self):
399        """
400        A pass through method to ``context.clear_trace_entities()``.
401        """
402        self.context.clear_trace_entities()
403
404    def stream_subsegments(self):
405        """
406        Stream all closed subsegments to the daemon
407        and remove reference to the parent segment.
408        No-op for a not sampled segment.
409        """
410        segment = self.current_segment()
411
412        if self.streaming.is_eligible(segment):
413            self.streaming.stream(segment, self._stream_subsegment_out)
414
415    def capture(self, name=None):
416        """
417        A decorator that records enclosed function in a subsegment.
418        It only works with synchronous functions.
419
420        params str name: The name of the subsegment. If not specified
421        the function name will be used.
422        """
423        return self.in_subsegment(name=name)
424
425    def record_subsegment(self, wrapped, instance, args, kwargs, name,
426                          namespace, meta_processor):
427
428        subsegment = self.begin_subsegment(name, namespace)
429
430        exception = None
431        stack = None
432        return_value = None
433
434        try:
435            return_value = wrapped(*args, **kwargs)
436            return return_value
437        except Exception as e:
438            exception = e
439            stack = stacktrace.get_stacktrace(limit=self.max_trace_back)
440            raise
441        finally:
442            # No-op if subsegment is `None` due to `LOG_ERROR`.
443            if subsegment is not None:
444                end_time = time.time()
445                if callable(meta_processor):
446                    meta_processor(
447                        wrapped=wrapped,
448                        instance=instance,
449                        args=args,
450                        kwargs=kwargs,
451                        return_value=return_value,
452                        exception=exception,
453                        subsegment=subsegment,
454                        stack=stack,
455                    )
456                elif exception:
457                    subsegment.add_exception(exception, stack)
458
459                self.end_subsegment(end_time)
460
461    def _populate_runtime_context(self, segment, sampling_decision):
462        if self._origin:
463            setattr(segment, 'origin', self._origin)
464
465        segment.set_aws(copy.deepcopy(self._aws_metadata))
466        segment.set_service(SERVICE_INFO)
467
468        if isinstance(sampling_decision, string_types):
469            segment.set_rule_name(sampling_decision)
470
471    def _send_segment(self):
472        """
473        Send the current segment to X-Ray daemon if it is present and
474        sampled, then clean up context storage.
475        The emitter will handle failures.
476        """
477        segment = self.current_segment()
478
479        if not segment:
480            return
481
482        if segment.sampled:
483            self.emitter.send_entity(segment)
484        self.clear_trace_entities()
485
486    def _stream_subsegment_out(self, subsegment):
487        log.debug("streaming subsegments...")
488        self.emitter.send_entity(subsegment)
489
490    def _load_sampling_rules(self, sampling_rules):
491
492        if not sampling_rules:
493            return
494
495        if isinstance(sampling_rules, dict):
496            self.sampler.load_local_rules(sampling_rules)
497        else:
498            with open(sampling_rules) as f:
499                self.sampler.load_local_rules(json.load(f))
500
501    def _is_subsegment(self, entity):
502
503        return (hasattr(entity, 'type') and entity.type == 'subsegment')
504
505    @property
506    def enabled(self):
507        return self._enabled
508
509    @enabled.setter
510    def enabled(self, value):
511        self._enabled = value
512
513    @property
514    def sampling(self):
515        return self._sampling
516
517    @sampling.setter
518    def sampling(self, value):
519        self._sampling = value
520
521    @property
522    def sampler(self):
523        return self._sampler
524
525    @sampler.setter
526    def sampler(self, value):
527        self._sampler = value
528
529    @property
530    def service(self):
531        return self._service
532
533    @service.setter
534    def service(self, value):
535        self._service = value
536
537    @property
538    def dynamic_naming(self):
539        return self._dynamic_naming
540
541    @dynamic_naming.setter
542    def dynamic_naming(self, value):
543        if isinstance(value, string_types):
544            self._dynamic_naming = DefaultDynamicNaming(value, self.service)
545        else:
546            self._dynamic_naming = value
547
548    @property
549    def context(self):
550        return self._context
551
552    @context.setter
553    def context(self, cxt):
554        self._context = cxt
555
556    @property
557    def emitter(self):
558        return self._emitter
559
560    @emitter.setter
561    def emitter(self, value):
562        self._emitter = value
563
564    @property
565    def streaming(self):
566        return self._streaming
567
568    @streaming.setter
569    def streaming(self, value):
570        self._streaming = value
571
572    @property
573    def streaming_threshold(self):
574        """
575        Proxy method to Streaming module's `streaming_threshold` property.
576        """
577        return self.streaming.streaming_threshold
578
579    @streaming_threshold.setter
580    def streaming_threshold(self, value):
581        """
582        Proxy method to Streaming module's `streaming_threshold` property.
583        """
584        self.streaming.streaming_threshold = value
585
586    @property
587    def max_trace_back(self):
588        return self._max_trace_back
589
590    @max_trace_back.setter
591    def max_trace_back(self, value):
592        self._max_trace_back = value
593
594    @property
595    def stream_sql(self):
596        return self._stream_sql
597
598    @stream_sql.setter
599    def stream_sql(self, value):
600        self._stream_sql = value
601