1package allocrunner
2
3import (
4	"context"
5	"fmt"
6
7	hclog "github.com/hashicorp/go-hclog"
8	multierror "github.com/hashicorp/go-multierror"
9	"github.com/hashicorp/nomad/client/pluginmanager/csimanager"
10	"github.com/hashicorp/nomad/nomad/structs"
11	"github.com/hashicorp/nomad/plugins/drivers"
12)
13
14// csiHook will wait for remote csi volumes to be attached to the host before
15// continuing.
16//
17// It is a noop for allocs that do not depend on CSI Volumes.
18type csiHook struct {
19	ar             *allocRunner
20	alloc          *structs.Allocation
21	logger         hclog.Logger
22	csimanager     csimanager.Manager
23	rpcClient      RPCer
24	updater        hookResourceSetter
25	volumeRequests map[string]*volumeAndRequest
26}
27
28func (c *csiHook) Name() string {
29	return "csi_hook"
30}
31
32func (c *csiHook) Prerun() error {
33	if !c.shouldRun() {
34		return nil
35	}
36
37	// We use this context only to attach hclog to the gRPC context. The
38	// lifetime is the lifetime of the gRPC stream, not specific RPC timeouts,
39	// but we manage the stream lifetime via Close in the pluginmanager.
40	ctx := context.Background()
41
42	volumes, err := c.claimVolumesFromAlloc()
43	if err != nil {
44		return fmt.Errorf("claim volumes: %v", err)
45	}
46	c.volumeRequests = volumes
47
48	mounts := make(map[string]*csimanager.MountInfo, len(volumes))
49	for alias, pair := range volumes {
50		mounter, err := c.csimanager.MounterForPlugin(ctx, pair.volume.PluginID)
51		if err != nil {
52			return err
53		}
54
55		usageOpts := &csimanager.UsageOptions{
56			ReadOnly:       pair.request.ReadOnly,
57			AttachmentMode: pair.request.AttachmentMode,
58			AccessMode:     pair.request.AccessMode,
59			MountOptions:   pair.request.MountOptions,
60		}
61
62		mountInfo, err := mounter.MountVolume(ctx, pair.volume, c.alloc, usageOpts, pair.publishContext)
63		if err != nil {
64			return err
65		}
66
67		mounts[alias] = mountInfo
68	}
69
70	res := c.updater.GetAllocHookResources()
71	res.CSIMounts = mounts
72	c.updater.SetAllocHookResources(res)
73
74	return nil
75}
76
77// Postrun sends an RPC to the server to unpublish the volume. This may
78// forward client RPCs to the node plugins or to the controller plugins,
79// depending on whether other allocations on this node have claims on this
80// volume.
81func (c *csiHook) Postrun() error {
82	if !c.shouldRun() {
83		return nil
84	}
85
86	var mErr *multierror.Error
87
88	for _, pair := range c.volumeRequests {
89
90		mode := structs.CSIVolumeClaimRead
91		if !pair.request.ReadOnly {
92			mode = structs.CSIVolumeClaimWrite
93		}
94
95		source := pair.request.Source
96		if pair.request.PerAlloc {
97			// NOTE: PerAlloc can't be set if we have canaries
98			source = source + structs.AllocSuffix(c.alloc.Name)
99		}
100
101		req := &structs.CSIVolumeUnpublishRequest{
102			VolumeID: source,
103			Claim: &structs.CSIVolumeClaim{
104				AllocationID: c.alloc.ID,
105				NodeID:       c.alloc.NodeID,
106				Mode:         mode,
107				State:        structs.CSIVolumeClaimStateUnpublishing,
108			},
109			WriteRequest: structs.WriteRequest{
110				Region:    c.alloc.Job.Region,
111				Namespace: c.alloc.Job.Namespace,
112				AuthToken: c.ar.clientConfig.Node.SecretID,
113			},
114		}
115		err := c.rpcClient.RPC("CSIVolume.Unpublish",
116			req, &structs.CSIVolumeUnpublishResponse{})
117		if err != nil {
118			mErr = multierror.Append(mErr, err)
119		}
120	}
121	return mErr.ErrorOrNil()
122}
123
124type volumeAndRequest struct {
125	volume  *structs.CSIVolume
126	request *structs.VolumeRequest
127
128	// When volumeAndRequest was returned from a volume claim, this field will be
129	// populated for plugins that require it.
130	publishContext map[string]string
131}
132
133// claimVolumesFromAlloc is used by the pre-run hook to fetch all of the volume
134// metadata and claim it for use by this alloc/node at the same time.
135func (c *csiHook) claimVolumesFromAlloc() (map[string]*volumeAndRequest, error) {
136	result := make(map[string]*volumeAndRequest)
137	tg := c.alloc.Job.LookupTaskGroup(c.alloc.TaskGroup)
138
139	// Initially, populate the result map with all of the requests
140	for alias, volumeRequest := range tg.Volumes {
141
142		if volumeRequest.Type == structs.VolumeTypeCSI {
143
144			for _, task := range tg.Tasks {
145				caps, err := c.ar.GetTaskDriverCapabilities(task.Name)
146				if err != nil {
147					return nil, fmt.Errorf("could not validate task driver capabilities: %v", err)
148				}
149
150				if caps.MountConfigs == drivers.MountConfigSupportNone {
151					return nil, fmt.Errorf(
152						"task driver %q for %q does not support CSI", task.Driver, task.Name)
153				}
154			}
155
156			result[alias] = &volumeAndRequest{request: volumeRequest}
157		}
158	}
159
160	// Iterate over the result map and upsert the volume field as each volume gets
161	// claimed by the server.
162	for alias, pair := range result {
163		claimType := structs.CSIVolumeClaimWrite
164		if pair.request.ReadOnly {
165			claimType = structs.CSIVolumeClaimRead
166		}
167
168		source := pair.request.Source
169		if pair.request.PerAlloc {
170			source = source + structs.AllocSuffix(c.alloc.Name)
171		}
172
173		req := &structs.CSIVolumeClaimRequest{
174			VolumeID:       source,
175			AllocationID:   c.alloc.ID,
176			NodeID:         c.alloc.NodeID,
177			Claim:          claimType,
178			AccessMode:     pair.request.AccessMode,
179			AttachmentMode: pair.request.AttachmentMode,
180			WriteRequest: structs.WriteRequest{
181				Region:    c.alloc.Job.Region,
182				Namespace: c.alloc.Job.Namespace,
183				AuthToken: c.ar.clientConfig.Node.SecretID,
184			},
185		}
186
187		var resp structs.CSIVolumeClaimResponse
188		if err := c.rpcClient.RPC("CSIVolume.Claim", req, &resp); err != nil {
189			return nil, err
190		}
191
192		if resp.Volume == nil {
193			return nil, fmt.Errorf("Unexpected nil volume returned for ID: %v", pair.request.Source)
194		}
195
196		result[alias].request = pair.request
197		result[alias].volume = resp.Volume
198		result[alias].publishContext = resp.PublishContext
199	}
200
201	return result, nil
202}
203
204func newCSIHook(ar *allocRunner, logger hclog.Logger, alloc *structs.Allocation, rpcClient RPCer, csi csimanager.Manager, updater hookResourceSetter) *csiHook {
205	return &csiHook{
206		ar:         ar,
207		alloc:      alloc,
208		logger:     logger.Named("csi_hook"),
209		rpcClient:  rpcClient,
210		csimanager: csi,
211		updater:    updater,
212	}
213}
214
215func (h *csiHook) shouldRun() bool {
216	tg := h.alloc.Job.LookupTaskGroup(h.alloc.TaskGroup)
217	for _, vol := range tg.Volumes {
218		if vol.Type == structs.VolumeTypeCSI {
219			return true
220		}
221	}
222
223	return false
224}
225