1// Copyright 2019 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package bigquery
16
17import (
18	"context"
19	"fmt"
20	"time"
21
22	"cloud.google.com/go/internal/optional"
23	"cloud.google.com/go/internal/trace"
24	bq "google.golang.org/api/bigquery/v2"
25)
26
27// Model represent a reference to a BigQuery ML model.
28// Within the API, models are used largely for communicating
29// statistical information about a given model, as creation of models is only
30// supported via BigQuery queries (e.g. CREATE MODEL .. AS ..).
31//
32// For more info, see documentation for Bigquery ML,
33// see: https://cloud.google.com/bigquery/docs/bigqueryml
34type Model struct {
35	ProjectID string
36	DatasetID string
37	// ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_).
38	// The maximum length is 1,024 characters.
39	ModelID string
40
41	c *Client
42}
43
44// FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format.
45func (m *Model) FullyQualifiedName() string {
46	return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID)
47}
48
49// Metadata fetches the metadata for a model, which includes ML training statistics.
50func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) {
51	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata")
52	defer func() { trace.EndSpan(ctx, err) }()
53
54	req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
55	setClientHeader(req.Header())
56	var model *bq.Model
57	err = runWithRetry(ctx, func() (err error) {
58		model, err = req.Do()
59		return err
60	})
61	if err != nil {
62		return nil, err
63	}
64	return bqToModelMetadata(model)
65}
66
67// Update updates mutable fields in an ML model.
68func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) {
69	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update")
70	defer func() { trace.EndSpan(ctx, err) }()
71
72	bqm, err := mm.toBQ()
73	if err != nil {
74		return nil, err
75	}
76	call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx)
77	setClientHeader(call.Header())
78	if etag != "" {
79		call.Header().Set("If-Match", etag)
80	}
81	var res *bq.Model
82	if err := runWithRetry(ctx, func() (err error) {
83		res, err = call.Do()
84		return err
85	}); err != nil {
86		return nil, err
87	}
88	return bqToModelMetadata(res)
89}
90
91// Delete deletes an ML model.
92func (m *Model) Delete(ctx context.Context) (err error) {
93	ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete")
94	defer func() { trace.EndSpan(ctx, err) }()
95
96	req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx)
97	setClientHeader(req.Header())
98	return req.Do()
99}
100
101// ModelMetadata represents information about a BigQuery ML model.
102type ModelMetadata struct {
103	// The user-friendly description of the model.
104	Description string
105
106	// The user-friendly name of the model.
107	Name string
108
109	// The type of the model.  Possible values include:
110	// "LINEAR_REGRESSION" - a linear regression model
111	// "LOGISTIC_REGRESSION" - a logistic regression model
112	// "KMEANS" - a k-means clustering model
113	Type string
114
115	// The creation time of the model.
116	CreationTime time.Time
117
118	// The last modified time of the model.
119	LastModifiedTime time.Time
120
121	// The expiration time of the model.
122	ExpirationTime time.Time
123
124	// The geographic location where the model resides.  This value is
125	// inherited from the encapsulating dataset.
126	Location string
127
128	// Custom encryption configuration (e.g., Cloud KMS keys).
129	EncryptionConfig *EncryptionConfig
130
131	// The input feature columns used to train the model.
132	featureColumns []*bq.StandardSqlField
133
134	// The label columns used to train the model.  Output
135	// from the model will have a "predicted_" prefix for these columns.
136	labelColumns []*bq.StandardSqlField
137
138	// Information for all training runs, ordered by increasing start times.
139	trainingRuns []*bq.TrainingRun
140
141	Labels map[string]string
142
143	// ETag is the ETag obtained when reading metadata. Pass it to Model.Update
144	// to ensure that the metadata hasn't changed since it was read.
145	ETag string
146}
147
148// TrainingRun represents information about a single training run for a BigQuery ML model.
149// Experimental:  This information may be modified or removed in future versions of this package.
150type TrainingRun bq.TrainingRun
151
152// RawTrainingRuns exposes the underlying training run stats for a model using types from
153// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
154// It is EXPERIMENTAL and subject to change or removal without notice.
155func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun {
156	if mm.trainingRuns == nil {
157		return nil
158	}
159	var runs []*TrainingRun
160
161	for _, v := range mm.trainingRuns {
162		r := TrainingRun(*v)
163		runs = append(runs, &r)
164	}
165	return runs
166}
167
168// RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from
169// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
170// It is EXPERIMENTAL and subject to change or removal without notice.
171func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) {
172	return bqToModelCols(mm.labelColumns)
173}
174
175// RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from
176// "google.golang.org/api/bigquery/v2", which are subject to change without warning.
177// It is EXPERIMENTAL and subject to change or removal without notice.
178func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) {
179	return bqToModelCols(mm.featureColumns)
180}
181
182func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) {
183	if s == nil {
184		return nil, nil
185	}
186	var cols []*StandardSQLField
187	for _, v := range s {
188		c, err := bqToStandardSQLField(v)
189		if err != nil {
190			return nil, err
191		}
192		cols = append(cols, c)
193	}
194	return cols, nil
195}
196
197func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) {
198	md := &ModelMetadata{
199		Description:      m.Description,
200		Name:             m.FriendlyName,
201		Type:             m.ModelType,
202		Location:         m.Location,
203		Labels:           m.Labels,
204		ExpirationTime:   unixMillisToTime(m.ExpirationTime),
205		CreationTime:     unixMillisToTime(m.CreationTime),
206		LastModifiedTime: unixMillisToTime(m.LastModifiedTime),
207		EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration),
208		featureColumns:   m.FeatureColumns,
209		labelColumns:     m.LabelColumns,
210		trainingRuns:     m.TrainingRuns,
211		ETag:             m.Etag,
212	}
213	return md, nil
214}
215
216// ModelMetadataToUpdate is used when updating an ML model's metadata.
217// Only non-nil fields will be updated.
218type ModelMetadataToUpdate struct {
219	// The user-friendly description of this model.
220	Description optional.String
221
222	// The user-friendly name of this model.
223	Name optional.String
224
225	// The time when this model expires.  To remove a model's expiration,
226	// set ExpirationTime to NeverExpire.  The zero value is ignored.
227	ExpirationTime time.Time
228
229	// The model's encryption configuration.
230	EncryptionConfig *EncryptionConfig
231
232	labelUpdater
233}
234
235func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) {
236	m := &bq.Model{}
237	forceSend := func(field string) {
238		m.ForceSendFields = append(m.ForceSendFields, field)
239	}
240
241	if mm.Description != nil {
242		m.Description = optional.ToString(mm.Description)
243		forceSend("Description")
244	}
245
246	if mm.Name != nil {
247		m.FriendlyName = optional.ToString(mm.Name)
248		forceSend("FriendlyName")
249	}
250
251	if mm.EncryptionConfig != nil {
252		m.EncryptionConfiguration = mm.EncryptionConfig.toBQ()
253	}
254
255	if !validExpiration(mm.ExpirationTime) {
256		return nil, invalidTimeError(mm.ExpirationTime)
257	}
258	if mm.ExpirationTime == NeverExpire {
259		m.NullFields = append(m.NullFields, "ExpirationTime")
260	} else if !mm.ExpirationTime.IsZero() {
261		m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6
262		forceSend("ExpirationTime")
263	}
264	labels, forces, nulls := mm.update()
265	m.Labels = labels
266	m.ForceSendFields = append(m.ForceSendFields, forces...)
267	m.NullFields = append(m.NullFields, nulls...)
268	return m, nil
269}
270