1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package bigquery 16 17import ( 18 "context" 19 "fmt" 20 "time" 21 22 "cloud.google.com/go/internal/optional" 23 "cloud.google.com/go/internal/trace" 24 bq "google.golang.org/api/bigquery/v2" 25) 26 27// Model represent a reference to a BigQuery ML model. 28// Within the API, models are used largely for communicating 29// statistical information about a given model, as creation of models is only 30// supported via BigQuery queries (e.g. CREATE MODEL .. AS ..). 31// 32// For more info, see documentation for Bigquery ML, 33// see: https://cloud.google.com/bigquery/docs/bigqueryml 34type Model struct { 35 ProjectID string 36 DatasetID string 37 // ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_). 38 // The maximum length is 1,024 characters. 39 ModelID string 40 41 c *Client 42} 43 44// FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format. 45func (m *Model) FullyQualifiedName() string { 46 return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID) 47} 48 49// Metadata fetches the metadata for a model, which includes ML training statistics. 50func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) { 51 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata") 52 defer func() { trace.EndSpan(ctx, err) }() 53 54 req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx) 55 setClientHeader(req.Header()) 56 var model *bq.Model 57 err = runWithRetry(ctx, func() (err error) { 58 model, err = req.Do() 59 return err 60 }) 61 if err != nil { 62 return nil, err 63 } 64 return bqToModelMetadata(model) 65} 66 67// Update updates mutable fields in an ML model. 68func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) { 69 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update") 70 defer func() { trace.EndSpan(ctx, err) }() 71 72 bqm, err := mm.toBQ() 73 if err != nil { 74 return nil, err 75 } 76 call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx) 77 setClientHeader(call.Header()) 78 if etag != "" { 79 call.Header().Set("If-Match", etag) 80 } 81 var res *bq.Model 82 if err := runWithRetry(ctx, func() (err error) { 83 res, err = call.Do() 84 return err 85 }); err != nil { 86 return nil, err 87 } 88 return bqToModelMetadata(res) 89} 90 91// Delete deletes an ML model. 92func (m *Model) Delete(ctx context.Context) (err error) { 93 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete") 94 defer func() { trace.EndSpan(ctx, err) }() 95 96 req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx) 97 setClientHeader(req.Header()) 98 return req.Do() 99} 100 101// ModelMetadata represents information about a BigQuery ML model. 102type ModelMetadata struct { 103 // The user-friendly description of the model. 104 Description string 105 106 // The user-friendly name of the model. 107 Name string 108 109 // The type of the model. Possible values include: 110 // "LINEAR_REGRESSION" - a linear regression model 111 // "LOGISTIC_REGRESSION" - a logistic regression model 112 // "KMEANS" - a k-means clustering model 113 Type string 114 115 // The creation time of the model. 116 CreationTime time.Time 117 118 // The last modified time of the model. 119 LastModifiedTime time.Time 120 121 // The expiration time of the model. 122 ExpirationTime time.Time 123 124 // The geographic location where the model resides. This value is 125 // inherited from the encapsulating dataset. 126 Location string 127 128 // Custom encryption configuration (e.g., Cloud KMS keys). 129 EncryptionConfig *EncryptionConfig 130 131 // The input feature columns used to train the model. 132 featureColumns []*bq.StandardSqlField 133 134 // The label columns used to train the model. Output 135 // from the model will have a "predicted_" prefix for these columns. 136 labelColumns []*bq.StandardSqlField 137 138 // Information for all training runs, ordered by increasing start times. 139 trainingRuns []*bq.TrainingRun 140 141 Labels map[string]string 142 143 // ETag is the ETag obtained when reading metadata. Pass it to Model.Update 144 // to ensure that the metadata hasn't changed since it was read. 145 ETag string 146} 147 148// TrainingRun represents information about a single training run for a BigQuery ML model. 149// Experimental: This information may be modified or removed in future versions of this package. 150type TrainingRun bq.TrainingRun 151 152// RawTrainingRuns exposes the underlying training run stats for a model using types from 153// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 154// It is EXPERIMENTAL and subject to change or removal without notice. 155func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun { 156 if mm.trainingRuns == nil { 157 return nil 158 } 159 var runs []*TrainingRun 160 161 for _, v := range mm.trainingRuns { 162 r := TrainingRun(*v) 163 runs = append(runs, &r) 164 } 165 return runs 166} 167 168// RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from 169// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 170// It is EXPERIMENTAL and subject to change or removal without notice. 171func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) { 172 return bqToModelCols(mm.labelColumns) 173} 174 175// RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from 176// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 177// It is EXPERIMENTAL and subject to change or removal without notice. 178func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) { 179 return bqToModelCols(mm.featureColumns) 180} 181 182func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) { 183 if s == nil { 184 return nil, nil 185 } 186 var cols []*StandardSQLField 187 for _, v := range s { 188 c, err := bqToStandardSQLField(v) 189 if err != nil { 190 return nil, err 191 } 192 cols = append(cols, c) 193 } 194 return cols, nil 195} 196 197func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) { 198 md := &ModelMetadata{ 199 Description: m.Description, 200 Name: m.FriendlyName, 201 Type: m.ModelType, 202 Location: m.Location, 203 Labels: m.Labels, 204 ExpirationTime: unixMillisToTime(m.ExpirationTime), 205 CreationTime: unixMillisToTime(m.CreationTime), 206 LastModifiedTime: unixMillisToTime(m.LastModifiedTime), 207 EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration), 208 featureColumns: m.FeatureColumns, 209 labelColumns: m.LabelColumns, 210 trainingRuns: m.TrainingRuns, 211 ETag: m.Etag, 212 } 213 return md, nil 214} 215 216// ModelMetadataToUpdate is used when updating an ML model's metadata. 217// Only non-nil fields will be updated. 218type ModelMetadataToUpdate struct { 219 // The user-friendly description of this model. 220 Description optional.String 221 222 // The user-friendly name of this model. 223 Name optional.String 224 225 // The time when this model expires. To remove a model's expiration, 226 // set ExpirationTime to NeverExpire. The zero value is ignored. 227 ExpirationTime time.Time 228 229 // The model's encryption configuration. 230 EncryptionConfig *EncryptionConfig 231 232 labelUpdater 233} 234 235func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) { 236 m := &bq.Model{} 237 forceSend := func(field string) { 238 m.ForceSendFields = append(m.ForceSendFields, field) 239 } 240 241 if mm.Description != nil { 242 m.Description = optional.ToString(mm.Description) 243 forceSend("Description") 244 } 245 246 if mm.Name != nil { 247 m.FriendlyName = optional.ToString(mm.Name) 248 forceSend("FriendlyName") 249 } 250 251 if mm.EncryptionConfig != nil { 252 m.EncryptionConfiguration = mm.EncryptionConfig.toBQ() 253 } 254 255 if !validExpiration(mm.ExpirationTime) { 256 return nil, invalidTimeError(mm.ExpirationTime) 257 } 258 if mm.ExpirationTime == NeverExpire { 259 m.NullFields = append(m.NullFields, "ExpirationTime") 260 } else if !mm.ExpirationTime.IsZero() { 261 m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6 262 forceSend("ExpirationTime") 263 } 264 labels, forces, nulls := mm.update() 265 m.Labels = labels 266 m.ForceSendFields = append(m.ForceSendFields, forces...) 267 m.NullFields = append(m.NullFields, nulls...) 268 return m, nil 269} 270