1package allocrunner 2 3import ( 4 "context" 5 "fmt" 6 7 hclog "github.com/hashicorp/go-hclog" 8 multierror "github.com/hashicorp/go-multierror" 9 "github.com/hashicorp/nomad/client/pluginmanager/csimanager" 10 "github.com/hashicorp/nomad/nomad/structs" 11 "github.com/hashicorp/nomad/plugins/drivers" 12) 13 14// csiHook will wait for remote csi volumes to be attached to the host before 15// continuing. 16// 17// It is a noop for allocs that do not depend on CSI Volumes. 18type csiHook struct { 19 ar *allocRunner 20 alloc *structs.Allocation 21 logger hclog.Logger 22 csimanager csimanager.Manager 23 rpcClient RPCer 24 updater hookResourceSetter 25 volumeRequests map[string]*volumeAndRequest 26} 27 28func (c *csiHook) Name() string { 29 return "csi_hook" 30} 31 32func (c *csiHook) Prerun() error { 33 if !c.shouldRun() { 34 return nil 35 } 36 37 // We use this context only to attach hclog to the gRPC context. The 38 // lifetime is the lifetime of the gRPC stream, not specific RPC timeouts, 39 // but we manage the stream lifetime via Close in the pluginmanager. 40 ctx := context.Background() 41 42 volumes, err := c.claimVolumesFromAlloc() 43 if err != nil { 44 return fmt.Errorf("claim volumes: %v", err) 45 } 46 c.volumeRequests = volumes 47 48 mounts := make(map[string]*csimanager.MountInfo, len(volumes)) 49 for alias, pair := range volumes { 50 mounter, err := c.csimanager.MounterForPlugin(ctx, pair.volume.PluginID) 51 if err != nil { 52 return err 53 } 54 55 usageOpts := &csimanager.UsageOptions{ 56 ReadOnly: pair.request.ReadOnly, 57 AttachmentMode: pair.request.AttachmentMode, 58 AccessMode: pair.request.AccessMode, 59 MountOptions: pair.request.MountOptions, 60 } 61 62 mountInfo, err := mounter.MountVolume(ctx, pair.volume, c.alloc, usageOpts, pair.publishContext) 63 if err != nil { 64 return err 65 } 66 67 mounts[alias] = mountInfo 68 } 69 70 res := c.updater.GetAllocHookResources() 71 res.CSIMounts = mounts 72 c.updater.SetAllocHookResources(res) 73 74 return nil 75} 76 77// Postrun sends an RPC to the server to unpublish the volume. This may 78// forward client RPCs to the node plugins or to the controller plugins, 79// depending on whether other allocations on this node have claims on this 80// volume. 81func (c *csiHook) Postrun() error { 82 if !c.shouldRun() { 83 return nil 84 } 85 86 var mErr *multierror.Error 87 88 for _, pair := range c.volumeRequests { 89 90 mode := structs.CSIVolumeClaimRead 91 if !pair.request.ReadOnly { 92 mode = structs.CSIVolumeClaimWrite 93 } 94 95 source := pair.request.Source 96 if pair.request.PerAlloc { 97 // NOTE: PerAlloc can't be set if we have canaries 98 source = source + structs.AllocSuffix(c.alloc.Name) 99 } 100 101 req := &structs.CSIVolumeUnpublishRequest{ 102 VolumeID: source, 103 Claim: &structs.CSIVolumeClaim{ 104 AllocationID: c.alloc.ID, 105 NodeID: c.alloc.NodeID, 106 Mode: mode, 107 State: structs.CSIVolumeClaimStateUnpublishing, 108 }, 109 WriteRequest: structs.WriteRequest{ 110 Region: c.alloc.Job.Region, 111 Namespace: c.alloc.Job.Namespace, 112 AuthToken: c.ar.clientConfig.Node.SecretID, 113 }, 114 } 115 err := c.rpcClient.RPC("CSIVolume.Unpublish", 116 req, &structs.CSIVolumeUnpublishResponse{}) 117 if err != nil { 118 mErr = multierror.Append(mErr, err) 119 } 120 } 121 return mErr.ErrorOrNil() 122} 123 124type volumeAndRequest struct { 125 volume *structs.CSIVolume 126 request *structs.VolumeRequest 127 128 // When volumeAndRequest was returned from a volume claim, this field will be 129 // populated for plugins that require it. 130 publishContext map[string]string 131} 132 133// claimVolumesFromAlloc is used by the pre-run hook to fetch all of the volume 134// metadata and claim it for use by this alloc/node at the same time. 135func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) { 136 result := make(map[string]*volumeAndRequest) 137 tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup) 138 139 // Initially, populate the result map with all of the requests 140 for alias, volumeRequest := range tg.Volumes { 141 142 if volumeRequest.Type == structs.VolumeTypeCSI { 143 144 for _, task := range tg.Tasks { 145 caps, err := c.ar.GetTaskDriverCapabilities(task.Name) 146 if err != nil { 147 return nil, fmt.Errorf("could not validate task driver capabilities: %v", err) 148 } 149 150 if caps.MountConfigs == drivers.MountConfigSupportNone { 151 return nil, fmt.Errorf( 152 "task driver %q for %q does not support CSI", task.Driver, task.Name) 153 } 154 } 155 156 result[alias] = &volumeAndRequest{request: volumeRequest} 157 } 158 } 159 160 // Iterate over the result map and upsert the volume field as each volume gets 161 // claimed by the server. 162 for alias, pair := range result { 163 claimType := structs.CSIVolumeClaimWrite 164 if pair.request.ReadOnly { 165 claimType = structs.CSIVolumeClaimRead 166 } 167 168 source := pair.request.Source 169 if pair.request.PerAlloc { 170 source = source + structs.AllocSuffix(c.alloc.Name) 171 } 172 173 req := &structs.CSIVolumeClaimRequest{ 174 VolumeID: source, 175 AllocationID: c.alloc.ID, 176 NodeID: c.alloc.NodeID, 177 Claim: claimType, 178 AccessMode: pair.request.AccessMode, 179 AttachmentMode: pair.request.AttachmentMode, 180 WriteRequest: structs.WriteRequest{ 181 Region: c.alloc.Job.Region, 182 Namespace: c.alloc.Job.Namespace, 183 AuthToken: c.ar.clientConfig.Node.SecretID, 184 }, 185 } 186 187 var resp structs.CSIVolumeClaimResponse 188 if err := c.rpcClient.RPC("CSIVolume.Claim", req, &resp); err != nil { 189 return nil, err 190 } 191 192 if resp.Volume == nil { 193 return nil, fmt.Errorf("Unexpected nil volume returned for ID: %v", pair.request.Source) 194 } 195 196 result[alias].request = pair.request 197 result[alias].volume = resp.Volume 198 result[alias].publishContext = resp.PublishContext 199 } 200 201 return result, nil 202} 203 204func newCSIHook(ar *allocRunner, logger hclog.Logger, alloc *structs.Allocation, rpcClient RPCer, csi csimanager.Manager, updater hookResourceSetter) *csiHook { 205 return &csiHook{ 206 ar: ar, 207 alloc: alloc, 208 logger: logger.Named("csi_hook"), 209 rpcClient: rpcClient, 210 csimanager: csi, 211 updater: updater, 212 } 213} 214 215func (h *csiHook) shouldRun() bool { 216 tg := h.alloc.Job.LookupTaskGroup(h.alloc.TaskGroup) 217 for _, vol := range tg.Volumes { 218 if vol.Type == structs.VolumeTypeCSI { 219 return true 220 } 221 } 222 223 return false 224} 225