1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.sql.catalyst.expressions
19
20import java.sql.{Date, Timestamp}
21import java.util.{Calendar, TimeZone}
22
23import org.apache.spark.SparkFunSuite
24import org.apache.spark.sql.Row
25import org.apache.spark.sql.catalyst.InternalRow
26import org.apache.spark.sql.catalyst.util.DateTimeUtils
27import org.apache.spark.sql.types._
28import org.apache.spark.unsafe.types.UTF8String
29
30/**
31 * Test suite for data type casting expression [[Cast]].
32 */
33class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
34
35  private def cast(v: Any, targetType: DataType): Cast = {
36    v match {
37      case lit: Expression => Cast(lit, targetType)
38      case _ => Cast(Literal(v), targetType)
39    }
40  }
41
42  // expected cannot be null
43  private def checkCast(v: Any, expected: Any): Unit = {
44    checkEvaluation(cast(v, Literal(expected).dataType), expected)
45  }
46
47  private def checkNullCast(from: DataType, to: DataType): Unit = {
48    checkEvaluation(Cast(Literal.create(null, from), to), null)
49  }
50
51  test("null cast") {
52    import DataTypeTestUtils._
53
54    // follow [[org.apache.spark.sql.catalyst.expressions.Cast.canCast]] logic
55    // to ensure we test every possible cast situation here
56    atomicTypes.zip(atomicTypes).foreach { case (from, to) =>
57      checkNullCast(from, to)
58    }
59
60    atomicTypes.foreach(dt => checkNullCast(NullType, dt))
61    atomicTypes.foreach(dt => checkNullCast(dt, StringType))
62    checkNullCast(StringType, BinaryType)
63    checkNullCast(StringType, BooleanType)
64    checkNullCast(DateType, BooleanType)
65    checkNullCast(TimestampType, BooleanType)
66    numericTypes.foreach(dt => checkNullCast(dt, BooleanType))
67
68    checkNullCast(StringType, TimestampType)
69    checkNullCast(BooleanType, TimestampType)
70    checkNullCast(DateType, TimestampType)
71    numericTypes.foreach(dt => checkNullCast(dt, TimestampType))
72
73    checkNullCast(StringType, DateType)
74    checkNullCast(TimestampType, DateType)
75
76    checkNullCast(StringType, CalendarIntervalType)
77    numericTypes.foreach(dt => checkNullCast(StringType, dt))
78    numericTypes.foreach(dt => checkNullCast(BooleanType, dt))
79    numericTypes.foreach(dt => checkNullCast(DateType, dt))
80    numericTypes.foreach(dt => checkNullCast(TimestampType, dt))
81    for (from <- numericTypes; to <- numericTypes) checkNullCast(from, to)
82  }
83
84  test("cast string to date") {
85    var c = Calendar.getInstance()
86    c.set(2015, 0, 1, 0, 0, 0)
87    c.set(Calendar.MILLISECOND, 0)
88    checkEvaluation(Cast(Literal("2015"), DateType), new Date(c.getTimeInMillis))
89    c = Calendar.getInstance()
90    c.set(2015, 2, 1, 0, 0, 0)
91    c.set(Calendar.MILLISECOND, 0)
92    checkEvaluation(Cast(Literal("2015-03"), DateType), new Date(c.getTimeInMillis))
93    c = Calendar.getInstance()
94    c.set(2015, 2, 18, 0, 0, 0)
95    c.set(Calendar.MILLISECOND, 0)
96    checkEvaluation(Cast(Literal("2015-03-18"), DateType), new Date(c.getTimeInMillis))
97    checkEvaluation(Cast(Literal("2015-03-18 "), DateType), new Date(c.getTimeInMillis))
98    checkEvaluation(Cast(Literal("2015-03-18 123142"), DateType), new Date(c.getTimeInMillis))
99    checkEvaluation(Cast(Literal("2015-03-18T123123"), DateType), new Date(c.getTimeInMillis))
100    checkEvaluation(Cast(Literal("2015-03-18T"), DateType), new Date(c.getTimeInMillis))
101
102    checkEvaluation(Cast(Literal("2015-03-18X"), DateType), null)
103    checkEvaluation(Cast(Literal("2015/03/18"), DateType), null)
104    checkEvaluation(Cast(Literal("2015.03.18"), DateType), null)
105    checkEvaluation(Cast(Literal("20150318"), DateType), null)
106    checkEvaluation(Cast(Literal("2015-031-8"), DateType), null)
107  }
108
109  test("cast string to timestamp") {
110    checkEvaluation(Cast(Literal("123"), TimestampType), null)
111
112    var c = Calendar.getInstance()
113    c.set(2015, 0, 1, 0, 0, 0)
114    c.set(Calendar.MILLISECOND, 0)
115    checkEvaluation(Cast(Literal("2015"), TimestampType),
116      new Timestamp(c.getTimeInMillis))
117    c = Calendar.getInstance()
118    c.set(2015, 2, 1, 0, 0, 0)
119    c.set(Calendar.MILLISECOND, 0)
120    checkEvaluation(Cast(Literal("2015-03"), TimestampType),
121      new Timestamp(c.getTimeInMillis))
122    c = Calendar.getInstance()
123    c.set(2015, 2, 18, 0, 0, 0)
124    c.set(Calendar.MILLISECOND, 0)
125    checkEvaluation(Cast(Literal("2015-03-18"), TimestampType),
126      new Timestamp(c.getTimeInMillis))
127    checkEvaluation(Cast(Literal("2015-03-18 "), TimestampType),
128      new Timestamp(c.getTimeInMillis))
129    checkEvaluation(Cast(Literal("2015-03-18T"), TimestampType),
130      new Timestamp(c.getTimeInMillis))
131
132    c = Calendar.getInstance()
133    c.set(2015, 2, 18, 12, 3, 17)
134    c.set(Calendar.MILLISECOND, 0)
135    checkEvaluation(Cast(Literal("2015-03-18 12:03:17"), TimestampType),
136      new Timestamp(c.getTimeInMillis))
137    checkEvaluation(Cast(Literal("2015-03-18T12:03:17"), TimestampType),
138      new Timestamp(c.getTimeInMillis))
139
140    c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
141    c.set(2015, 2, 18, 12, 3, 17)
142    c.set(Calendar.MILLISECOND, 0)
143    checkEvaluation(Cast(Literal("2015-03-18T12:03:17Z"), TimestampType),
144      new Timestamp(c.getTimeInMillis))
145    checkEvaluation(Cast(Literal("2015-03-18 12:03:17Z"), TimestampType),
146      new Timestamp(c.getTimeInMillis))
147
148    c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
149    c.set(2015, 2, 18, 12, 3, 17)
150    c.set(Calendar.MILLISECOND, 0)
151    checkEvaluation(Cast(Literal("2015-03-18T12:03:17-1:0"), TimestampType),
152      new Timestamp(c.getTimeInMillis))
153    checkEvaluation(Cast(Literal("2015-03-18T12:03:17-01:00"), TimestampType),
154      new Timestamp(c.getTimeInMillis))
155
156    c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
157    c.set(2015, 2, 18, 12, 3, 17)
158    c.set(Calendar.MILLISECOND, 0)
159    checkEvaluation(Cast(Literal("2015-03-18T12:03:17+07:30"), TimestampType),
160      new Timestamp(c.getTimeInMillis))
161
162    c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
163    c.set(2015, 2, 18, 12, 3, 17)
164    c.set(Calendar.MILLISECOND, 0)
165    checkEvaluation(Cast(Literal("2015-03-18T12:03:17+7:3"), TimestampType),
166      new Timestamp(c.getTimeInMillis))
167
168    c = Calendar.getInstance()
169    c.set(2015, 2, 18, 12, 3, 17)
170    c.set(Calendar.MILLISECOND, 123)
171    checkEvaluation(Cast(Literal("2015-03-18 12:03:17.123"), TimestampType),
172      new Timestamp(c.getTimeInMillis))
173    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123"), TimestampType),
174      new Timestamp(c.getTimeInMillis))
175
176    c = Calendar.getInstance(TimeZone.getTimeZone("UTC"))
177    c.set(2015, 2, 18, 12, 3, 17)
178    c.set(Calendar.MILLISECOND, 456)
179    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.456Z"), TimestampType),
180      new Timestamp(c.getTimeInMillis))
181    checkEvaluation(Cast(Literal("2015-03-18 12:03:17.456Z"), TimestampType),
182      new Timestamp(c.getTimeInMillis))
183
184    c = Calendar.getInstance(TimeZone.getTimeZone("GMT-01:00"))
185    c.set(2015, 2, 18, 12, 3, 17)
186    c.set(Calendar.MILLISECOND, 123)
187    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-1:0"), TimestampType),
188      new Timestamp(c.getTimeInMillis))
189    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123-01:00"), TimestampType),
190      new Timestamp(c.getTimeInMillis))
191
192    c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:30"))
193    c.set(2015, 2, 18, 12, 3, 17)
194    c.set(Calendar.MILLISECOND, 123)
195    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+07:30"), TimestampType),
196      new Timestamp(c.getTimeInMillis))
197
198    c = Calendar.getInstance(TimeZone.getTimeZone("GMT+07:03"))
199    c.set(2015, 2, 18, 12, 3, 17)
200    c.set(Calendar.MILLISECOND, 123)
201    checkEvaluation(Cast(Literal("2015-03-18T12:03:17.123+7:3"), TimestampType),
202      new Timestamp(c.getTimeInMillis))
203
204    checkEvaluation(Cast(Literal("2015-03-18 123142"), TimestampType), null)
205    checkEvaluation(Cast(Literal("2015-03-18T123123"), TimestampType), null)
206    checkEvaluation(Cast(Literal("2015-03-18X"), TimestampType), null)
207    checkEvaluation(Cast(Literal("2015/03/18"), TimestampType), null)
208    checkEvaluation(Cast(Literal("2015.03.18"), TimestampType), null)
209    checkEvaluation(Cast(Literal("20150318"), TimestampType), null)
210    checkEvaluation(Cast(Literal("2015-031-8"), TimestampType), null)
211    checkEvaluation(Cast(Literal("2015-03-18T12:03:17-0:70"), TimestampType), null)
212  }
213
214  test("cast from int") {
215    checkCast(0, false)
216    checkCast(1, true)
217    checkCast(-5, true)
218    checkCast(1, 1.toByte)
219    checkCast(1, 1.toShort)
220    checkCast(1, 1)
221    checkCast(1, 1.toLong)
222    checkCast(1, 1.0f)
223    checkCast(1, 1.0)
224    checkCast(123, "123")
225
226    checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
227    checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
228    checkEvaluation(cast(123, DecimalType(3, 1)), null)
229    checkEvaluation(cast(123, DecimalType(2, 0)), null)
230  }
231
232  test("cast from long") {
233    checkCast(0L, false)
234    checkCast(1L, true)
235    checkCast(-5L, true)
236    checkCast(1L, 1.toByte)
237    checkCast(1L, 1.toShort)
238    checkCast(1L, 1)
239    checkCast(1L, 1.toLong)
240    checkCast(1L, 1.0f)
241    checkCast(1L, 1.0)
242    checkCast(123L, "123")
243
244    checkEvaluation(cast(123L, DecimalType.USER_DEFAULT), Decimal(123))
245    checkEvaluation(cast(123L, DecimalType(3, 0)), Decimal(123))
246    checkEvaluation(cast(123L, DecimalType(3, 1)), null)
247
248    checkEvaluation(cast(123L, DecimalType(2, 0)), null)
249  }
250
251  test("cast from boolean") {
252    checkEvaluation(cast(true, IntegerType), 1)
253    checkEvaluation(cast(false, IntegerType), 0)
254    checkEvaluation(cast(true, StringType), "true")
255    checkEvaluation(cast(false, StringType), "false")
256    checkEvaluation(cast(cast(1, BooleanType), IntegerType), 1)
257    checkEvaluation(cast(cast(0, BooleanType), IntegerType), 0)
258  }
259
260  test("cast from int 2") {
261    checkEvaluation(cast(1, LongType), 1.toLong)
262    checkEvaluation(cast(cast(1000, TimestampType), LongType), 1000.toLong)
263    checkEvaluation(cast(cast(-1200, TimestampType), LongType), -1200.toLong)
264
265    checkEvaluation(cast(123, DecimalType.USER_DEFAULT), Decimal(123))
266    checkEvaluation(cast(123, DecimalType(3, 0)), Decimal(123))
267    checkEvaluation(cast(123, DecimalType(3, 1)), null)
268    checkEvaluation(cast(123, DecimalType(2, 0)), null)
269  }
270
271  test("cast from float") {
272    checkCast(0.0f, false)
273    checkCast(0.5f, true)
274    checkCast(-5.0f, true)
275    checkCast(1.5f, 1.toByte)
276    checkCast(1.5f, 1.toShort)
277    checkCast(1.5f, 1)
278    checkCast(1.5f, 1.toLong)
279    checkCast(1.5f, 1.5)
280    checkCast(1.5f, "1.5")
281  }
282
283  test("cast from double") {
284    checkCast(0.0, false)
285    checkCast(0.5, true)
286    checkCast(-5.0, true)
287    checkCast(1.5, 1.toByte)
288    checkCast(1.5, 1.toShort)
289    checkCast(1.5, 1)
290    checkCast(1.5, 1.toLong)
291    checkCast(1.5, 1.5f)
292    checkCast(1.5, "1.5")
293
294    checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
295    checkEvaluation(cast(cast(1.toDouble, TimestampType), DoubleType), 1.toDouble)
296  }
297
298  test("cast from string") {
299    assert(cast("abcdef", StringType).nullable === false)
300    assert(cast("abcdef", BinaryType).nullable === false)
301    assert(cast("abcdef", BooleanType).nullable === true)
302    assert(cast("abcdef", TimestampType).nullable === true)
303    assert(cast("abcdef", LongType).nullable === true)
304    assert(cast("abcdef", IntegerType).nullable === true)
305    assert(cast("abcdef", ShortType).nullable === true)
306    assert(cast("abcdef", ByteType).nullable === true)
307    assert(cast("abcdef", DecimalType.USER_DEFAULT).nullable === true)
308    assert(cast("abcdef", DecimalType(4, 2)).nullable === true)
309    assert(cast("abcdef", DoubleType).nullable === true)
310    assert(cast("abcdef", FloatType).nullable === true)
311  }
312
313  test("data type casting") {
314    val sd = "1970-01-01"
315    val d = Date.valueOf(sd)
316    val zts = sd + " 00:00:00"
317    val sts = sd + " 00:00:02"
318    val nts = sts + ".1"
319    val ts = Timestamp.valueOf(nts)
320
321    var c = Calendar.getInstance()
322    c.set(2015, 2, 8, 2, 30, 0)
323    checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType),
324      c.getTimeInMillis * 1000)
325    c = Calendar.getInstance()
326    c.set(2015, 10, 1, 2, 30, 0)
327    checkEvaluation(cast(cast(new Timestamp(c.getTimeInMillis), StringType), TimestampType),
328      c.getTimeInMillis * 1000)
329
330    checkEvaluation(cast("abdef", StringType), "abdef")
331    checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)
332    checkEvaluation(cast("abdef", TimestampType), null)
333    checkEvaluation(cast("12.65", DecimalType.SYSTEM_DEFAULT), Decimal(12.65))
334
335    checkEvaluation(cast(cast(sd, DateType), StringType), sd)
336    checkEvaluation(cast(cast(d, StringType), DateType), 0)
337    checkEvaluation(cast(cast(nts, TimestampType), StringType), nts)
338    checkEvaluation(cast(cast(ts, StringType), TimestampType), DateTimeUtils.fromJavaTimestamp(ts))
339
340    // all convert to string type to check
341    checkEvaluation(cast(cast(cast(nts, TimestampType), DateType), StringType), sd)
342    checkEvaluation(cast(cast(cast(ts, DateType), TimestampType), StringType), zts)
343
344    checkEvaluation(cast(cast("abdef", BinaryType), StringType), "abdef")
345
346    checkEvaluation(cast(cast(cast(cast(
347      cast(cast("5", ByteType), ShortType), IntegerType), FloatType), DoubleType), LongType),
348      5.toLong)
349    checkEvaluation(
350      cast(cast(cast(cast(cast(cast("5", ByteType), TimestampType),
351        DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
352      5.toShort)
353    checkEvaluation(
354      cast(cast(cast(cast(cast(cast("5", TimestampType), ByteType),
355        DecimalType.SYSTEM_DEFAULT), LongType), StringType), ShortType),
356      null)
357    checkEvaluation(cast(cast(cast(cast(cast(cast("5", DecimalType.SYSTEM_DEFAULT),
358      ByteType), TimestampType), LongType), StringType), ShortType),
359      5.toShort)
360
361    checkEvaluation(cast("23", DoubleType), 23d)
362    checkEvaluation(cast("23", IntegerType), 23)
363    checkEvaluation(cast("23", FloatType), 23f)
364    checkEvaluation(cast("23", DecimalType.USER_DEFAULT), Decimal(23))
365    checkEvaluation(cast("23", ByteType), 23.toByte)
366    checkEvaluation(cast("23", ShortType), 23.toShort)
367    checkEvaluation(cast("2012-12-11", DoubleType), null)
368    checkEvaluation(cast(123, IntegerType), 123)
369
370    checkEvaluation(cast(Literal.create(null, IntegerType), ShortType), null)
371  }
372
373  test("cast and add") {
374    checkEvaluation(Add(Literal(23d), cast(true, DoubleType)), 24d)
375    checkEvaluation(Add(Literal(23), cast(true, IntegerType)), 24)
376    checkEvaluation(Add(Literal(23f), cast(true, FloatType)), 24f)
377    checkEvaluation(Add(Literal(Decimal(23)), cast(true, DecimalType.USER_DEFAULT)), Decimal(24))
378    checkEvaluation(Add(Literal(23.toByte), cast(true, ByteType)), 24.toByte)
379    checkEvaluation(Add(Literal(23.toShort), cast(true, ShortType)), 24.toShort)
380  }
381
382  test("from decimal") {
383    checkCast(Decimal(0.0), false)
384    checkCast(Decimal(0.5), true)
385    checkCast(Decimal(-5.0), true)
386    checkCast(Decimal(1.5), 1.toByte)
387    checkCast(Decimal(1.5), 1.toShort)
388    checkCast(Decimal(1.5), 1)
389    checkCast(Decimal(1.5), 1.toLong)
390    checkCast(Decimal(1.5), 1.5f)
391    checkCast(Decimal(1.5), 1.5)
392    checkCast(Decimal(1.5), "1.5")
393  }
394
395  test("casting to fixed-precision decimals") {
396    // Overflow and rounding for casting to fixed-precision decimals:
397    // - Values should round with HALF_UP mode by default when you lower scale
398    // - Values that would overflow the target precision should turn into null
399    // - Because of this, casts to fixed-precision decimals should be nullable
400
401    assert(cast(123, DecimalType.USER_DEFAULT).nullable === true)
402    assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true)
403    assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true)
404    assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true)
405
406    assert(cast(123, DecimalType(2, 1)).nullable === true)
407    assert(cast(10.03f, DecimalType(2, 1)).nullable === true)
408    assert(cast(10.03, DecimalType(2, 1)).nullable === true)
409    assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true)
410
411
412    checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03))
413    checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03))
414    checkEvaluation(cast(10.03, DecimalType(3, 1)), Decimal(10.0))
415    checkEvaluation(cast(10.03, DecimalType(2, 0)), Decimal(10))
416    checkEvaluation(cast(10.03, DecimalType(1, 0)), null)
417    checkEvaluation(cast(10.03, DecimalType(2, 1)), null)
418    checkEvaluation(cast(10.03, DecimalType(3, 2)), null)
419    checkEvaluation(cast(Decimal(10.03), DecimalType(3, 1)), Decimal(10.0))
420    checkEvaluation(cast(Decimal(10.03), DecimalType(3, 2)), null)
421
422    checkEvaluation(cast(10.05, DecimalType.SYSTEM_DEFAULT), Decimal(10.05))
423    checkEvaluation(cast(10.05, DecimalType(4, 2)), Decimal(10.05))
424    checkEvaluation(cast(10.05, DecimalType(3, 1)), Decimal(10.1))
425    checkEvaluation(cast(10.05, DecimalType(2, 0)), Decimal(10))
426    checkEvaluation(cast(10.05, DecimalType(1, 0)), null)
427    checkEvaluation(cast(10.05, DecimalType(2, 1)), null)
428    checkEvaluation(cast(10.05, DecimalType(3, 2)), null)
429    checkEvaluation(cast(Decimal(10.05), DecimalType(3, 1)), Decimal(10.1))
430    checkEvaluation(cast(Decimal(10.05), DecimalType(3, 2)), null)
431
432    checkEvaluation(cast(9.95, DecimalType(3, 2)), Decimal(9.95))
433    checkEvaluation(cast(9.95, DecimalType(3, 1)), Decimal(10.0))
434    checkEvaluation(cast(9.95, DecimalType(2, 0)), Decimal(10))
435    checkEvaluation(cast(9.95, DecimalType(2, 1)), null)
436    checkEvaluation(cast(9.95, DecimalType(1, 0)), null)
437    checkEvaluation(cast(Decimal(9.95), DecimalType(3, 1)), Decimal(10.0))
438    checkEvaluation(cast(Decimal(9.95), DecimalType(1, 0)), null)
439
440    checkEvaluation(cast(-9.95, DecimalType(3, 2)), Decimal(-9.95))
441    checkEvaluation(cast(-9.95, DecimalType(3, 1)), Decimal(-10.0))
442    checkEvaluation(cast(-9.95, DecimalType(2, 0)), Decimal(-10))
443    checkEvaluation(cast(-9.95, DecimalType(2, 1)), null)
444    checkEvaluation(cast(-9.95, DecimalType(1, 0)), null)
445    checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0))
446    checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null)
447
448    checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null)
449    checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null)
450    checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null)
451    checkEvaluation(cast(1.0f / 0.0f, DecimalType.SYSTEM_DEFAULT), null)
452
453    checkEvaluation(cast(Double.NaN, DecimalType(2, 1)), null)
454    checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null)
455    checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null)
456    checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null)
457  }
458
459  test("cast from date") {
460    val d = Date.valueOf("1970-01-01")
461    checkEvaluation(cast(d, ShortType), null)
462    checkEvaluation(cast(d, IntegerType), null)
463    checkEvaluation(cast(d, LongType), null)
464    checkEvaluation(cast(d, FloatType), null)
465    checkEvaluation(cast(d, DoubleType), null)
466    checkEvaluation(cast(d, DecimalType.SYSTEM_DEFAULT), null)
467    checkEvaluation(cast(d, DecimalType(10, 2)), null)
468    checkEvaluation(cast(d, StringType), "1970-01-01")
469    checkEvaluation(cast(cast(d, TimestampType), StringType), "1970-01-01 00:00:00")
470  }
471
472  test("cast from timestamp") {
473    val millis = 15 * 1000 + 3
474    val seconds = millis * 1000 + 3
475    val ts = new Timestamp(millis)
476    val tss = new Timestamp(seconds)
477    checkEvaluation(cast(ts, ShortType), 15.toShort)
478    checkEvaluation(cast(ts, IntegerType), 15)
479    checkEvaluation(cast(ts, LongType), 15.toLong)
480    checkEvaluation(cast(ts, FloatType), 15.003f)
481    checkEvaluation(cast(ts, DoubleType), 15.003)
482    checkEvaluation(cast(cast(tss, ShortType), TimestampType),
483      DateTimeUtils.fromJavaTimestamp(ts) * 1000)
484    checkEvaluation(cast(cast(tss, IntegerType), TimestampType),
485      DateTimeUtils.fromJavaTimestamp(ts) * 1000)
486    checkEvaluation(cast(cast(tss, LongType), TimestampType),
487      DateTimeUtils.fromJavaTimestamp(ts) * 1000)
488    checkEvaluation(
489      cast(cast(millis.toFloat / 1000, TimestampType), FloatType),
490      millis.toFloat / 1000)
491    checkEvaluation(
492      cast(cast(millis.toDouble / 1000, TimestampType), DoubleType),
493      millis.toDouble / 1000)
494    checkEvaluation(
495      cast(cast(Decimal(1), TimestampType), DecimalType.SYSTEM_DEFAULT),
496      Decimal(1))
497
498    // A test for higher precision than millis
499    checkEvaluation(cast(cast(0.000001, TimestampType), DoubleType), 0.000001)
500
501    checkEvaluation(cast(Double.NaN, TimestampType), null)
502    checkEvaluation(cast(1.0 / 0.0, TimestampType), null)
503    checkEvaluation(cast(Float.NaN, TimestampType), null)
504    checkEvaluation(cast(1.0f / 0.0f, TimestampType), null)
505  }
506
507  test("cast from array") {
508    val array = Literal.create(Seq("123", "true", "f", null),
509      ArrayType(StringType, containsNull = true))
510    val array_notNull = Literal.create(Seq("123", "true", "f"),
511      ArrayType(StringType, containsNull = false))
512
513    checkNullCast(ArrayType(StringType), ArrayType(IntegerType))
514
515    {
516      val ret = cast(array, ArrayType(IntegerType, containsNull = true))
517      assert(ret.resolved === true)
518      checkEvaluation(ret, Seq(123, null, null, null))
519    }
520    {
521      val ret = cast(array, ArrayType(IntegerType, containsNull = false))
522      assert(ret.resolved === false)
523    }
524    {
525      val ret = cast(array, ArrayType(BooleanType, containsNull = true))
526      assert(ret.resolved === true)
527      checkEvaluation(ret, Seq(null, true, false, null))
528    }
529    {
530      val ret = cast(array, ArrayType(BooleanType, containsNull = false))
531      assert(ret.resolved === false)
532    }
533
534    {
535      val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = true))
536      assert(ret.resolved === true)
537      checkEvaluation(ret, Seq(123, null, null))
538    }
539    {
540      val ret = cast(array_notNull, ArrayType(IntegerType, containsNull = false))
541      assert(ret.resolved === false)
542    }
543    {
544      val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = true))
545      assert(ret.resolved === true)
546      checkEvaluation(ret, Seq(null, true, false))
547    }
548    {
549      val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false))
550      assert(ret.resolved === false)
551    }
552
553    {
554      val ret = cast(array, IntegerType)
555      assert(ret.resolved === false)
556    }
557  }
558
559  test("cast from map") {
560    val map = Literal.create(
561      Map("a" -> "123", "b" -> "true", "c" -> "f", "d" -> null),
562      MapType(StringType, StringType, valueContainsNull = true))
563    val map_notNull = Literal.create(
564      Map("a" -> "123", "b" -> "true", "c" -> "f"),
565      MapType(StringType, StringType, valueContainsNull = false))
566
567    checkNullCast(MapType(StringType, IntegerType), MapType(StringType, StringType))
568
569    {
570      val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = true))
571      assert(ret.resolved === true)
572      checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null, "d" -> null))
573    }
574    {
575      val ret = cast(map, MapType(StringType, IntegerType, valueContainsNull = false))
576      assert(ret.resolved === false)
577    }
578    {
579      val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = true))
580      assert(ret.resolved === true)
581      checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false, "d" -> null))
582    }
583    {
584      val ret = cast(map, MapType(StringType, BooleanType, valueContainsNull = false))
585      assert(ret.resolved === false)
586    }
587    {
588      val ret = cast(map, MapType(IntegerType, StringType, valueContainsNull = true))
589      assert(ret.resolved === false)
590    }
591
592    {
593      val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = true))
594      assert(ret.resolved === true)
595      checkEvaluation(ret, Map("a" -> 123, "b" -> null, "c" -> null))
596    }
597    {
598      val ret = cast(map_notNull, MapType(StringType, IntegerType, valueContainsNull = false))
599      assert(ret.resolved === false)
600    }
601    {
602      val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = true))
603      assert(ret.resolved === true)
604      checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false))
605    }
606    {
607      val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false))
608      assert(ret.resolved === false)
609    }
610    {
611      val ret = cast(map_notNull, MapType(IntegerType, StringType, valueContainsNull = true))
612      assert(ret.resolved === false)
613    }
614
615    {
616      val ret = cast(map, IntegerType)
617      assert(ret.resolved === false)
618    }
619  }
620
621  test("cast from struct") {
622    checkNullCast(
623      StructType(Seq(
624        StructField("a", StringType),
625        StructField("b", IntegerType))),
626      StructType(Seq(
627        StructField("a", StringType),
628        StructField("b", StringType))))
629
630    val struct = Literal.create(
631      InternalRow(
632        UTF8String.fromString("123"),
633        UTF8String.fromString("true"),
634        UTF8String.fromString("f"),
635        null),
636      StructType(Seq(
637        StructField("a", StringType, nullable = true),
638        StructField("b", StringType, nullable = true),
639        StructField("c", StringType, nullable = true),
640        StructField("d", StringType, nullable = true))))
641    val struct_notNull = Literal.create(
642      InternalRow(
643        UTF8String.fromString("123"),
644        UTF8String.fromString("true"),
645        UTF8String.fromString("f")),
646      StructType(Seq(
647        StructField("a", StringType, nullable = false),
648        StructField("b", StringType, nullable = false),
649        StructField("c", StringType, nullable = false))))
650
651    {
652      val ret = cast(struct, StructType(Seq(
653        StructField("a", IntegerType, nullable = true),
654        StructField("b", IntegerType, nullable = true),
655        StructField("c", IntegerType, nullable = true),
656        StructField("d", IntegerType, nullable = true))))
657      assert(ret.resolved === true)
658      checkEvaluation(ret, InternalRow(123, null, null, null))
659    }
660    {
661      val ret = cast(struct, StructType(Seq(
662        StructField("a", IntegerType, nullable = true),
663        StructField("b", IntegerType, nullable = true),
664        StructField("c", IntegerType, nullable = false),
665        StructField("d", IntegerType, nullable = true))))
666      assert(ret.resolved === false)
667    }
668    {
669      val ret = cast(struct, StructType(Seq(
670        StructField("a", BooleanType, nullable = true),
671        StructField("b", BooleanType, nullable = true),
672        StructField("c", BooleanType, nullable = true),
673        StructField("d", BooleanType, nullable = true))))
674      assert(ret.resolved === true)
675      checkEvaluation(ret, InternalRow(null, true, false, null))
676    }
677    {
678      val ret = cast(struct, StructType(Seq(
679        StructField("a", BooleanType, nullable = true),
680        StructField("b", BooleanType, nullable = true),
681        StructField("c", BooleanType, nullable = false),
682        StructField("d", BooleanType, nullable = true))))
683      assert(ret.resolved === false)
684    }
685
686    {
687      val ret = cast(struct_notNull, StructType(Seq(
688        StructField("a", IntegerType, nullable = true),
689        StructField("b", IntegerType, nullable = true),
690        StructField("c", IntegerType, nullable = true))))
691      assert(ret.resolved === true)
692      checkEvaluation(ret, InternalRow(123, null, null))
693    }
694    {
695      val ret = cast(struct_notNull, StructType(Seq(
696        StructField("a", IntegerType, nullable = true),
697        StructField("b", IntegerType, nullable = true),
698        StructField("c", IntegerType, nullable = false))))
699      assert(ret.resolved === false)
700    }
701    {
702      val ret = cast(struct_notNull, StructType(Seq(
703        StructField("a", BooleanType, nullable = true),
704        StructField("b", BooleanType, nullable = true),
705        StructField("c", BooleanType, nullable = true))))
706      assert(ret.resolved === true)
707      checkEvaluation(ret, InternalRow(null, true, false))
708    }
709    {
710      val ret = cast(struct_notNull, StructType(Seq(
711        StructField("a", BooleanType, nullable = true),
712        StructField("b", BooleanType, nullable = true),
713        StructField("c", BooleanType, nullable = false))))
714      assert(ret.resolved === false)
715    }
716
717    {
718      val ret = cast(struct, StructType(Seq(
719        StructField("a", StringType, nullable = true),
720        StructField("b", StringType, nullable = true),
721        StructField("c", StringType, nullable = true))))
722      assert(ret.resolved === false)
723    }
724    {
725      val ret = cast(struct, IntegerType)
726      assert(ret.resolved === false)
727    }
728  }
729
730  test("cast struct with a timestamp field") {
731    val originalSchema = new StructType().add("tsField", TimestampType, nullable = false)
732    // nine out of ten times I'm casting a struct, it's to normalize its fields nullability
733    val targetSchema = new StructType().add("tsField", TimestampType, nullable = true)
734
735    val inp = Literal.create(InternalRow(0L), originalSchema)
736    val expected = InternalRow(0L)
737    checkEvaluation(cast(inp, targetSchema), expected)
738  }
739
740  test("complex casting") {
741    val complex = Literal.create(
742      Row(
743        Seq("123", "true", "f"),
744        Map("a" -> "123", "b" -> "true", "c" -> "f"),
745        Row(0)),
746      StructType(Seq(
747        StructField("a",
748          ArrayType(StringType, containsNull = false), nullable = true),
749        StructField("m",
750          MapType(StringType, StringType, valueContainsNull = false), nullable = true),
751        StructField("s",
752          StructType(Seq(
753            StructField("i", IntegerType, nullable = true)))))))
754
755    val ret = cast(complex, StructType(Seq(
756      StructField("a",
757        ArrayType(IntegerType, containsNull = true), nullable = true),
758      StructField("m",
759        MapType(StringType, BooleanType, valueContainsNull = false), nullable = true),
760      StructField("s",
761        StructType(Seq(
762          StructField("l", LongType, nullable = true)))))))
763
764    assert(ret.resolved === false)
765  }
766
767  test("cast between string and interval") {
768    import org.apache.spark.unsafe.types.CalendarInterval
769
770    checkEvaluation(Cast(Literal(""), CalendarIntervalType), null)
771    checkEvaluation(Cast(Literal("interval -3 month 7 hours"), CalendarIntervalType),
772      new CalendarInterval(-3, 7 * CalendarInterval.MICROS_PER_HOUR))
773    checkEvaluation(Cast(Literal.create(
774      new CalendarInterval(15, -3 * CalendarInterval.MICROS_PER_DAY), CalendarIntervalType),
775      StringType),
776      "interval 1 years 3 months -3 days")
777  }
778
779  test("cast string to boolean") {
780    checkCast("t", true)
781    checkCast("true", true)
782    checkCast("tRUe", true)
783    checkCast("y", true)
784    checkCast("yes", true)
785    checkCast("1", true)
786
787    checkCast("f", false)
788    checkCast("false", false)
789    checkCast("FAlsE", false)
790    checkCast("n", false)
791    checkCast("no", false)
792    checkCast("0", false)
793
794    checkEvaluation(cast("abc", BooleanType), null)
795    checkEvaluation(cast("", BooleanType), null)
796  }
797
798  test("SPARK-16729 type checking for casting to date type") {
799    assert(cast("1234", DateType).checkInputDataTypes().isSuccess)
800    assert(cast(new Timestamp(1), DateType).checkInputDataTypes().isSuccess)
801    assert(cast(false, DateType).checkInputDataTypes().isFailure)
802    assert(cast(1.toByte, DateType).checkInputDataTypes().isFailure)
803    assert(cast(1.toShort, DateType).checkInputDataTypes().isFailure)
804    assert(cast(1, DateType).checkInputDataTypes().isFailure)
805    assert(cast(1L, DateType).checkInputDataTypes().isFailure)
806    assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
807    assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
808  }
809}
810