1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package types
16
17import (
18	"fmt"
19	"reflect"
20	"strconv"
21	"time"
22
23	"github.com/golang/protobuf/proto"
24	"github.com/golang/protobuf/ptypes"
25
26	"github.com/google/cel-go/common/overloads"
27	"github.com/google/cel-go/common/types/ref"
28	"github.com/google/cel-go/common/types/traits"
29
30	dpb "github.com/golang/protobuf/ptypes/duration"
31	structpb "github.com/golang/protobuf/ptypes/struct"
32)
33
34// Duration type that implements ref.Val and supports add, compare, negate,
35// and subtract operators. This type is also a receiver which means it can
36// participate in dispatch to receiver functions.
37type Duration struct {
38	*dpb.Duration
39}
40
41var (
42	// DurationType singleton.
43	DurationType = NewTypeValue("google.protobuf.Duration",
44		traits.AdderType,
45		traits.ComparerType,
46		traits.NegatorType,
47		traits.ReceiverType,
48		traits.SubtractorType)
49)
50
51// Add implements traits.Adder.Add.
52func (d Duration) Add(other ref.Val) ref.Val {
53	switch other.Type() {
54	case DurationType:
55		dur1, err := ptypes.Duration(d.Duration)
56		if err != nil {
57			return &Err{err}
58		}
59		dur2, err := ptypes.Duration(other.(Duration).Duration)
60		if err != nil {
61			return &Err{err}
62		}
63		return Duration{ptypes.DurationProto(dur1 + dur2)}
64	case TimestampType:
65		dur, err := ptypes.Duration(d.Duration)
66		if err != nil {
67			return &Err{err}
68		}
69		ts, err := ptypes.Timestamp(other.(Timestamp).Timestamp)
70		if err != nil {
71			return &Err{err}
72		}
73		tstamp, err := ptypes.TimestampProto(ts.Add(dur))
74		if err != nil {
75			return &Err{err}
76		}
77		return Timestamp{tstamp}
78	}
79	return ValOrErr(other, "no such overload")
80}
81
82// Compare implements traits.Comparer.Compare.
83func (d Duration) Compare(other ref.Val) ref.Val {
84	otherDur, ok := other.(Duration)
85	if !ok {
86		return ValOrErr(other, "no such overload")
87	}
88	dur1, err := ptypes.Duration(d.Duration)
89	if err != nil {
90		return &Err{err}
91	}
92	dur2, err := ptypes.Duration(otherDur.Duration)
93	if err != nil {
94		return &Err{err}
95	}
96	dur := dur1 - dur2
97	if dur < 0 {
98		return IntNegOne
99	}
100	if dur > 0 {
101		return IntOne
102	}
103	return IntZero
104}
105
106// ConvertToNative implements ref.Val.ConvertToNative.
107func (d Duration) ConvertToNative(typeDesc reflect.Type) (interface{}, error) {
108	switch typeDesc {
109	case anyValueType:
110		// Pack the underlying proto value into an Any value.
111		return ptypes.MarshalAny(d.Value().(*dpb.Duration))
112	case durationValueType:
113		// Unwrap the CEL value to its underlying proto value.
114		return d.Value(), nil
115	case jsonValueType:
116		// CEL follows the proto3 to JSON conversion.
117		// Note, using jsonpb would wrap the result in extra double quotes.
118		v := d.ConvertToType(StringType)
119		if IsError(v) {
120			return nil, v.(*Err)
121		}
122		return &structpb.Value{
123			Kind: &structpb.Value_StringValue{StringValue: string(v.(String))},
124		}, nil
125	}
126	// If the duration is already assignable to the desired type return it.
127	if reflect.TypeOf(d).AssignableTo(typeDesc) {
128		return d, nil
129	}
130	return nil, fmt.Errorf("type conversion error from "+
131		"'google.protobuf.Duration' to '%v'", typeDesc)
132}
133
134// ConvertToType implements ref.Val.ConvertToType.
135func (d Duration) ConvertToType(typeVal ref.Type) ref.Val {
136	switch typeVal {
137	case StringType:
138		if dur, err := ptypes.Duration(d.Duration); err == nil {
139			return String(strconv.FormatFloat(dur.Seconds(), 'f', -1, 64) + "s")
140		}
141	case IntType:
142		if dur, err := ptypes.Duration(d.Duration); err == nil {
143			return Int(dur)
144		}
145	case DurationType:
146		return d
147	case TypeType:
148		return DurationType
149	}
150	return NewErr("type conversion error from '%s' to '%s'", DurationType, typeVal)
151}
152
153// Equal implements ref.Val.Equal.
154func (d Duration) Equal(other ref.Val) ref.Val {
155	otherDur, ok := other.(Duration)
156	if !ok {
157		return ValOrErr(other, "no such overload")
158	}
159	return Bool(proto.Equal(d.Duration, otherDur.Value().(proto.Message)))
160}
161
162// Negate implements traits.Negater.Negate.
163func (d Duration) Negate() ref.Val {
164	dur, err := ptypes.Duration(d.Duration)
165	if err != nil {
166		return &Err{err}
167	}
168	return Duration{ptypes.DurationProto(-dur)}
169}
170
171// Receive implements traits.Receiver.Receive.
172func (d Duration) Receive(function string, overload string, args []ref.Val) ref.Val {
173	dur, err := ptypes.Duration(d.Duration)
174	if err != nil {
175		return &Err{err}
176	}
177	if len(args) == 0 {
178		if f, found := durationZeroArgOverloads[function]; found {
179			return f(dur)
180		}
181	}
182	return NewErr("no such overload")
183}
184
185// Subtract implements traits.Subtractor.Subtract.
186func (d Duration) Subtract(subtrahend ref.Val) ref.Val {
187	subtraDur, ok := subtrahend.(Duration)
188	if !ok {
189		return ValOrErr(subtrahend, "no such overload")
190	}
191	return d.Add(subtraDur.Negate())
192}
193
194// Type implements ref.Val.Type.
195func (d Duration) Type() ref.Type {
196	return DurationType
197}
198
199// Value implements ref.Val.Value.
200func (d Duration) Value() interface{} {
201	return d.Duration
202}
203
204var (
205	durationValueType = reflect.TypeOf(&dpb.Duration{})
206
207	durationZeroArgOverloads = map[string]func(time.Duration) ref.Val{
208		overloads.TimeGetHours: func(dur time.Duration) ref.Val {
209			return Int(dur.Hours())
210		},
211		overloads.TimeGetMinutes: func(dur time.Duration) ref.Val {
212			return Int(dur.Minutes())
213		},
214		overloads.TimeGetSeconds: func(dur time.Duration) ref.Val {
215			return Int(dur.Seconds())
216		},
217		overloads.TimeGetMilliseconds: func(dur time.Duration) ref.Val {
218			return Int(dur.Nanoseconds() / 1000000)
219		}}
220)
221