xref: /linux/drivers/virt/coco/sev-guest/sev-guest.c (revision db10cb9b)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * AMD Secure Encrypted Virtualization (SEV) guest driver interface
4  *
5  * Copyright (C) 2021 Advanced Micro Devices, Inc.
6  *
7  * Author: Brijesh Singh <brijesh.singh@amd.com>
8  */
9 
10 #include <linux/module.h>
11 #include <linux/kernel.h>
12 #include <linux/types.h>
13 #include <linux/mutex.h>
14 #include <linux/io.h>
15 #include <linux/platform_device.h>
16 #include <linux/miscdevice.h>
17 #include <linux/set_memory.h>
18 #include <linux/fs.h>
19 #include <crypto/aead.h>
20 #include <linux/scatterlist.h>
21 #include <linux/psp-sev.h>
22 #include <uapi/linux/sev-guest.h>
23 #include <uapi/linux/psp-sev.h>
24 
25 #include <asm/svm.h>
26 #include <asm/sev.h>
27 
28 #include "sev-guest.h"
29 
30 #define DEVICE_NAME	"sev-guest"
31 #define AAD_LEN		48
32 #define MSG_HDR_VER	1
33 
34 #define SNP_REQ_MAX_RETRY_DURATION	(60*HZ)
35 #define SNP_REQ_RETRY_DELAY		(2*HZ)
36 
37 struct snp_guest_crypto {
38 	struct crypto_aead *tfm;
39 	u8 *iv, *authtag;
40 	int iv_len, a_len;
41 };
42 
43 struct snp_guest_dev {
44 	struct device *dev;
45 	struct miscdevice misc;
46 
47 	void *certs_data;
48 	struct snp_guest_crypto *crypto;
49 	/* request and response are in unencrypted memory */
50 	struct snp_guest_msg *request, *response;
51 
52 	/*
53 	 * Avoid information leakage by double-buffering shared messages
54 	 * in fields that are in regular encrypted memory.
55 	 */
56 	struct snp_guest_msg secret_request, secret_response;
57 
58 	struct snp_secrets_page_layout *layout;
59 	struct snp_req_data input;
60 	union {
61 		struct snp_report_req report;
62 		struct snp_derived_key_req derived_key;
63 		struct snp_ext_report_req ext_report;
64 	} req;
65 	u32 *os_area_msg_seqno;
66 	u8 *vmpck;
67 };
68 
69 static u32 vmpck_id;
70 module_param(vmpck_id, uint, 0444);
71 MODULE_PARM_DESC(vmpck_id, "The VMPCK ID to use when communicating with the PSP.");
72 
73 /* Mutex to serialize the shared buffer access and command handling. */
74 static DEFINE_MUTEX(snp_cmd_mutex);
75 
76 static bool is_vmpck_empty(struct snp_guest_dev *snp_dev)
77 {
78 	char zero_key[VMPCK_KEY_LEN] = {0};
79 
80 	if (snp_dev->vmpck)
81 		return !memcmp(snp_dev->vmpck, zero_key, VMPCK_KEY_LEN);
82 
83 	return true;
84 }
85 
86 /*
87  * If an error is received from the host or AMD Secure Processor (ASP) there
88  * are two options. Either retry the exact same encrypted request or discontinue
89  * using the VMPCK.
90  *
91  * This is because in the current encryption scheme GHCB v2 uses AES-GCM to
92  * encrypt the requests. The IV for this scheme is the sequence number. GCM
93  * cannot tolerate IV reuse.
94  *
95  * The ASP FW v1.51 only increments the sequence numbers on a successful
96  * guest<->ASP back and forth and only accepts messages at its exact sequence
97  * number.
98  *
99  * So if the sequence number were to be reused the encryption scheme is
100  * vulnerable. If the sequence number were incremented for a fresh IV the ASP
101  * will reject the request.
102  */
103 static void snp_disable_vmpck(struct snp_guest_dev *snp_dev)
104 {
105 	dev_alert(snp_dev->dev, "Disabling vmpck_id %d to prevent IV reuse.\n",
106 		  vmpck_id);
107 	memzero_explicit(snp_dev->vmpck, VMPCK_KEY_LEN);
108 	snp_dev->vmpck = NULL;
109 }
110 
111 static inline u64 __snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
112 {
113 	u64 count;
114 
115 	lockdep_assert_held(&snp_cmd_mutex);
116 
117 	/* Read the current message sequence counter from secrets pages */
118 	count = *snp_dev->os_area_msg_seqno;
119 
120 	return count + 1;
121 }
122 
123 /* Return a non-zero on success */
124 static u64 snp_get_msg_seqno(struct snp_guest_dev *snp_dev)
125 {
126 	u64 count = __snp_get_msg_seqno(snp_dev);
127 
128 	/*
129 	 * The message sequence counter for the SNP guest request is a  64-bit
130 	 * value but the version 2 of GHCB specification defines a 32-bit storage
131 	 * for it. If the counter exceeds the 32-bit value then return zero.
132 	 * The caller should check the return value, but if the caller happens to
133 	 * not check the value and use it, then the firmware treats zero as an
134 	 * invalid number and will fail the  message request.
135 	 */
136 	if (count >= UINT_MAX) {
137 		dev_err(snp_dev->dev, "request message sequence counter overflow\n");
138 		return 0;
139 	}
140 
141 	return count;
142 }
143 
144 static void snp_inc_msg_seqno(struct snp_guest_dev *snp_dev)
145 {
146 	/*
147 	 * The counter is also incremented by the PSP, so increment it by 2
148 	 * and save in secrets page.
149 	 */
150 	*snp_dev->os_area_msg_seqno += 2;
151 }
152 
153 static inline struct snp_guest_dev *to_snp_dev(struct file *file)
154 {
155 	struct miscdevice *dev = file->private_data;
156 
157 	return container_of(dev, struct snp_guest_dev, misc);
158 }
159 
160 static struct snp_guest_crypto *init_crypto(struct snp_guest_dev *snp_dev, u8 *key, size_t keylen)
161 {
162 	struct snp_guest_crypto *crypto;
163 
164 	crypto = kzalloc(sizeof(*crypto), GFP_KERNEL_ACCOUNT);
165 	if (!crypto)
166 		return NULL;
167 
168 	crypto->tfm = crypto_alloc_aead("gcm(aes)", 0, 0);
169 	if (IS_ERR(crypto->tfm))
170 		goto e_free;
171 
172 	if (crypto_aead_setkey(crypto->tfm, key, keylen))
173 		goto e_free_crypto;
174 
175 	crypto->iv_len = crypto_aead_ivsize(crypto->tfm);
176 	crypto->iv = kmalloc(crypto->iv_len, GFP_KERNEL_ACCOUNT);
177 	if (!crypto->iv)
178 		goto e_free_crypto;
179 
180 	if (crypto_aead_authsize(crypto->tfm) > MAX_AUTHTAG_LEN) {
181 		if (crypto_aead_setauthsize(crypto->tfm, MAX_AUTHTAG_LEN)) {
182 			dev_err(snp_dev->dev, "failed to set authsize to %d\n", MAX_AUTHTAG_LEN);
183 			goto e_free_iv;
184 		}
185 	}
186 
187 	crypto->a_len = crypto_aead_authsize(crypto->tfm);
188 	crypto->authtag = kmalloc(crypto->a_len, GFP_KERNEL_ACCOUNT);
189 	if (!crypto->authtag)
190 		goto e_free_iv;
191 
192 	return crypto;
193 
194 e_free_iv:
195 	kfree(crypto->iv);
196 e_free_crypto:
197 	crypto_free_aead(crypto->tfm);
198 e_free:
199 	kfree(crypto);
200 
201 	return NULL;
202 }
203 
204 static void deinit_crypto(struct snp_guest_crypto *crypto)
205 {
206 	crypto_free_aead(crypto->tfm);
207 	kfree(crypto->iv);
208 	kfree(crypto->authtag);
209 	kfree(crypto);
210 }
211 
212 static int enc_dec_message(struct snp_guest_crypto *crypto, struct snp_guest_msg *msg,
213 			   u8 *src_buf, u8 *dst_buf, size_t len, bool enc)
214 {
215 	struct snp_guest_msg_hdr *hdr = &msg->hdr;
216 	struct scatterlist src[3], dst[3];
217 	DECLARE_CRYPTO_WAIT(wait);
218 	struct aead_request *req;
219 	int ret;
220 
221 	req = aead_request_alloc(crypto->tfm, GFP_KERNEL);
222 	if (!req)
223 		return -ENOMEM;
224 
225 	/*
226 	 * AEAD memory operations:
227 	 * +------ AAD -------+------- DATA -----+---- AUTHTAG----+
228 	 * |  msg header      |  plaintext       |  hdr->authtag  |
229 	 * | bytes 30h - 5Fh  |    or            |                |
230 	 * |                  |   cipher         |                |
231 	 * +------------------+------------------+----------------+
232 	 */
233 	sg_init_table(src, 3);
234 	sg_set_buf(&src[0], &hdr->algo, AAD_LEN);
235 	sg_set_buf(&src[1], src_buf, hdr->msg_sz);
236 	sg_set_buf(&src[2], hdr->authtag, crypto->a_len);
237 
238 	sg_init_table(dst, 3);
239 	sg_set_buf(&dst[0], &hdr->algo, AAD_LEN);
240 	sg_set_buf(&dst[1], dst_buf, hdr->msg_sz);
241 	sg_set_buf(&dst[2], hdr->authtag, crypto->a_len);
242 
243 	aead_request_set_ad(req, AAD_LEN);
244 	aead_request_set_tfm(req, crypto->tfm);
245 	aead_request_set_callback(req, 0, crypto_req_done, &wait);
246 
247 	aead_request_set_crypt(req, src, dst, len, crypto->iv);
248 	ret = crypto_wait_req(enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req), &wait);
249 
250 	aead_request_free(req);
251 	return ret;
252 }
253 
254 static int __enc_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
255 			 void *plaintext, size_t len)
256 {
257 	struct snp_guest_crypto *crypto = snp_dev->crypto;
258 	struct snp_guest_msg_hdr *hdr = &msg->hdr;
259 
260 	memset(crypto->iv, 0, crypto->iv_len);
261 	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
262 
263 	return enc_dec_message(crypto, msg, plaintext, msg->payload, len, true);
264 }
265 
266 static int dec_payload(struct snp_guest_dev *snp_dev, struct snp_guest_msg *msg,
267 		       void *plaintext, size_t len)
268 {
269 	struct snp_guest_crypto *crypto = snp_dev->crypto;
270 	struct snp_guest_msg_hdr *hdr = &msg->hdr;
271 
272 	/* Build IV with response buffer sequence number */
273 	memset(crypto->iv, 0, crypto->iv_len);
274 	memcpy(crypto->iv, &hdr->msg_seqno, sizeof(hdr->msg_seqno));
275 
276 	return enc_dec_message(crypto, msg, msg->payload, plaintext, len, false);
277 }
278 
279 static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload, u32 sz)
280 {
281 	struct snp_guest_crypto *crypto = snp_dev->crypto;
282 	struct snp_guest_msg *resp = &snp_dev->secret_response;
283 	struct snp_guest_msg *req = &snp_dev->secret_request;
284 	struct snp_guest_msg_hdr *req_hdr = &req->hdr;
285 	struct snp_guest_msg_hdr *resp_hdr = &resp->hdr;
286 
287 	dev_dbg(snp_dev->dev, "response [seqno %lld type %d version %d sz %d]\n",
288 		resp_hdr->msg_seqno, resp_hdr->msg_type, resp_hdr->msg_version, resp_hdr->msg_sz);
289 
290 	/* Copy response from shared memory to encrypted memory. */
291 	memcpy(resp, snp_dev->response, sizeof(*resp));
292 
293 	/* Verify that the sequence counter is incremented by 1 */
294 	if (unlikely(resp_hdr->msg_seqno != (req_hdr->msg_seqno + 1)))
295 		return -EBADMSG;
296 
297 	/* Verify response message type and version number. */
298 	if (resp_hdr->msg_type != (req_hdr->msg_type + 1) ||
299 	    resp_hdr->msg_version != req_hdr->msg_version)
300 		return -EBADMSG;
301 
302 	/*
303 	 * If the message size is greater than our buffer length then return
304 	 * an error.
305 	 */
306 	if (unlikely((resp_hdr->msg_sz + crypto->a_len) > sz))
307 		return -EBADMSG;
308 
309 	/* Decrypt the payload */
310 	return dec_payload(snp_dev, resp, payload, resp_hdr->msg_sz + crypto->a_len);
311 }
312 
313 static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8 type,
314 			void *payload, size_t sz)
315 {
316 	struct snp_guest_msg *req = &snp_dev->secret_request;
317 	struct snp_guest_msg_hdr *hdr = &req->hdr;
318 
319 	memset(req, 0, sizeof(*req));
320 
321 	hdr->algo = SNP_AEAD_AES_256_GCM;
322 	hdr->hdr_version = MSG_HDR_VER;
323 	hdr->hdr_sz = sizeof(*hdr);
324 	hdr->msg_type = type;
325 	hdr->msg_version = version;
326 	hdr->msg_seqno = seqno;
327 	hdr->msg_vmpck = vmpck_id;
328 	hdr->msg_sz = sz;
329 
330 	/* Verify the sequence number is non-zero */
331 	if (!hdr->msg_seqno)
332 		return -ENOSR;
333 
334 	dev_dbg(snp_dev->dev, "request [seqno %lld type %d version %d sz %d]\n",
335 		hdr->msg_seqno, hdr->msg_type, hdr->msg_version, hdr->msg_sz);
336 
337 	return __enc_payload(snp_dev, req, payload, sz);
338 }
339 
340 static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
341 				  struct snp_guest_request_ioctl *rio)
342 {
343 	unsigned long req_start = jiffies;
344 	unsigned int override_npages = 0;
345 	u64 override_err = 0;
346 	int rc;
347 
348 retry_request:
349 	/*
350 	 * Call firmware to process the request. In this function the encrypted
351 	 * message enters shared memory with the host. So after this call the
352 	 * sequence number must be incremented or the VMPCK must be deleted to
353 	 * prevent reuse of the IV.
354 	 */
355 	rc = snp_issue_guest_request(exit_code, &snp_dev->input, rio);
356 	switch (rc) {
357 	case -ENOSPC:
358 		/*
359 		 * If the extended guest request fails due to having too
360 		 * small of a certificate data buffer, retry the same
361 		 * guest request without the extended data request in
362 		 * order to increment the sequence number and thus avoid
363 		 * IV reuse.
364 		 */
365 		override_npages = snp_dev->input.data_npages;
366 		exit_code	= SVM_VMGEXIT_GUEST_REQUEST;
367 
368 		/*
369 		 * Override the error to inform callers the given extended
370 		 * request buffer size was too small and give the caller the
371 		 * required buffer size.
372 		 */
373 		override_err = SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN);
374 
375 		/*
376 		 * If this call to the firmware succeeds, the sequence number can
377 		 * be incremented allowing for continued use of the VMPCK. If
378 		 * there is an error reflected in the return value, this value
379 		 * is checked further down and the result will be the deletion
380 		 * of the VMPCK and the error code being propagated back to the
381 		 * user as an ioctl() return code.
382 		 */
383 		goto retry_request;
384 
385 	/*
386 	 * The host may return SNP_GUEST_VMM_ERR_BUSY if the request has been
387 	 * throttled. Retry in the driver to avoid returning and reusing the
388 	 * message sequence number on a different message.
389 	 */
390 	case -EAGAIN:
391 		if (jiffies - req_start > SNP_REQ_MAX_RETRY_DURATION) {
392 			rc = -ETIMEDOUT;
393 			break;
394 		}
395 		schedule_timeout_killable(SNP_REQ_RETRY_DELAY);
396 		goto retry_request;
397 	}
398 
399 	/*
400 	 * Increment the message sequence number. There is no harm in doing
401 	 * this now because decryption uses the value stored in the response
402 	 * structure and any failure will wipe the VMPCK, preventing further
403 	 * use anyway.
404 	 */
405 	snp_inc_msg_seqno(snp_dev);
406 
407 	if (override_err) {
408 		rio->exitinfo2 = override_err;
409 
410 		/*
411 		 * If an extended guest request was issued and the supplied certificate
412 		 * buffer was not large enough, a standard guest request was issued to
413 		 * prevent IV reuse. If the standard request was successful, return -EIO
414 		 * back to the caller as would have originally been returned.
415 		 */
416 		if (!rc && override_err == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
417 			rc = -EIO;
418 	}
419 
420 	if (override_npages)
421 		snp_dev->input.data_npages = override_npages;
422 
423 	return rc;
424 }
425 
426 static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
427 				struct snp_guest_request_ioctl *rio, u8 type,
428 				void *req_buf, size_t req_sz, void *resp_buf,
429 				u32 resp_sz)
430 {
431 	u64 seqno;
432 	int rc;
433 
434 	/* Get message sequence and verify that its a non-zero */
435 	seqno = snp_get_msg_seqno(snp_dev);
436 	if (!seqno)
437 		return -EIO;
438 
439 	/* Clear shared memory's response for the host to populate. */
440 	memset(snp_dev->response, 0, sizeof(struct snp_guest_msg));
441 
442 	/* Encrypt the userspace provided payload in snp_dev->secret_request. */
443 	rc = enc_payload(snp_dev, seqno, rio->msg_version, type, req_buf, req_sz);
444 	if (rc)
445 		return rc;
446 
447 	/*
448 	 * Write the fully encrypted request to the shared unencrypted
449 	 * request page.
450 	 */
451 	memcpy(snp_dev->request, &snp_dev->secret_request,
452 	       sizeof(snp_dev->secret_request));
453 
454 	rc = __handle_guest_request(snp_dev, exit_code, rio);
455 	if (rc) {
456 		if (rc == -EIO &&
457 		    rio->exitinfo2 == SNP_GUEST_VMM_ERR(SNP_GUEST_VMM_ERR_INVALID_LEN))
458 			return rc;
459 
460 		dev_alert(snp_dev->dev,
461 			  "Detected error from ASP request. rc: %d, exitinfo2: 0x%llx\n",
462 			  rc, rio->exitinfo2);
463 
464 		snp_disable_vmpck(snp_dev);
465 		return rc;
466 	}
467 
468 	rc = verify_and_dec_payload(snp_dev, resp_buf, resp_sz);
469 	if (rc) {
470 		dev_alert(snp_dev->dev, "Detected unexpected decode failure from ASP. rc: %d\n", rc);
471 		snp_disable_vmpck(snp_dev);
472 		return rc;
473 	}
474 
475 	return 0;
476 }
477 
478 static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
479 {
480 	struct snp_guest_crypto *crypto = snp_dev->crypto;
481 	struct snp_report_req *req = &snp_dev->req.report;
482 	struct snp_report_resp *resp;
483 	int rc, resp_len;
484 
485 	lockdep_assert_held(&snp_cmd_mutex);
486 
487 	if (!arg->req_data || !arg->resp_data)
488 		return -EINVAL;
489 
490 	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
491 		return -EFAULT;
492 
493 	/*
494 	 * The intermediate response buffer is used while decrypting the
495 	 * response payload. Make sure that it has enough space to cover the
496 	 * authtag.
497 	 */
498 	resp_len = sizeof(resp->data) + crypto->a_len;
499 	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
500 	if (!resp)
501 		return -ENOMEM;
502 
503 	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
504 				  SNP_MSG_REPORT_REQ, req, sizeof(*req), resp->data,
505 				  resp_len);
506 	if (rc)
507 		goto e_free;
508 
509 	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
510 		rc = -EFAULT;
511 
512 e_free:
513 	kfree(resp);
514 	return rc;
515 }
516 
517 static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
518 {
519 	struct snp_derived_key_req *req = &snp_dev->req.derived_key;
520 	struct snp_guest_crypto *crypto = snp_dev->crypto;
521 	struct snp_derived_key_resp resp = {0};
522 	int rc, resp_len;
523 	/* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
524 	u8 buf[64 + 16];
525 
526 	lockdep_assert_held(&snp_cmd_mutex);
527 
528 	if (!arg->req_data || !arg->resp_data)
529 		return -EINVAL;
530 
531 	/*
532 	 * The intermediate response buffer is used while decrypting the
533 	 * response payload. Make sure that it has enough space to cover the
534 	 * authtag.
535 	 */
536 	resp_len = sizeof(resp.data) + crypto->a_len;
537 	if (sizeof(buf) < resp_len)
538 		return -ENOMEM;
539 
540 	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
541 		return -EFAULT;
542 
543 	rc = handle_guest_request(snp_dev, SVM_VMGEXIT_GUEST_REQUEST, arg,
544 				  SNP_MSG_KEY_REQ, req, sizeof(*req), buf, resp_len);
545 	if (rc)
546 		return rc;
547 
548 	memcpy(resp.data, buf, sizeof(resp.data));
549 	if (copy_to_user((void __user *)arg->resp_data, &resp, sizeof(resp)))
550 		rc = -EFAULT;
551 
552 	/* The response buffer contains the sensitive data, explicitly clear it. */
553 	memzero_explicit(buf, sizeof(buf));
554 	memzero_explicit(&resp, sizeof(resp));
555 	return rc;
556 }
557 
558 static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_ioctl *arg)
559 {
560 	struct snp_ext_report_req *req = &snp_dev->req.ext_report;
561 	struct snp_guest_crypto *crypto = snp_dev->crypto;
562 	struct snp_report_resp *resp;
563 	int ret, npages = 0, resp_len;
564 
565 	lockdep_assert_held(&snp_cmd_mutex);
566 
567 	if (!arg->req_data || !arg->resp_data)
568 		return -EINVAL;
569 
570 	if (copy_from_user(req, (void __user *)arg->req_data, sizeof(*req)))
571 		return -EFAULT;
572 
573 	/* userspace does not want certificate data */
574 	if (!req->certs_len || !req->certs_address)
575 		goto cmd;
576 
577 	if (req->certs_len > SEV_FW_BLOB_MAX_SIZE ||
578 	    !IS_ALIGNED(req->certs_len, PAGE_SIZE))
579 		return -EINVAL;
580 
581 	if (!access_ok((const void __user *)req->certs_address, req->certs_len))
582 		return -EFAULT;
583 
584 	/*
585 	 * Initialize the intermediate buffer with all zeros. This buffer
586 	 * is used in the guest request message to get the certs blob from
587 	 * the host. If host does not supply any certs in it, then copy
588 	 * zeros to indicate that certificate data was not provided.
589 	 */
590 	memset(snp_dev->certs_data, 0, req->certs_len);
591 	npages = req->certs_len >> PAGE_SHIFT;
592 cmd:
593 	/*
594 	 * The intermediate response buffer is used while decrypting the
595 	 * response payload. Make sure that it has enough space to cover the
596 	 * authtag.
597 	 */
598 	resp_len = sizeof(resp->data) + crypto->a_len;
599 	resp = kzalloc(resp_len, GFP_KERNEL_ACCOUNT);
600 	if (!resp)
601 		return -ENOMEM;
602 
603 	snp_dev->input.data_npages = npages;
604 	ret = handle_guest_request(snp_dev, SVM_VMGEXIT_EXT_GUEST_REQUEST, arg,
605 				   SNP_MSG_REPORT_REQ, &req->data,
606 				   sizeof(req->data), resp->data, resp_len);
607 
608 	/* If certs length is invalid then copy the returned length */
609 	if (arg->vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN) {
610 		req->certs_len = snp_dev->input.data_npages << PAGE_SHIFT;
611 
612 		if (copy_to_user((void __user *)arg->req_data, req, sizeof(*req)))
613 			ret = -EFAULT;
614 	}
615 
616 	if (ret)
617 		goto e_free;
618 
619 	if (npages &&
620 	    copy_to_user((void __user *)req->certs_address, snp_dev->certs_data,
621 			 req->certs_len)) {
622 		ret = -EFAULT;
623 		goto e_free;
624 	}
625 
626 	if (copy_to_user((void __user *)arg->resp_data, resp, sizeof(*resp)))
627 		ret = -EFAULT;
628 
629 e_free:
630 	kfree(resp);
631 	return ret;
632 }
633 
634 static long snp_guest_ioctl(struct file *file, unsigned int ioctl, unsigned long arg)
635 {
636 	struct snp_guest_dev *snp_dev = to_snp_dev(file);
637 	void __user *argp = (void __user *)arg;
638 	struct snp_guest_request_ioctl input;
639 	int ret = -ENOTTY;
640 
641 	if (copy_from_user(&input, argp, sizeof(input)))
642 		return -EFAULT;
643 
644 	input.exitinfo2 = 0xff;
645 
646 	/* Message version must be non-zero */
647 	if (!input.msg_version)
648 		return -EINVAL;
649 
650 	mutex_lock(&snp_cmd_mutex);
651 
652 	/* Check if the VMPCK is not empty */
653 	if (is_vmpck_empty(snp_dev)) {
654 		dev_err_ratelimited(snp_dev->dev, "VMPCK is disabled\n");
655 		mutex_unlock(&snp_cmd_mutex);
656 		return -ENOTTY;
657 	}
658 
659 	switch (ioctl) {
660 	case SNP_GET_REPORT:
661 		ret = get_report(snp_dev, &input);
662 		break;
663 	case SNP_GET_DERIVED_KEY:
664 		ret = get_derived_key(snp_dev, &input);
665 		break;
666 	case SNP_GET_EXT_REPORT:
667 		ret = get_ext_report(snp_dev, &input);
668 		break;
669 	default:
670 		break;
671 	}
672 
673 	mutex_unlock(&snp_cmd_mutex);
674 
675 	if (input.exitinfo2 && copy_to_user(argp, &input, sizeof(input)))
676 		return -EFAULT;
677 
678 	return ret;
679 }
680 
681 static void free_shared_pages(void *buf, size_t sz)
682 {
683 	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
684 	int ret;
685 
686 	if (!buf)
687 		return;
688 
689 	ret = set_memory_encrypted((unsigned long)buf, npages);
690 	if (ret) {
691 		WARN_ONCE(ret, "failed to restore encryption mask (leak it)\n");
692 		return;
693 	}
694 
695 	__free_pages(virt_to_page(buf), get_order(sz));
696 }
697 
698 static void *alloc_shared_pages(struct device *dev, size_t sz)
699 {
700 	unsigned int npages = PAGE_ALIGN(sz) >> PAGE_SHIFT;
701 	struct page *page;
702 	int ret;
703 
704 	page = alloc_pages(GFP_KERNEL_ACCOUNT, get_order(sz));
705 	if (!page)
706 		return NULL;
707 
708 	ret = set_memory_decrypted((unsigned long)page_address(page), npages);
709 	if (ret) {
710 		dev_err(dev, "failed to mark page shared, ret=%d\n", ret);
711 		__free_pages(page, get_order(sz));
712 		return NULL;
713 	}
714 
715 	return page_address(page);
716 }
717 
718 static const struct file_operations snp_guest_fops = {
719 	.owner	= THIS_MODULE,
720 	.unlocked_ioctl = snp_guest_ioctl,
721 };
722 
723 static u8 *get_vmpck(int id, struct snp_secrets_page_layout *layout, u32 **seqno)
724 {
725 	u8 *key = NULL;
726 
727 	switch (id) {
728 	case 0:
729 		*seqno = &layout->os_area.msg_seqno_0;
730 		key = layout->vmpck0;
731 		break;
732 	case 1:
733 		*seqno = &layout->os_area.msg_seqno_1;
734 		key = layout->vmpck1;
735 		break;
736 	case 2:
737 		*seqno = &layout->os_area.msg_seqno_2;
738 		key = layout->vmpck2;
739 		break;
740 	case 3:
741 		*seqno = &layout->os_area.msg_seqno_3;
742 		key = layout->vmpck3;
743 		break;
744 	default:
745 		break;
746 	}
747 
748 	return key;
749 }
750 
751 static int __init sev_guest_probe(struct platform_device *pdev)
752 {
753 	struct snp_secrets_page_layout *layout;
754 	struct sev_guest_platform_data *data;
755 	struct device *dev = &pdev->dev;
756 	struct snp_guest_dev *snp_dev;
757 	struct miscdevice *misc;
758 	void __iomem *mapping;
759 	int ret;
760 
761 	if (!cc_platform_has(CC_ATTR_GUEST_SEV_SNP))
762 		return -ENODEV;
763 
764 	if (!dev->platform_data)
765 		return -ENODEV;
766 
767 	data = (struct sev_guest_platform_data *)dev->platform_data;
768 	mapping = ioremap_encrypted(data->secrets_gpa, PAGE_SIZE);
769 	if (!mapping)
770 		return -ENODEV;
771 
772 	layout = (__force void *)mapping;
773 
774 	ret = -ENOMEM;
775 	snp_dev = devm_kzalloc(&pdev->dev, sizeof(struct snp_guest_dev), GFP_KERNEL);
776 	if (!snp_dev)
777 		goto e_unmap;
778 
779 	ret = -EINVAL;
780 	snp_dev->vmpck = get_vmpck(vmpck_id, layout, &snp_dev->os_area_msg_seqno);
781 	if (!snp_dev->vmpck) {
782 		dev_err(dev, "invalid vmpck id %d\n", vmpck_id);
783 		goto e_unmap;
784 	}
785 
786 	/* Verify that VMPCK is not zero. */
787 	if (is_vmpck_empty(snp_dev)) {
788 		dev_err(dev, "vmpck id %d is null\n", vmpck_id);
789 		goto e_unmap;
790 	}
791 
792 	platform_set_drvdata(pdev, snp_dev);
793 	snp_dev->dev = dev;
794 	snp_dev->layout = layout;
795 
796 	/* Allocate the shared page used for the request and response message. */
797 	snp_dev->request = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
798 	if (!snp_dev->request)
799 		goto e_unmap;
800 
801 	snp_dev->response = alloc_shared_pages(dev, sizeof(struct snp_guest_msg));
802 	if (!snp_dev->response)
803 		goto e_free_request;
804 
805 	snp_dev->certs_data = alloc_shared_pages(dev, SEV_FW_BLOB_MAX_SIZE);
806 	if (!snp_dev->certs_data)
807 		goto e_free_response;
808 
809 	ret = -EIO;
810 	snp_dev->crypto = init_crypto(snp_dev, snp_dev->vmpck, VMPCK_KEY_LEN);
811 	if (!snp_dev->crypto)
812 		goto e_free_cert_data;
813 
814 	misc = &snp_dev->misc;
815 	misc->minor = MISC_DYNAMIC_MINOR;
816 	misc->name = DEVICE_NAME;
817 	misc->fops = &snp_guest_fops;
818 
819 	/* initial the input address for guest request */
820 	snp_dev->input.req_gpa = __pa(snp_dev->request);
821 	snp_dev->input.resp_gpa = __pa(snp_dev->response);
822 	snp_dev->input.data_gpa = __pa(snp_dev->certs_data);
823 
824 	ret =  misc_register(misc);
825 	if (ret)
826 		goto e_free_cert_data;
827 
828 	dev_info(dev, "Initialized SEV guest driver (using vmpck_id %d)\n", vmpck_id);
829 	return 0;
830 
831 e_free_cert_data:
832 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
833 e_free_response:
834 	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
835 e_free_request:
836 	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
837 e_unmap:
838 	iounmap(mapping);
839 	return ret;
840 }
841 
842 static int __exit sev_guest_remove(struct platform_device *pdev)
843 {
844 	struct snp_guest_dev *snp_dev = platform_get_drvdata(pdev);
845 
846 	free_shared_pages(snp_dev->certs_data, SEV_FW_BLOB_MAX_SIZE);
847 	free_shared_pages(snp_dev->response, sizeof(struct snp_guest_msg));
848 	free_shared_pages(snp_dev->request, sizeof(struct snp_guest_msg));
849 	deinit_crypto(snp_dev->crypto);
850 	misc_deregister(&snp_dev->misc);
851 
852 	return 0;
853 }
854 
855 /*
856  * This driver is meant to be a common SEV guest interface driver and to
857  * support any SEV guest API. As such, even though it has been introduced
858  * with the SEV-SNP support, it is named "sev-guest".
859  */
860 static struct platform_driver sev_guest_driver = {
861 	.remove		= __exit_p(sev_guest_remove),
862 	.driver		= {
863 		.name = "sev-guest",
864 	},
865 };
866 
867 module_platform_driver_probe(sev_guest_driver, sev_guest_probe);
868 
869 MODULE_AUTHOR("Brijesh Singh <brijesh.singh@amd.com>");
870 MODULE_LICENSE("GPL");
871 MODULE_VERSION("1.0.0");
872 MODULE_DESCRIPTION("AMD SEV Guest Driver");
873 MODULE_ALIAS("platform:sev-guest");
874