1#
2# Licensed to the Apache Software Foundation (ASF) under one or more
3# contributor license agreements.  See the NOTICE file distributed with
4# this work for additional information regarding copyright ownership.
5# The ASF licenses this file to You under the Apache License, Version 2.0
6# (the "License"); you may not use this file except in compliance with
7# the License.  You may obtain a copy of the License at
8#
9#    http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18import sys
19import warnings
20
21if sys.version >= '3':
22    basestring = str
23    long = int
24
25from pyspark import copy_func, since
26from pyspark.context import SparkContext
27from pyspark.rdd import ignore_unicode_prefix
28from pyspark.sql.types import *
29
30__all__ = ["DataFrame", "Column", "DataFrameNaFunctions", "DataFrameStatFunctions"]
31
32
33def _create_column_from_literal(literal):
34    sc = SparkContext._active_spark_context
35    return sc._jvm.functions.lit(literal)
36
37
38def _create_column_from_name(name):
39    sc = SparkContext._active_spark_context
40    return sc._jvm.functions.col(name)
41
42
43def _to_java_column(col):
44    if isinstance(col, Column):
45        jcol = col._jc
46    else:
47        jcol = _create_column_from_name(col)
48    return jcol
49
50
51def _to_seq(sc, cols, converter=None):
52    """
53    Convert a list of Column (or names) into a JVM Seq of Column.
54
55    An optional `converter` could be used to convert items in `cols`
56    into JVM Column objects.
57    """
58    if converter:
59        cols = [converter(c) for c in cols]
60    return sc._jvm.PythonUtils.toSeq(cols)
61
62
63def _to_list(sc, cols, converter=None):
64    """
65    Convert a list of Column (or names) into a JVM (Scala) List of Column.
66
67    An optional `converter` could be used to convert items in `cols`
68    into JVM Column objects.
69    """
70    if converter:
71        cols = [converter(c) for c in cols]
72    return sc._jvm.PythonUtils.toList(cols)
73
74
75def _unary_op(name, doc="unary operator"):
76    """ Create a method for given unary operator """
77    def _(self):
78        jc = getattr(self._jc, name)()
79        return Column(jc)
80    _.__doc__ = doc
81    return _
82
83
84def _func_op(name, doc=''):
85    def _(self):
86        sc = SparkContext._active_spark_context
87        jc = getattr(sc._jvm.functions, name)(self._jc)
88        return Column(jc)
89    _.__doc__ = doc
90    return _
91
92
93def _bin_func_op(name, reverse=False, doc="binary function"):
94    def _(self, other):
95        sc = SparkContext._active_spark_context
96        fn = getattr(sc._jvm.functions, name)
97        jc = other._jc if isinstance(other, Column) else _create_column_from_literal(other)
98        njc = fn(self._jc, jc) if not reverse else fn(jc, self._jc)
99        return Column(njc)
100    _.__doc__ = doc
101    return _
102
103
104def _bin_op(name, doc="binary operator"):
105    """ Create a method for given binary operator
106    """
107    def _(self, other):
108        jc = other._jc if isinstance(other, Column) else other
109        njc = getattr(self._jc, name)(jc)
110        return Column(njc)
111    _.__doc__ = doc
112    return _
113
114
115def _reverse_op(name, doc="binary operator"):
116    """ Create a method for binary operator (this object is on right side)
117    """
118    def _(self, other):
119        jother = _create_column_from_literal(other)
120        jc = getattr(jother, name)(self._jc)
121        return Column(jc)
122    _.__doc__ = doc
123    return _
124
125
126class Column(object):
127
128    """
129    A column in a DataFrame.
130
131    :class:`Column` instances can be created by::
132
133        # 1. Select a column out of a DataFrame
134
135        df.colName
136        df["colName"]
137
138        # 2. Create from an expression
139        df.colName + 1
140        1 / df.colName
141
142    .. versionadded:: 1.3
143    """
144
145    def __init__(self, jc):
146        self._jc = jc
147
148    # arithmetic operators
149    __neg__ = _func_op("negate")
150    __add__ = _bin_op("plus")
151    __sub__ = _bin_op("minus")
152    __mul__ = _bin_op("multiply")
153    __div__ = _bin_op("divide")
154    __truediv__ = _bin_op("divide")
155    __mod__ = _bin_op("mod")
156    __radd__ = _bin_op("plus")
157    __rsub__ = _reverse_op("minus")
158    __rmul__ = _bin_op("multiply")
159    __rdiv__ = _reverse_op("divide")
160    __rtruediv__ = _reverse_op("divide")
161    __rmod__ = _reverse_op("mod")
162    __pow__ = _bin_func_op("pow")
163    __rpow__ = _bin_func_op("pow", reverse=True)
164
165    # logistic operators
166    __eq__ = _bin_op("equalTo")
167    __ne__ = _bin_op("notEqual")
168    __lt__ = _bin_op("lt")
169    __le__ = _bin_op("leq")
170    __ge__ = _bin_op("geq")
171    __gt__ = _bin_op("gt")
172
173    # `and`, `or`, `not` cannot be overloaded in Python,
174    # so use bitwise operators as boolean operators
175    __and__ = _bin_op('and')
176    __or__ = _bin_op('or')
177    __invert__ = _func_op('not')
178    __rand__ = _bin_op("and")
179    __ror__ = _bin_op("or")
180
181    # container operators
182    __contains__ = _bin_op("contains")
183    __getitem__ = _bin_op("apply")
184
185    # bitwise operators
186    bitwiseOR = _bin_op("bitwiseOR")
187    bitwiseAND = _bin_op("bitwiseAND")
188    bitwiseXOR = _bin_op("bitwiseXOR")
189
190    @since(1.3)
191    def getItem(self, key):
192        """
193        An expression that gets an item at position ``ordinal`` out of a list,
194        or gets an item by key out of a dict.
195
196        >>> df = sc.parallelize([([1, 2], {"key": "value"})]).toDF(["l", "d"])
197        >>> df.select(df.l.getItem(0), df.d.getItem("key")).show()
198        +----+------+
199        |l[0]|d[key]|
200        +----+------+
201        |   1| value|
202        +----+------+
203        >>> df.select(df.l[0], df.d["key"]).show()
204        +----+------+
205        |l[0]|d[key]|
206        +----+------+
207        |   1| value|
208        +----+------+
209        """
210        return self[key]
211
212    @since(1.3)
213    def getField(self, name):
214        """
215        An expression that gets a field by name in a StructField.
216
217        >>> from pyspark.sql import Row
218        >>> df = sc.parallelize([Row(r=Row(a=1, b="b"))]).toDF()
219        >>> df.select(df.r.getField("b")).show()
220        +---+
221        |r.b|
222        +---+
223        |  b|
224        +---+
225        >>> df.select(df.r.a).show()
226        +---+
227        |r.a|
228        +---+
229        |  1|
230        +---+
231        """
232        return self[name]
233
234    def __getattr__(self, item):
235        if item.startswith("__"):
236            raise AttributeError(item)
237        return self.getField(item)
238
239    def __iter__(self):
240        raise TypeError("Column is not iterable")
241
242    # string methods
243    rlike = _bin_op("rlike")
244    like = _bin_op("like")
245    startswith = _bin_op("startsWith")
246    endswith = _bin_op("endsWith")
247
248    @ignore_unicode_prefix
249    @since(1.3)
250    def substr(self, startPos, length):
251        """
252        Return a :class:`Column` which is a substring of the column.
253
254        :param startPos: start position (int or Column)
255        :param length:  length of the substring (int or Column)
256
257        >>> df.select(df.name.substr(1, 3).alias("col")).collect()
258        [Row(col=u'Ali'), Row(col=u'Bob')]
259        """
260        if type(startPos) != type(length):
261            raise TypeError("Can not mix the type")
262        if isinstance(startPos, (int, long)):
263            jc = self._jc.substr(startPos, length)
264        elif isinstance(startPos, Column):
265            jc = self._jc.substr(startPos._jc, length._jc)
266        else:
267            raise TypeError("Unexpected type: %s" % type(startPos))
268        return Column(jc)
269
270    __getslice__ = substr
271
272    @ignore_unicode_prefix
273    @since(1.5)
274    def isin(self, *cols):
275        """
276        A boolean expression that is evaluated to true if the value of this
277        expression is contained by the evaluated values of the arguments.
278
279        >>> df[df.name.isin("Bob", "Mike")].collect()
280        [Row(age=5, name=u'Bob')]
281        >>> df[df.age.isin([1, 2, 3])].collect()
282        [Row(age=2, name=u'Alice')]
283        """
284        if len(cols) == 1 and isinstance(cols[0], (list, set)):
285            cols = cols[0]
286        cols = [c._jc if isinstance(c, Column) else _create_column_from_literal(c) for c in cols]
287        sc = SparkContext._active_spark_context
288        jc = getattr(self._jc, "isin")(_to_seq(sc, cols))
289        return Column(jc)
290
291    # order
292    asc = _unary_op("asc", "Returns a sort expression based on the"
293                           " ascending order of the given column name.")
294    desc = _unary_op("desc", "Returns a sort expression based on the"
295                             " descending order of the given column name.")
296
297    isNull = _unary_op("isNull", "True if the current expression is null.")
298    isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
299
300    @since(1.3)
301    def alias(self, *alias):
302        """
303        Returns this column aliased with a new name or names (in the case of expressions that
304        return more than one column, such as explode).
305
306        >>> df.select(df.age.alias("age2")).collect()
307        [Row(age2=2), Row(age2=5)]
308        """
309
310        if len(alias) == 1:
311            return Column(getattr(self._jc, "as")(alias[0]))
312        else:
313            sc = SparkContext._active_spark_context
314            return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
315
316    name = copy_func(alias, sinceversion=2.0, doc=":func:`name` is an alias for :func:`alias`.")
317
318    @ignore_unicode_prefix
319    @since(1.3)
320    def cast(self, dataType):
321        """ Convert the column into type ``dataType``.
322
323        >>> df.select(df.age.cast("string").alias('ages')).collect()
324        [Row(ages=u'2'), Row(ages=u'5')]
325        >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
326        [Row(ages=u'2'), Row(ages=u'5')]
327        """
328        if isinstance(dataType, basestring):
329            jc = self._jc.cast(dataType)
330        elif isinstance(dataType, DataType):
331            from pyspark.sql import SparkSession
332            spark = SparkSession.builder.getOrCreate()
333            jdt = spark._jsparkSession.parseDataType(dataType.json())
334            jc = self._jc.cast(jdt)
335        else:
336            raise TypeError("unexpected type: %s" % type(dataType))
337        return Column(jc)
338
339    astype = copy_func(cast, sinceversion=1.4, doc=":func:`astype` is an alias for :func:`cast`.")
340
341    @since(1.3)
342    def between(self, lowerBound, upperBound):
343        """
344        A boolean expression that is evaluated to true if the value of this
345        expression is between the given columns.
346
347        >>> df.select(df.name, df.age.between(2, 4)).show()
348        +-----+---------------------------+
349        | name|((age >= 2) AND (age <= 4))|
350        +-----+---------------------------+
351        |Alice|                       true|
352        |  Bob|                      false|
353        +-----+---------------------------+
354        """
355        return (self >= lowerBound) & (self <= upperBound)
356
357    @since(1.4)
358    def when(self, condition, value):
359        """
360        Evaluates a list of conditions and returns one of multiple possible result expressions.
361        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
362
363        See :func:`pyspark.sql.functions.when` for example usage.
364
365        :param condition: a boolean :class:`Column` expression.
366        :param value: a literal value, or a :class:`Column` expression.
367
368        >>> from pyspark.sql import functions as F
369        >>> df.select(df.name, F.when(df.age > 4, 1).when(df.age < 3, -1).otherwise(0)).show()
370        +-----+------------------------------------------------------------+
371        | name|CASE WHEN (age > 4) THEN 1 WHEN (age < 3) THEN -1 ELSE 0 END|
372        +-----+------------------------------------------------------------+
373        |Alice|                                                          -1|
374        |  Bob|                                                           1|
375        +-----+------------------------------------------------------------+
376        """
377        if not isinstance(condition, Column):
378            raise TypeError("condition should be a Column")
379        v = value._jc if isinstance(value, Column) else value
380        jc = self._jc.when(condition._jc, v)
381        return Column(jc)
382
383    @since(1.4)
384    def otherwise(self, value):
385        """
386        Evaluates a list of conditions and returns one of multiple possible result expressions.
387        If :func:`Column.otherwise` is not invoked, None is returned for unmatched conditions.
388
389        See :func:`pyspark.sql.functions.when` for example usage.
390
391        :param value: a literal value, or a :class:`Column` expression.
392
393        >>> from pyspark.sql import functions as F
394        >>> df.select(df.name, F.when(df.age > 3, 1).otherwise(0)).show()
395        +-----+-------------------------------------+
396        | name|CASE WHEN (age > 3) THEN 1 ELSE 0 END|
397        +-----+-------------------------------------+
398        |Alice|                                    0|
399        |  Bob|                                    1|
400        +-----+-------------------------------------+
401        """
402        v = value._jc if isinstance(value, Column) else value
403        jc = self._jc.otherwise(v)
404        return Column(jc)
405
406    @since(1.4)
407    def over(self, window):
408        """
409        Define a windowing column.
410
411        :param window: a :class:`WindowSpec`
412        :return: a Column
413
414        >>> from pyspark.sql import Window
415        >>> window = Window.partitionBy("name").orderBy("age").rowsBetween(-1, 1)
416        >>> from pyspark.sql.functions import rank, min
417        >>> # df.select(rank().over(window), min('age').over(window))
418        """
419        from pyspark.sql.window import WindowSpec
420        if not isinstance(window, WindowSpec):
421            raise TypeError("window should be WindowSpec")
422        jc = self._jc.over(window._jspec)
423        return Column(jc)
424
425    def __nonzero__(self):
426        raise ValueError("Cannot convert column into bool: please use '&' for 'and', '|' for 'or', "
427                         "'~' for 'not' when building DataFrame boolean expressions.")
428    __bool__ = __nonzero__
429
430    def __repr__(self):
431        return 'Column<%s>' % self._jc.toString().encode('utf8')
432
433
434def _test():
435    import doctest
436    from pyspark.sql import SparkSession
437    import pyspark.sql.column
438    globs = pyspark.sql.column.__dict__.copy()
439    spark = SparkSession.builder\
440        .master("local[4]")\
441        .appName("sql.column tests")\
442        .getOrCreate()
443    sc = spark.sparkContext
444    globs['sc'] = sc
445    globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \
446        .toDF(StructType([StructField('age', IntegerType()),
447                          StructField('name', StringType())]))
448
449    (failure_count, test_count) = doctest.testmod(
450        pyspark.sql.column, globs=globs,
451        optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
452    spark.stop()
453    if failure_count:
454        exit(-1)
455
456
457if __name__ == "__main__":
458    _test()
459