1from sentry_sdk import configure_scope
2from sentry_sdk.hub import Hub
3from sentry_sdk.integrations import Integration
4from sentry_sdk.utils import capture_internal_exceptions
5
6from sentry_sdk._types import MYPY
7
8if MYPY:
9    from typing import Any
10    from typing import Optional
11
12    from sentry_sdk._types import Event, Hint
13
14
15class SparkIntegration(Integration):
16    identifier = "spark"
17
18    @staticmethod
19    def setup_once():
20        # type: () -> None
21        patch_spark_context_init()
22
23
24def _set_app_properties():
25    # type: () -> None
26    """
27    Set properties in driver that propagate to worker processes, allowing for workers to have access to those properties.
28    This allows worker integration to have access to app_name and application_id.
29    """
30    from pyspark import SparkContext
31
32    spark_context = SparkContext._active_spark_context
33    if spark_context:
34        spark_context.setLocalProperty("sentry_app_name", spark_context.appName)
35        spark_context.setLocalProperty(
36            "sentry_application_id", spark_context.applicationId
37        )
38
39
40def _start_sentry_listener(sc):
41    # type: (Any) -> None
42    """
43    Start java gateway server to add custom `SparkListener`
44    """
45    from pyspark.java_gateway import ensure_callback_server_started
46
47    gw = sc._gateway
48    ensure_callback_server_started(gw)
49    listener = SentryListener()
50    sc._jsc.sc().addSparkListener(listener)
51
52
53def patch_spark_context_init():
54    # type: () -> None
55    from pyspark import SparkContext
56
57    spark_context_init = SparkContext._do_init
58
59    def _sentry_patched_spark_context_init(self, *args, **kwargs):
60        # type: (SparkContext, *Any, **Any) -> Optional[Any]
61        init = spark_context_init(self, *args, **kwargs)
62
63        if Hub.current.get_integration(SparkIntegration) is None:
64            return init
65
66        _start_sentry_listener(self)
67        _set_app_properties()
68
69        with configure_scope() as scope:
70
71            @scope.add_event_processor
72            def process_event(event, hint):
73                # type: (Event, Hint) -> Optional[Event]
74                with capture_internal_exceptions():
75                    if Hub.current.get_integration(SparkIntegration) is None:
76                        return event
77
78                    event.setdefault("user", {}).setdefault("id", self.sparkUser())
79
80                    event.setdefault("tags", {}).setdefault(
81                        "executor.id", self._conf.get("spark.executor.id")
82                    )
83                    event["tags"].setdefault(
84                        "spark-submit.deployMode",
85                        self._conf.get("spark.submit.deployMode"),
86                    )
87                    event["tags"].setdefault(
88                        "driver.host", self._conf.get("spark.driver.host")
89                    )
90                    event["tags"].setdefault(
91                        "driver.port", self._conf.get("spark.driver.port")
92                    )
93                    event["tags"].setdefault("spark_version", self.version)
94                    event["tags"].setdefault("app_name", self.appName)
95                    event["tags"].setdefault("application_id", self.applicationId)
96                    event["tags"].setdefault("master", self.master)
97                    event["tags"].setdefault("spark_home", self.sparkHome)
98
99                    event.setdefault("extra", {}).setdefault("web_url", self.uiWebUrl)
100
101                return event
102
103        return init
104
105    SparkContext._do_init = _sentry_patched_spark_context_init
106
107
108class SparkListener(object):
109    def onApplicationEnd(self, applicationEnd):  # noqa: N802,N803
110        # type: (Any) -> None
111        pass
112
113    def onApplicationStart(self, applicationStart):  # noqa: N802,N803
114        # type: (Any) -> None
115        pass
116
117    def onBlockManagerAdded(self, blockManagerAdded):  # noqa: N802,N803
118        # type: (Any) -> None
119        pass
120
121    def onBlockManagerRemoved(self, blockManagerRemoved):  # noqa: N802,N803
122        # type: (Any) -> None
123        pass
124
125    def onBlockUpdated(self, blockUpdated):  # noqa: N802,N803
126        # type: (Any) -> None
127        pass
128
129    def onEnvironmentUpdate(self, environmentUpdate):  # noqa: N802,N803
130        # type: (Any) -> None
131        pass
132
133    def onExecutorAdded(self, executorAdded):  # noqa: N802,N803
134        # type: (Any) -> None
135        pass
136
137    def onExecutorBlacklisted(self, executorBlacklisted):  # noqa: N802,N803
138        # type: (Any) -> None
139        pass
140
141    def onExecutorBlacklistedForStage(  # noqa: N802
142        self, executorBlacklistedForStage  # noqa: N803
143    ):
144        # type: (Any) -> None
145        pass
146
147    def onExecutorMetricsUpdate(self, executorMetricsUpdate):  # noqa: N802,N803
148        # type: (Any) -> None
149        pass
150
151    def onExecutorRemoved(self, executorRemoved):  # noqa: N802,N803
152        # type: (Any) -> None
153        pass
154
155    def onJobEnd(self, jobEnd):  # noqa: N802,N803
156        # type: (Any) -> None
157        pass
158
159    def onJobStart(self, jobStart):  # noqa: N802,N803
160        # type: (Any) -> None
161        pass
162
163    def onNodeBlacklisted(self, nodeBlacklisted):  # noqa: N802,N803
164        # type: (Any) -> None
165        pass
166
167    def onNodeBlacklistedForStage(self, nodeBlacklistedForStage):  # noqa: N802,N803
168        # type: (Any) -> None
169        pass
170
171    def onNodeUnblacklisted(self, nodeUnblacklisted):  # noqa: N802,N803
172        # type: (Any) -> None
173        pass
174
175    def onOtherEvent(self, event):  # noqa: N802,N803
176        # type: (Any) -> None
177        pass
178
179    def onSpeculativeTaskSubmitted(self, speculativeTask):  # noqa: N802,N803
180        # type: (Any) -> None
181        pass
182
183    def onStageCompleted(self, stageCompleted):  # noqa: N802,N803
184        # type: (Any) -> None
185        pass
186
187    def onStageSubmitted(self, stageSubmitted):  # noqa: N802,N803
188        # type: (Any) -> None
189        pass
190
191    def onTaskEnd(self, taskEnd):  # noqa: N802,N803
192        # type: (Any) -> None
193        pass
194
195    def onTaskGettingResult(self, taskGettingResult):  # noqa: N802,N803
196        # type: (Any) -> None
197        pass
198
199    def onTaskStart(self, taskStart):  # noqa: N802,N803
200        # type: (Any) -> None
201        pass
202
203    def onUnpersistRDD(self, unpersistRDD):  # noqa: N802,N803
204        # type: (Any) -> None
205        pass
206
207    class Java:
208        implements = ["org.apache.spark.scheduler.SparkListenerInterface"]
209
210
211class SentryListener(SparkListener):
212    def __init__(self):
213        # type: () -> None
214        self.hub = Hub.current
215
216    def onJobStart(self, jobStart):  # noqa: N802,N803
217        # type: (Any) -> None
218        message = "Job {} Started".format(jobStart.jobId())
219        self.hub.add_breadcrumb(level="info", message=message)
220        _set_app_properties()
221
222    def onJobEnd(self, jobEnd):  # noqa: N802,N803
223        # type: (Any) -> None
224        level = ""
225        message = ""
226        data = {"result": jobEnd.jobResult().toString()}
227
228        if jobEnd.jobResult().toString() == "JobSucceeded":
229            level = "info"
230            message = "Job {} Ended".format(jobEnd.jobId())
231        else:
232            level = "warning"
233            message = "Job {} Failed".format(jobEnd.jobId())
234
235        self.hub.add_breadcrumb(level=level, message=message, data=data)
236
237    def onStageSubmitted(self, stageSubmitted):  # noqa: N802,N803
238        # type: (Any) -> None
239        stage_info = stageSubmitted.stageInfo()
240        message = "Stage {} Submitted".format(stage_info.stageId())
241        data = {"attemptId": stage_info.attemptId(), "name": stage_info.name()}
242        self.hub.add_breadcrumb(level="info", message=message, data=data)
243        _set_app_properties()
244
245    def onStageCompleted(self, stageCompleted):  # noqa: N802,N803
246        # type: (Any) -> None
247        from py4j.protocol import Py4JJavaError  # type: ignore
248
249        stage_info = stageCompleted.stageInfo()
250        message = ""
251        level = ""
252        data = {"attemptId": stage_info.attemptId(), "name": stage_info.name()}
253
254        # Have to Try Except because stageInfo.failureReason() is typed with Scala Option
255        try:
256            data["reason"] = stage_info.failureReason().get()
257            message = "Stage {} Failed".format(stage_info.stageId())
258            level = "warning"
259        except Py4JJavaError:
260            message = "Stage {} Completed".format(stage_info.stageId())
261            level = "info"
262
263        self.hub.add_breadcrumb(level=level, message=message, data=data)
264