1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package dbutil
5
6import (
7	"database/sql"
8	"database/sql/driver"
9	"time"
10
11	"github.com/zeebo/errs"
12)
13
14const (
15	sqliteTimeLayout           = "2006-01-02 15:04:05-07:00"
16	sqliteTimeLayoutNoTimeZone = "2006-01-02 15:04:05"
17	sqliteTimeLayoutDate       = "2006-01-02"
18)
19
20// ErrNullTime defines error class for NullTime.
21var ErrNullTime = errs.Class("null time")
22
23// NullTime time helps convert nil to time.Time.
24type NullTime struct {
25	time.Time
26	Valid bool
27}
28
29// Scan implements the Scanner interface.
30func (nt *NullTime) Scan(value interface{}) error {
31	nt.Time, nt.Valid = time.Time{}, false
32
33	if value == nil {
34		return nil
35	}
36
37	switch v := value.(type) {
38	// Postgres could return for lagged time values.
39	case time.Time:
40		nt.Time, nt.Valid = v, true
41
42		// Database could return for nullable time values.
43	case sql.NullTime:
44		nt.Time, nt.Valid = v.Time, v.Valid
45
46		// Sqlite may return this.
47	case string:
48		parsed, err := parseSqliteTimeString(v)
49		if err != nil {
50			return ErrNullTime.Wrap(err)
51		}
52		nt.Time, nt.Valid = parsed, true
53
54		// Sqlite may return this.
55	case []byte:
56		parsed, err := parseSqliteTimeString(string(v))
57		if err != nil {
58			return ErrNullTime.Wrap(err)
59		}
60		nt.Time, nt.Valid = parsed, true
61
62	default:
63		return ErrNullTime.New("scan received unsupported value %T", value)
64	}
65
66	return nil
67}
68
69// Value implements the driver Valuer interface.
70func (nt NullTime) Value() (driver.Value, error) {
71	if !nt.Valid {
72		return nil, nil
73	}
74	return nt.Time, nil
75}
76
77// parseSqliteTimeString parses sqlite times string.
78// It tries to process value as string with timezone first,
79// then fallback to parsing as string without timezone and
80// finally to parsing value as date.
81func parseSqliteTimeString(val string) (time.Time, error) {
82	var times time.Time
83	var err error
84
85	times, err = time.Parse(sqliteTimeLayout, val)
86	if err == nil {
87		return times, nil
88	}
89
90	times, err = time.Parse(sqliteTimeLayoutNoTimeZone, val)
91	if err == nil {
92		return times, nil
93	}
94
95	return time.Parse(sqliteTimeLayoutDate, val)
96}
97