1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package bigquery
16
17import (
18	"context"
19	"fmt"
20	"time"
21
22	"cloud.google.com/go/internal/optional"
23	"cloud.google.com/go/internal/trace"
24	bq "google.golang.org/api/bigquery/v2"
25)
26
27// Model represent a reference to a BigQuery ML model.
28// Within the API, models are used largely for communicating
29// statistical information about a given model, as creation of models is only
30// supported via BigQuery queries (e.g. CREATE MODEL .. AS ..).
31//
32// For more info, see documentation for Bigquery ML,
33// see: https://cloud.google.com/bigquery/docs/bigqueryml
34type Model struct {
35	ProjectID string
36	DatasetID string
37	// ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_).
38	// The maximum length is 1,024 characters.
39	ModelID string
40
41	c *Client
42}
43
44// FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format.
45func (m *Model) FullyQualifiedName() string {
46	return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID)
47}
48
49func (m *Model) toBQ() *bq.ModelReference {
50	return &bq.ModelReference{
51		ProjectId: m.ProjectID,
52		DatasetId: m.DatasetID,
53		ModelId:   m.ModelID,
54	}
55}
56
57// Metadata fetches the metadata for a model, which includes ML training statistics.
58func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) {
59	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata")
60	defer func() { trace.EndSpan(ctx, err) }()
61
62	req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
63	setClientHeader(req.Header())
64	var model *bq.Model
65	err = runWithRetry(ctx, func() (err error) {
66		model, err = req.Do()
67		return err
68	})
69	if err != nil {
70		return nil, err
71	}
72	return bqToModelMetadata(model)
73}
74
75// Update updates mutable fields in an ML model.
76func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) {
77	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update")
78	defer func() { trace.EndSpan(ctx, err) }()
79
80	bqm, err := mm.toBQ()
81	if err != nil {
82		return nil, err
83	}
84	call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx)
85	setClientHeader(call.Header())
86	if etag != "" {
87		call.Header().Set("If-Match", etag)
88	}
89	var res *bq.Model
90	if err := runWithRetry(ctx, func() (err error) {
91		res, err = call.Do()
92		return err
93	}); err != nil {
94		return nil, err
95	}
96	return bqToModelMetadata(res)
97}
98
99// Delete deletes an ML model.
100func (m *Model) Delete(ctx context.Context) (err error) {
101	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete")
102	defer func() { trace.EndSpan(ctx, err) }()
103
104	req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
105	setClientHeader(req.Header())
106	return req.Do()
107}
108
109// ModelMetadata represents information about a BigQuery ML model.
110type ModelMetadata struct {
111	// The user-friendly description of the model.
112	Description string
113
114	// The user-friendly name of the model.
115	Name string
116
117	// The type of the model.  Possible values include:
118	// "LINEAR_REGRESSION" - a linear regression model
119	// "LOGISTIC_REGRESSION" - a logistic regression model
120	// "KMEANS" - a k-means clustering model
121	Type string
122
123	// The creation time of the model.
124	CreationTime time.Time
125
126	// The last modified time of the model.
127	LastModifiedTime time.Time
128
129	// The expiration time of the model.
130	ExpirationTime time.Time
131
132	// The geographic location where the model resides.  This value is
133	// inherited from the encapsulating dataset.
134	Location string
135
136	// Custom encryption configuration (e.g., Cloud KMS keys).
137	EncryptionConfig *EncryptionConfig
138
139	// The input feature columns used to train the model.
140	featureColumns []*bq.StandardSqlField
141
142	// The label columns used to train the model.  Output
143	// from the model will have a "predicted_" prefix for these columns.
144	labelColumns []*bq.StandardSqlField
145
146	// Information for all training runs, ordered by increasing start times.
147	trainingRuns []*bq.TrainingRun
148
149	Labels map[string]string
150
151	// ETag is the ETag obtained when reading metadata. Pass it to Model.Update
152	// to ensure that the metadata hasn't changed since it was read.
153	ETag string
154}
155
156// TrainingRun represents information about a single training run for a BigQuery ML model.
157// Experimental:  This information may be modified or removed in future versions of this package.
158type TrainingRun bq.TrainingRun
159
160// RawTrainingRuns exposes the underlying training run stats for a model using types from
161// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
162// It is EXPERIMENTAL and subject to change or removal without notice.
163func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun {
164	if mm.trainingRuns == nil {
165		return nil
166	}
167	var runs []*TrainingRun
168
169	for _, v := range mm.trainingRuns {
170		r := TrainingRun(*v)
171		runs = append(runs, &r)
172	}
173	return runs
174}
175
176// RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from
177// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
178// It is EXPERIMENTAL and subject to change or removal without notice.
179func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) {
180	return bqToModelCols(mm.labelColumns)
181}
182
183// RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from
184// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
185// It is EXPERIMENTAL and subject to change or removal without notice.
186func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) {
187	return bqToModelCols(mm.featureColumns)
188}
189
190func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) {
191	if s == nil {
192		return nil, nil
193	}
194	var cols []*StandardSQLField
195	for _, v := range s {
196		c, err := bqToStandardSQLField(v)
197		if err != nil {
198			return nil, err
199		}
200		cols = append(cols, c)
201	}
202	return cols, nil
203}
204
205func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) {
206	md := &ModelMetadata{
207		Description:      m.Description,
208		Name:             m.FriendlyName,
209		Type:             m.ModelType,
210		Location:         m.Location,
211		Labels:           m.Labels,
212		ExpirationTime:   unixMillisToTime(m.ExpirationTime),
213		CreationTime:     unixMillisToTime(m.CreationTime),
214		LastModifiedTime: unixMillisToTime(m.LastModifiedTime),
215		EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration),
216		featureColumns:   m.FeatureColumns,
217		labelColumns:     m.LabelColumns,
218		trainingRuns:     m.TrainingRuns,
219		ETag:             m.Etag,
220	}
221	return md, nil
222}
223
224// ModelMetadataToUpdate is used when updating an ML model's metadata.
225// Only non-nil fields will be updated.
226type ModelMetadataToUpdate struct {
227	// The user-friendly description of this model.
228	Description optional.String
229
230	// The user-friendly name of this model.
231	Name optional.String
232
233	// The time when this model expires.  To remove a model's expiration,
234	// set ExpirationTime to NeverExpire.  The zero value is ignored.
235	ExpirationTime time.Time
236
237	// The model's encryption configuration.
238	EncryptionConfig *EncryptionConfig
239
240	labelUpdater
241}
242
243func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) {
244	m := &bq.Model{}
245	forceSend := func(field string) {
246		m.ForceSendFields = append(m.ForceSendFields, field)
247	}
248
249	if mm.Description != nil {
250		m.Description = optional.ToString(mm.Description)
251		forceSend("Description")
252	}
253
254	if mm.Name != nil {
255		m.FriendlyName = optional.ToString(mm.Name)
256		forceSend("FriendlyName")
257	}
258
259	if mm.EncryptionConfig != nil {
260		m.EncryptionConfiguration = mm.EncryptionConfig.toBQ()
261	}
262
263	if !validExpiration(mm.ExpirationTime) {
264		return nil, invalidTimeError(mm.ExpirationTime)
265	}
266	if mm.ExpirationTime == NeverExpire {
267		m.NullFields = append(m.NullFields, "ExpirationTime")
268	} else if !mm.ExpirationTime.IsZero() {
269		m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6
270		forceSend("ExpirationTime")
271	}
272	labels, forces, nulls := mm.update()
273	m.Labels = labels
274	m.ForceSendFields = append(m.ForceSendFields, forces...)
275	m.NullFields = append(m.NullFields, nulls...)
276	return m, nil
277}
278