1# 2# Licensed to the Apache Software Foundation (ASF) under one or more 3# contributor license agreements. See the NOTICE file distributed with 4# this work for additional information regarding copyright ownership. 5# The ASF licenses this file to You under the Apache License, Version 2.0 6# (the "License"); you may not use this file except in compliance with 7# the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16# 17 18import sys 19if sys.version >= '3': 20 long = int 21 unicode = str 22 23import py4j.protocol 24from py4j.protocol import Py4JJavaError 25from py4j.java_gateway import JavaObject 26from py4j.java_collections import JavaArray, JavaList 27 28from pyspark import RDD, SparkContext 29from pyspark.serializers import PickleSerializer, AutoBatchedSerializer 30from pyspark.sql import DataFrame, SQLContext 31 32# Hack for support float('inf') in Py4j 33_old_smart_decode = py4j.protocol.smart_decode 34 35_float_str_mapping = { 36 'nan': 'NaN', 37 'inf': 'Infinity', 38 '-inf': '-Infinity', 39} 40 41 42def _new_smart_decode(obj): 43 if isinstance(obj, float): 44 s = str(obj) 45 return _float_str_mapping.get(s, s) 46 return _old_smart_decode(obj) 47 48py4j.protocol.smart_decode = _new_smart_decode 49 50 51_picklable_classes = [ 52 'SparseVector', 53 'DenseVector', 54 'SparseMatrix', 55 'DenseMatrix', 56] 57 58 59# this will call the ML version of pythonToJava() 60def _to_java_object_rdd(rdd): 61 """ Return an JavaRDD of Object by unpickling 62 63 It will convert each Python object into Java object by Pyrolite, whenever the 64 RDD is serialized in batch or not. 65 """ 66 rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) 67 return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True) 68 69 70def _py2java(sc, obj): 71 """ Convert Python object into Java """ 72 if isinstance(obj, RDD): 73 obj = _to_java_object_rdd(obj) 74 elif isinstance(obj, DataFrame): 75 obj = obj._jdf 76 elif isinstance(obj, SparkContext): 77 obj = obj._jsc 78 elif isinstance(obj, list): 79 obj = [_py2java(sc, x) for x in obj] 80 elif isinstance(obj, JavaObject): 81 pass 82 elif isinstance(obj, (int, long, float, bool, bytes, unicode)): 83 pass 84 else: 85 data = bytearray(PickleSerializer().dumps(obj)) 86 obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data) 87 return obj 88 89 90def _java2py(sc, r, encoding="bytes"): 91 if isinstance(r, JavaObject): 92 clsName = r.getClass().getSimpleName() 93 # convert RDD into JavaRDD 94 if clsName != 'JavaRDD' and clsName.endswith("RDD"): 95 r = r.toJavaRDD() 96 clsName = 'JavaRDD' 97 98 if clsName == 'JavaRDD': 99 jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r) 100 return RDD(jrdd, sc) 101 102 if clsName == 'Dataset': 103 return DataFrame(r, SQLContext.getOrCreate(sc)) 104 105 if clsName in _picklable_classes: 106 r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) 107 elif isinstance(r, (JavaArray, JavaList)): 108 try: 109 r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r) 110 except Py4JJavaError: 111 pass # not pickable 112 113 if isinstance(r, (bytearray, bytes)): 114 r = PickleSerializer().loads(bytes(r), encoding=encoding) 115 return r 116 117 118def callJavaFunc(sc, func, *args): 119 """ Call Java Function """ 120 args = [_py2java(sc, a) for a in args] 121 return _java2py(sc, func(*args)) 122 123 124def inherit_doc(cls): 125 """ 126 A decorator that makes a class inherit documentation from its parents. 127 """ 128 for name, func in vars(cls).items(): 129 # only inherit docstring for public functions 130 if name.startswith("_"): 131 continue 132 if not func.__doc__: 133 for parent in cls.__bases__: 134 parent_func = getattr(parent, name, None) 135 if parent_func and getattr(parent_func, "__doc__", None): 136 func.__doc__ = parent_func.__doc__ 137 break 138 return cls 139