1// Copyright 2019 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// http://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13// limitations under the License. 14 15package bigquery 16 17import ( 18 "context" 19 "fmt" 20 "time" 21 22 "cloud.google.com/go/internal/optional" 23 "cloud.google.com/go/internal/trace" 24 bq "google.golang.org/api/bigquery/v2" 25) 26 27// Model represent a reference to a BigQuery ML model. 28// Within the API, models are used largely for communicating 29// statistical information about a given model, as creation of models is only 30// supported via BigQuery queries (e.g. CREATE MODEL .. AS ..). 31// 32// For more info, see documentation for Bigquery ML, 33// see: https://cloud.google.com/bigquery/docs/bigqueryml 34type Model struct { 35 ProjectID string 36 DatasetID string 37 // ModelID must contain only letters (a-z, A-Z), numbers (0-9), or underscores (_). 38 // The maximum length is 1,024 characters. 39 ModelID string 40 41 c *Client 42} 43 44// FullyQualifiedName returns the ID of the model in projectID:datasetID.modelid format. 45func (m *Model) FullyQualifiedName() string { 46 return fmt.Sprintf("%s:%s.%s", m.ProjectID, m.DatasetID, m.ModelID) 47} 48 49func (m *Model) toBQ() *bq.ModelReference { 50 return &bq.ModelReference{ 51 ProjectId: m.ProjectID, 52 DatasetId: m.DatasetID, 53 ModelId: m.ModelID, 54 } 55} 56 57// Metadata fetches the metadata for a model, which includes ML training statistics. 58func (m *Model) Metadata(ctx context.Context) (mm *ModelMetadata, err error) { 59 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Metadata") 60 defer func() { trace.EndSpan(ctx, err) }() 61 62 req := m.c.bqs.Models.Get(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx) 63 setClientHeader(req.Header()) 64 var model *bq.Model 65 err = runWithRetry(ctx, func() (err error) { 66 model, err = req.Do() 67 return err 68 }) 69 if err != nil { 70 return nil, err 71 } 72 return bqToModelMetadata(model) 73} 74 75// Update updates mutable fields in an ML model. 76func (m *Model) Update(ctx context.Context, mm ModelMetadataToUpdate, etag string) (md *ModelMetadata, err error) { 77 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Update") 78 defer func() { trace.EndSpan(ctx, err) }() 79 80 bqm, err := mm.toBQ() 81 if err != nil { 82 return nil, err 83 } 84 call := m.c.bqs.Models.Patch(m.ProjectID, m.DatasetID, m.ModelID, bqm).Context(ctx) 85 setClientHeader(call.Header()) 86 if etag != "" { 87 call.Header().Set("If-Match", etag) 88 } 89 var res *bq.Model 90 if err := runWithRetry(ctx, func() (err error) { 91 res, err = call.Do() 92 return err 93 }); err != nil { 94 return nil, err 95 } 96 return bqToModelMetadata(res) 97} 98 99// Delete deletes an ML model. 100func (m *Model) Delete(ctx context.Context) (err error) { 101 ctx = trace.StartSpan(ctx, "cloud.google.com/go/bigquery.Model.Delete") 102 defer func() { trace.EndSpan(ctx, err) }() 103 104 req := m.c.bqs.Models.Delete(m.ProjectID, m.DatasetID, m.ModelID).Context(ctx) 105 setClientHeader(req.Header()) 106 return req.Do() 107} 108 109// ModelMetadata represents information about a BigQuery ML model. 110type ModelMetadata struct { 111 // The user-friendly description of the model. 112 Description string 113 114 // The user-friendly name of the model. 115 Name string 116 117 // The type of the model. Possible values include: 118 // "LINEAR_REGRESSION" - a linear regression model 119 // "LOGISTIC_REGRESSION" - a logistic regression model 120 // "KMEANS" - a k-means clustering model 121 Type string 122 123 // The creation time of the model. 124 CreationTime time.Time 125 126 // The last modified time of the model. 127 LastModifiedTime time.Time 128 129 // The expiration time of the model. 130 ExpirationTime time.Time 131 132 // The geographic location where the model resides. This value is 133 // inherited from the encapsulating dataset. 134 Location string 135 136 // Custom encryption configuration (e.g., Cloud KMS keys). 137 EncryptionConfig *EncryptionConfig 138 139 // The input feature columns used to train the model. 140 featureColumns []*bq.StandardSqlField 141 142 // The label columns used to train the model. Output 143 // from the model will have a "predicted_" prefix for these columns. 144 labelColumns []*bq.StandardSqlField 145 146 // Information for all training runs, ordered by increasing start times. 147 trainingRuns []*bq.TrainingRun 148 149 Labels map[string]string 150 151 // ETag is the ETag obtained when reading metadata. Pass it to Model.Update 152 // to ensure that the metadata hasn't changed since it was read. 153 ETag string 154} 155 156// TrainingRun represents information about a single training run for a BigQuery ML model. 157// Experimental: This information may be modified or removed in future versions of this package. 158type TrainingRun bq.TrainingRun 159 160// RawTrainingRuns exposes the underlying training run stats for a model using types from 161// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 162// It is EXPERIMENTAL and subject to change or removal without notice. 163func (mm *ModelMetadata) RawTrainingRuns() []*TrainingRun { 164 if mm.trainingRuns == nil { 165 return nil 166 } 167 var runs []*TrainingRun 168 169 for _, v := range mm.trainingRuns { 170 r := TrainingRun(*v) 171 runs = append(runs, &r) 172 } 173 return runs 174} 175 176// RawLabelColumns exposes the underlying label columns used to train an ML model and uses types from 177// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 178// It is EXPERIMENTAL and subject to change or removal without notice. 179func (mm *ModelMetadata) RawLabelColumns() ([]*StandardSQLField, error) { 180 return bqToModelCols(mm.labelColumns) 181} 182 183// RawFeatureColumns exposes the underlying feature columns used to train an ML model and uses types from 184// "google.golang.org/api/bigquery/v2", which are subject to change without warning. 185// It is EXPERIMENTAL and subject to change or removal without notice. 186func (mm *ModelMetadata) RawFeatureColumns() ([]*StandardSQLField, error) { 187 return bqToModelCols(mm.featureColumns) 188} 189 190func bqToModelCols(s []*bq.StandardSqlField) ([]*StandardSQLField, error) { 191 if s == nil { 192 return nil, nil 193 } 194 var cols []*StandardSQLField 195 for _, v := range s { 196 c, err := bqToStandardSQLField(v) 197 if err != nil { 198 return nil, err 199 } 200 cols = append(cols, c) 201 } 202 return cols, nil 203} 204 205func bqToModelMetadata(m *bq.Model) (*ModelMetadata, error) { 206 md := &ModelMetadata{ 207 Description: m.Description, 208 Name: m.FriendlyName, 209 Type: m.ModelType, 210 Location: m.Location, 211 Labels: m.Labels, 212 ExpirationTime: unixMillisToTime(m.ExpirationTime), 213 CreationTime: unixMillisToTime(m.CreationTime), 214 LastModifiedTime: unixMillisToTime(m.LastModifiedTime), 215 EncryptionConfig: bqToEncryptionConfig(m.EncryptionConfiguration), 216 featureColumns: m.FeatureColumns, 217 labelColumns: m.LabelColumns, 218 trainingRuns: m.TrainingRuns, 219 ETag: m.Etag, 220 } 221 return md, nil 222} 223 224// ModelMetadataToUpdate is used when updating an ML model's metadata. 225// Only non-nil fields will be updated. 226type ModelMetadataToUpdate struct { 227 // The user-friendly description of this model. 228 Description optional.String 229 230 // The user-friendly name of this model. 231 Name optional.String 232 233 // The time when this model expires. To remove a model's expiration, 234 // set ExpirationTime to NeverExpire. The zero value is ignored. 235 ExpirationTime time.Time 236 237 // The model's encryption configuration. 238 EncryptionConfig *EncryptionConfig 239 240 labelUpdater 241} 242 243func (mm *ModelMetadataToUpdate) toBQ() (*bq.Model, error) { 244 m := &bq.Model{} 245 forceSend := func(field string) { 246 m.ForceSendFields = append(m.ForceSendFields, field) 247 } 248 249 if mm.Description != nil { 250 m.Description = optional.ToString(mm.Description) 251 forceSend("Description") 252 } 253 254 if mm.Name != nil { 255 m.FriendlyName = optional.ToString(mm.Name) 256 forceSend("FriendlyName") 257 } 258 259 if mm.EncryptionConfig != nil { 260 m.EncryptionConfiguration = mm.EncryptionConfig.toBQ() 261 } 262 263 if !validExpiration(mm.ExpirationTime) { 264 return nil, invalidTimeError(mm.ExpirationTime) 265 } 266 if mm.ExpirationTime == NeverExpire { 267 m.NullFields = append(m.NullFields, "ExpirationTime") 268 } else if !mm.ExpirationTime.IsZero() { 269 m.ExpirationTime = mm.ExpirationTime.UnixNano() / 1e6 270 forceSend("ExpirationTime") 271 } 272 labels, forces, nulls := mm.update() 273 m.Labels = labels 274 m.ForceSendFields = append(m.ForceSendFields, forces...) 275 m.NullFields = append(m.NullFields, nulls...) 276 return m, nil 277} 278