1 // SPDX-License-Identifier: MIT
2 
3 /*
4  * Copyright © 2019 Intel Corporation
5  */
6 
7 #include <linux/delay.h>
8 #include <linux/dma-fence.h>
9 #include <linux/dma-fence-chain.h>
10 #include <linux/kernel.h>
11 #include <linux/kthread.h>
12 #include <linux/mm.h>
13 #include <linux/sched/signal.h>
14 #include <linux/slab.h>
15 #include <linux/spinlock.h>
16 #include <linux/random.h>
17 
18 #include "selftest.h"
19 
20 #define CHAIN_SZ (4 << 10)
21 
22 static struct kmem_cache *slab_fences;
23 
24 static inline struct mock_fence {
25 	struct dma_fence base;
26 	spinlock_t lock;
27 } *to_mock_fence(struct dma_fence *f) {
28 	return container_of(f, struct mock_fence, base);
29 }
30 
31 static const char *mock_name(struct dma_fence *f)
32 {
33 	return "mock";
34 }
35 
36 static void mock_fence_release(struct dma_fence *f)
37 {
38 	kmem_cache_free(slab_fences, to_mock_fence(f));
39 }
40 
41 static const struct dma_fence_ops mock_ops = {
42 	.get_driver_name = mock_name,
43 	.get_timeline_name = mock_name,
44 	.release = mock_fence_release,
45 };
46 
47 static struct dma_fence *mock_fence(void)
48 {
49 	struct mock_fence *f;
50 
51 	f = kmem_cache_alloc(slab_fences, GFP_KERNEL);
52 	if (!f)
53 		return NULL;
54 
55 	spin_lock_init(&f->lock);
56 	dma_fence_init(&f->base, &mock_ops, &f->lock, 0, 0);
57 
58 	return &f->base;
59 }
60 
61 static struct dma_fence *mock_chain(struct dma_fence *prev,
62 				    struct dma_fence *fence,
63 				    u64 seqno)
64 {
65 	struct dma_fence_chain *f;
66 
67 	f = dma_fence_chain_alloc();
68 	if (!f)
69 		return NULL;
70 
71 	dma_fence_chain_init(f, dma_fence_get(prev), dma_fence_get(fence),
72 			     seqno);
73 
74 	return &f->base;
75 }
76 
77 static int sanitycheck(void *arg)
78 {
79 	struct dma_fence *f, *chain;
80 	int err = 0;
81 
82 	f = mock_fence();
83 	if (!f)
84 		return -ENOMEM;
85 
86 	chain = mock_chain(NULL, f, 1);
87 	if (chain)
88 		dma_fence_enable_sw_signaling(chain);
89 	else
90 		err = -ENOMEM;
91 
92 	dma_fence_signal(f);
93 	dma_fence_put(f);
94 
95 	dma_fence_put(chain);
96 
97 	return err;
98 }
99 
100 struct fence_chains {
101 	unsigned int chain_length;
102 	struct dma_fence **fences;
103 	struct dma_fence **chains;
104 
105 	struct dma_fence *tail;
106 };
107 
108 static uint64_t seqno_inc(unsigned int i)
109 {
110 	return i + 1;
111 }
112 
113 static int fence_chains_init(struct fence_chains *fc, unsigned int count,
114 			     uint64_t (*seqno_fn)(unsigned int))
115 {
116 	unsigned int i;
117 	int err = 0;
118 
119 	fc->chains = kvmalloc_array(count, sizeof(*fc->chains),
120 				    GFP_KERNEL | __GFP_ZERO);
121 	if (!fc->chains)
122 		return -ENOMEM;
123 
124 	fc->fences = kvmalloc_array(count, sizeof(*fc->fences),
125 				    GFP_KERNEL | __GFP_ZERO);
126 	if (!fc->fences) {
127 		err = -ENOMEM;
128 		goto err_chains;
129 	}
130 
131 	fc->tail = NULL;
132 	for (i = 0; i < count; i++) {
133 		fc->fences[i] = mock_fence();
134 		if (!fc->fences[i]) {
135 			err = -ENOMEM;
136 			goto unwind;
137 		}
138 
139 		fc->chains[i] = mock_chain(fc->tail,
140 					   fc->fences[i],
141 					   seqno_fn(i));
142 		if (!fc->chains[i]) {
143 			err = -ENOMEM;
144 			goto unwind;
145 		}
146 
147 		fc->tail = fc->chains[i];
148 
149 		dma_fence_enable_sw_signaling(fc->chains[i]);
150 	}
151 
152 	fc->chain_length = i;
153 	return 0;
154 
155 unwind:
156 	for (i = 0; i < count; i++) {
157 		dma_fence_put(fc->fences[i]);
158 		dma_fence_put(fc->chains[i]);
159 	}
160 	kvfree(fc->fences);
161 err_chains:
162 	kvfree(fc->chains);
163 	return err;
164 }
165 
166 static void fence_chains_fini(struct fence_chains *fc)
167 {
168 	unsigned int i;
169 
170 	for (i = 0; i < fc->chain_length; i++) {
171 		dma_fence_signal(fc->fences[i]);
172 		dma_fence_put(fc->fences[i]);
173 	}
174 	kvfree(fc->fences);
175 
176 	for (i = 0; i < fc->chain_length; i++)
177 		dma_fence_put(fc->chains[i]);
178 	kvfree(fc->chains);
179 }
180 
181 static int find_seqno(void *arg)
182 {
183 	struct fence_chains fc;
184 	struct dma_fence *fence;
185 	int err;
186 	int i;
187 
188 	err = fence_chains_init(&fc, 64, seqno_inc);
189 	if (err)
190 		return err;
191 
192 	fence = dma_fence_get(fc.tail);
193 	err = dma_fence_chain_find_seqno(&fence, 0);
194 	dma_fence_put(fence);
195 	if (err) {
196 		pr_err("Reported %d for find_seqno(0)!\n", err);
197 		goto err;
198 	}
199 
200 	for (i = 0; i < fc.chain_length; i++) {
201 		fence = dma_fence_get(fc.tail);
202 		err = dma_fence_chain_find_seqno(&fence, i + 1);
203 		dma_fence_put(fence);
204 		if (err) {
205 			pr_err("Reported %d for find_seqno(%d:%d)!\n",
206 			       err, fc.chain_length + 1, i + 1);
207 			goto err;
208 		}
209 		if (fence != fc.chains[i]) {
210 			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
211 			       fc.chain_length + 1, i + 1);
212 			err = -EINVAL;
213 			goto err;
214 		}
215 
216 		dma_fence_get(fence);
217 		err = dma_fence_chain_find_seqno(&fence, i + 1);
218 		dma_fence_put(fence);
219 		if (err) {
220 			pr_err("Error reported for finding self\n");
221 			goto err;
222 		}
223 		if (fence != fc.chains[i]) {
224 			pr_err("Incorrect fence reported by find self\n");
225 			err = -EINVAL;
226 			goto err;
227 		}
228 
229 		dma_fence_get(fence);
230 		err = dma_fence_chain_find_seqno(&fence, i + 2);
231 		dma_fence_put(fence);
232 		if (!err) {
233 			pr_err("Error not reported for future fence: find_seqno(%d:%d)!\n",
234 			       i + 1, i + 2);
235 			err = -EINVAL;
236 			goto err;
237 		}
238 
239 		dma_fence_get(fence);
240 		err = dma_fence_chain_find_seqno(&fence, i);
241 		dma_fence_put(fence);
242 		if (err) {
243 			pr_err("Error reported for previous fence!\n");
244 			goto err;
245 		}
246 		if (i > 0 && fence != fc.chains[i - 1]) {
247 			pr_err("Incorrect fence reported by find_seqno(%d:%d)\n",
248 			       i + 1, i);
249 			err = -EINVAL;
250 			goto err;
251 		}
252 	}
253 
254 err:
255 	fence_chains_fini(&fc);
256 	return err;
257 }
258 
259 static int find_signaled(void *arg)
260 {
261 	struct fence_chains fc;
262 	struct dma_fence *fence;
263 	int err;
264 
265 	err = fence_chains_init(&fc, 2, seqno_inc);
266 	if (err)
267 		return err;
268 
269 	dma_fence_signal(fc.fences[0]);
270 
271 	fence = dma_fence_get(fc.tail);
272 	err = dma_fence_chain_find_seqno(&fence, 1);
273 	dma_fence_put(fence);
274 	if (err) {
275 		pr_err("Reported %d for find_seqno()!\n", err);
276 		goto err;
277 	}
278 
279 	if (fence && fence != fc.chains[0]) {
280 		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:1\n",
281 		       fence->seqno);
282 
283 		dma_fence_get(fence);
284 		err = dma_fence_chain_find_seqno(&fence, 1);
285 		dma_fence_put(fence);
286 		if (err)
287 			pr_err("Reported %d for finding self!\n", err);
288 
289 		err = -EINVAL;
290 	}
291 
292 err:
293 	fence_chains_fini(&fc);
294 	return err;
295 }
296 
297 static int find_out_of_order(void *arg)
298 {
299 	struct fence_chains fc;
300 	struct dma_fence *fence;
301 	int err;
302 
303 	err = fence_chains_init(&fc, 3, seqno_inc);
304 	if (err)
305 		return err;
306 
307 	dma_fence_signal(fc.fences[1]);
308 
309 	fence = dma_fence_get(fc.tail);
310 	err = dma_fence_chain_find_seqno(&fence, 2);
311 	dma_fence_put(fence);
312 	if (err) {
313 		pr_err("Reported %d for find_seqno()!\n", err);
314 		goto err;
315 	}
316 
317 	/*
318 	 * We signaled the middle fence (2) of the 1-2-3 chain. The behavior
319 	 * of the dma-fence-chain is to make us wait for all the fences up to
320 	 * the point we want. Since fence 1 is still not signaled, this what
321 	 * we should get as fence to wait upon (fence 2 being garbage
322 	 * collected during the traversal of the chain).
323 	 */
324 	if (fence != fc.chains[0]) {
325 		pr_err("Incorrect chain-fence.seqno:%lld reported for completed seqno:2\n",
326 		       fence ? fence->seqno : 0);
327 
328 		err = -EINVAL;
329 	}
330 
331 err:
332 	fence_chains_fini(&fc);
333 	return err;
334 }
335 
336 static uint64_t seqno_inc2(unsigned int i)
337 {
338 	return 2 * i + 2;
339 }
340 
341 static int find_gap(void *arg)
342 {
343 	struct fence_chains fc;
344 	struct dma_fence *fence;
345 	int err;
346 	int i;
347 
348 	err = fence_chains_init(&fc, 64, seqno_inc2);
349 	if (err)
350 		return err;
351 
352 	for (i = 0; i < fc.chain_length; i++) {
353 		fence = dma_fence_get(fc.tail);
354 		err = dma_fence_chain_find_seqno(&fence, 2 * i + 1);
355 		dma_fence_put(fence);
356 		if (err) {
357 			pr_err("Reported %d for find_seqno(%d:%d)!\n",
358 			       err, fc.chain_length + 1, 2 * i + 1);
359 			goto err;
360 		}
361 		if (fence != fc.chains[i]) {
362 			pr_err("Incorrect fence.seqno:%lld reported by find_seqno(%d:%d)\n",
363 			       fence->seqno,
364 			       fc.chain_length + 1,
365 			       2 * i + 1);
366 			err = -EINVAL;
367 			goto err;
368 		}
369 
370 		dma_fence_get(fence);
371 		err = dma_fence_chain_find_seqno(&fence, 2 * i + 2);
372 		dma_fence_put(fence);
373 		if (err) {
374 			pr_err("Error reported for finding self\n");
375 			goto err;
376 		}
377 		if (fence != fc.chains[i]) {
378 			pr_err("Incorrect fence reported by find self\n");
379 			err = -EINVAL;
380 			goto err;
381 		}
382 	}
383 
384 err:
385 	fence_chains_fini(&fc);
386 	return err;
387 }
388 
389 struct find_race {
390 	struct fence_chains fc;
391 	atomic_t children;
392 };
393 
394 static int __find_race(void *arg)
395 {
396 	struct find_race *data = arg;
397 	int err = 0;
398 
399 	while (!kthread_should_stop()) {
400 		struct dma_fence *fence = dma_fence_get(data->fc.tail);
401 		int seqno;
402 
403 		seqno = get_random_u32_inclusive(1, data->fc.chain_length);
404 
405 		err = dma_fence_chain_find_seqno(&fence, seqno);
406 		if (err) {
407 			pr_err("Failed to find fence seqno:%d\n",
408 			       seqno);
409 			dma_fence_put(fence);
410 			break;
411 		}
412 		if (!fence)
413 			goto signal;
414 
415 		/*
416 		 * We can only find ourselves if we are on fence we were
417 		 * looking for.
418 		 */
419 		if (fence->seqno == seqno) {
420 			err = dma_fence_chain_find_seqno(&fence, seqno);
421 			if (err) {
422 				pr_err("Reported an invalid fence for find-self:%d\n",
423 				       seqno);
424 				dma_fence_put(fence);
425 				break;
426 			}
427 		}
428 
429 		dma_fence_put(fence);
430 
431 signal:
432 		seqno = get_random_u32_below(data->fc.chain_length - 1);
433 		dma_fence_signal(data->fc.fences[seqno]);
434 		cond_resched();
435 	}
436 
437 	if (atomic_dec_and_test(&data->children))
438 		wake_up_var(&data->children);
439 	return err;
440 }
441 
442 static int find_race(void *arg)
443 {
444 	struct find_race data;
445 	int ncpus = num_online_cpus();
446 	struct task_struct **threads;
447 	unsigned long count;
448 	int err;
449 	int i;
450 
451 	err = fence_chains_init(&data.fc, CHAIN_SZ, seqno_inc);
452 	if (err)
453 		return err;
454 
455 	threads = kmalloc_array(ncpus, sizeof(*threads), GFP_KERNEL);
456 	if (!threads) {
457 		err = -ENOMEM;
458 		goto err;
459 	}
460 
461 	atomic_set(&data.children, 0);
462 	for (i = 0; i < ncpus; i++) {
463 		threads[i] = kthread_run(__find_race, &data, "dmabuf/%d", i);
464 		if (IS_ERR(threads[i])) {
465 			ncpus = i;
466 			break;
467 		}
468 		atomic_inc(&data.children);
469 		get_task_struct(threads[i]);
470 	}
471 
472 	wait_var_event_timeout(&data.children,
473 			       !atomic_read(&data.children),
474 			       5 * HZ);
475 
476 	for (i = 0; i < ncpus; i++) {
477 		int ret;
478 
479 		ret = kthread_stop_put(threads[i]);
480 		if (ret && !err)
481 			err = ret;
482 	}
483 	kfree(threads);
484 
485 	count = 0;
486 	for (i = 0; i < data.fc.chain_length; i++)
487 		if (dma_fence_is_signaled(data.fc.fences[i]))
488 			count++;
489 	pr_info("Completed %lu cycles\n", count);
490 
491 err:
492 	fence_chains_fini(&data.fc);
493 	return err;
494 }
495 
496 static int signal_forward(void *arg)
497 {
498 	struct fence_chains fc;
499 	int err;
500 	int i;
501 
502 	err = fence_chains_init(&fc, 64, seqno_inc);
503 	if (err)
504 		return err;
505 
506 	for (i = 0; i < fc.chain_length; i++) {
507 		dma_fence_signal(fc.fences[i]);
508 
509 		if (!dma_fence_is_signaled(fc.chains[i])) {
510 			pr_err("chain[%d] not signaled!\n", i);
511 			err = -EINVAL;
512 			goto err;
513 		}
514 
515 		if (i + 1 < fc.chain_length &&
516 		    dma_fence_is_signaled(fc.chains[i + 1])) {
517 			pr_err("chain[%d] is signaled!\n", i);
518 			err = -EINVAL;
519 			goto err;
520 		}
521 	}
522 
523 err:
524 	fence_chains_fini(&fc);
525 	return err;
526 }
527 
528 static int signal_backward(void *arg)
529 {
530 	struct fence_chains fc;
531 	int err;
532 	int i;
533 
534 	err = fence_chains_init(&fc, 64, seqno_inc);
535 	if (err)
536 		return err;
537 
538 	for (i = fc.chain_length; i--; ) {
539 		dma_fence_signal(fc.fences[i]);
540 
541 		if (i > 0 && dma_fence_is_signaled(fc.chains[i])) {
542 			pr_err("chain[%d] is signaled!\n", i);
543 			err = -EINVAL;
544 			goto err;
545 		}
546 	}
547 
548 	for (i = 0; i < fc.chain_length; i++) {
549 		if (!dma_fence_is_signaled(fc.chains[i])) {
550 			pr_err("chain[%d] was not signaled!\n", i);
551 			err = -EINVAL;
552 			goto err;
553 		}
554 	}
555 
556 err:
557 	fence_chains_fini(&fc);
558 	return err;
559 }
560 
561 static int __wait_fence_chains(void *arg)
562 {
563 	struct fence_chains *fc = arg;
564 
565 	if (dma_fence_wait(fc->tail, false))
566 		return -EIO;
567 
568 	return 0;
569 }
570 
571 static int wait_forward(void *arg)
572 {
573 	struct fence_chains fc;
574 	struct task_struct *tsk;
575 	int err;
576 	int i;
577 
578 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
579 	if (err)
580 		return err;
581 
582 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
583 	if (IS_ERR(tsk)) {
584 		err = PTR_ERR(tsk);
585 		goto err;
586 	}
587 	get_task_struct(tsk);
588 	yield_to(tsk, true);
589 
590 	for (i = 0; i < fc.chain_length; i++)
591 		dma_fence_signal(fc.fences[i]);
592 
593 	err = kthread_stop_put(tsk);
594 
595 err:
596 	fence_chains_fini(&fc);
597 	return err;
598 }
599 
600 static int wait_backward(void *arg)
601 {
602 	struct fence_chains fc;
603 	struct task_struct *tsk;
604 	int err;
605 	int i;
606 
607 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
608 	if (err)
609 		return err;
610 
611 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
612 	if (IS_ERR(tsk)) {
613 		err = PTR_ERR(tsk);
614 		goto err;
615 	}
616 	get_task_struct(tsk);
617 	yield_to(tsk, true);
618 
619 	for (i = fc.chain_length; i--; )
620 		dma_fence_signal(fc.fences[i]);
621 
622 	err = kthread_stop_put(tsk);
623 
624 err:
625 	fence_chains_fini(&fc);
626 	return err;
627 }
628 
629 static void randomise_fences(struct fence_chains *fc)
630 {
631 	unsigned int count = fc->chain_length;
632 
633 	/* Fisher-Yates shuffle courtesy of Knuth */
634 	while (--count) {
635 		unsigned int swp;
636 
637 		swp = get_random_u32_below(count + 1);
638 		if (swp == count)
639 			continue;
640 
641 		swap(fc->fences[count], fc->fences[swp]);
642 	}
643 }
644 
645 static int wait_random(void *arg)
646 {
647 	struct fence_chains fc;
648 	struct task_struct *tsk;
649 	int err;
650 	int i;
651 
652 	err = fence_chains_init(&fc, CHAIN_SZ, seqno_inc);
653 	if (err)
654 		return err;
655 
656 	randomise_fences(&fc);
657 
658 	tsk = kthread_run(__wait_fence_chains, &fc, "dmabuf/wait");
659 	if (IS_ERR(tsk)) {
660 		err = PTR_ERR(tsk);
661 		goto err;
662 	}
663 	get_task_struct(tsk);
664 	yield_to(tsk, true);
665 
666 	for (i = 0; i < fc.chain_length; i++)
667 		dma_fence_signal(fc.fences[i]);
668 
669 	err = kthread_stop_put(tsk);
670 
671 err:
672 	fence_chains_fini(&fc);
673 	return err;
674 }
675 
676 int dma_fence_chain(void)
677 {
678 	static const struct subtest tests[] = {
679 		SUBTEST(sanitycheck),
680 		SUBTEST(find_seqno),
681 		SUBTEST(find_signaled),
682 		SUBTEST(find_out_of_order),
683 		SUBTEST(find_gap),
684 		SUBTEST(find_race),
685 		SUBTEST(signal_forward),
686 		SUBTEST(signal_backward),
687 		SUBTEST(wait_forward),
688 		SUBTEST(wait_backward),
689 		SUBTEST(wait_random),
690 	};
691 	int ret;
692 
693 	pr_info("sizeof(dma_fence_chain)=%zu\n",
694 		sizeof(struct dma_fence_chain));
695 
696 	slab_fences = KMEM_CACHE(mock_fence,
697 				 SLAB_TYPESAFE_BY_RCU |
698 				 SLAB_HWCACHE_ALIGN);
699 	if (!slab_fences)
700 		return -ENOMEM;
701 
702 	ret = subtests(tests, NULL);
703 
704 	kmem_cache_destroy(slab_fences);
705 	return ret;
706 }
707