1 /*
2 * Copyright (C) by Argonne National Laboratory
3 * See COPYRIGHT in top-level directory
4 */
5
6 #include <stdlib.h>
7 #include <assert.h>
8 #include "yaksi.h"
9 #include "yaksu.h"
10 #include "yaksuri_cudai.h"
11
get_num_elements(yaksi_type_s * type)12 static uintptr_t get_num_elements(yaksi_type_s * type)
13 {
14 switch (type->kind) {
15 case YAKSI_TYPE_KIND__BUILTIN:
16 return type->num_contig;
17
18 case YAKSI_TYPE_KIND__CONTIG:
19 return type->u.contig.count * get_num_elements(type->u.contig.child);
20
21 case YAKSI_TYPE_KIND__RESIZED:
22 return get_num_elements(type->u.resized.child);
23
24 case YAKSI_TYPE_KIND__HVECTOR:
25 return type->u.hvector.count * type->u.hvector.blocklength *
26 get_num_elements(type->u.hvector.child);
27
28 case YAKSI_TYPE_KIND__BLKHINDX:
29 return type->u.blkhindx.count * type->u.blkhindx.blocklength *
30 get_num_elements(type->u.blkhindx.child);
31
32 case YAKSI_TYPE_KIND__HINDEXED:
33 {
34 uintptr_t nelems = 0;
35 for (int i = 0; i < type->u.hindexed.count; i++)
36 nelems += type->u.hindexed.array_of_blocklengths[i];
37 nelems *= get_num_elements(type->u.hindexed.child);
38 return nelems;
39 }
40
41 default:
42 return 0;
43 }
44 }
45
yaksuri_cudai_type_create_hook(yaksi_type_s * type)46 int yaksuri_cudai_type_create_hook(yaksi_type_s * type)
47 {
48 int rc = YAKSA_SUCCESS;
49
50 type->backend.cuda.priv = malloc(sizeof(yaksuri_cudai_type_s));
51 YAKSU_ERR_CHKANDJUMP(!type->backend.cuda.priv, rc, YAKSA_ERR__OUT_OF_MEM, fn_fail);
52
53 yaksuri_cudai_type_s *cuda;
54 cuda = (yaksuri_cudai_type_s *) type->backend.cuda.priv;
55
56 cuda->num_elements = get_num_elements(type);
57 cuda->md = NULL;
58 pthread_mutex_init(&cuda->mdmutex, NULL);
59
60 rc = yaksuri_cudai_populate_pupfns(type);
61 YAKSU_ERR_CHECK(rc, fn_fail);
62
63 fn_exit:
64 return rc;
65 fn_fail:
66 goto fn_exit;
67 }
68
yaksuri_cudai_type_free_hook(yaksi_type_s * type)69 int yaksuri_cudai_type_free_hook(yaksi_type_s * type)
70 {
71 int rc = YAKSA_SUCCESS;
72 yaksuri_cudai_type_s *cuda = (yaksuri_cudai_type_s *) type->backend.cuda.priv;
73 cudaError_t cerr;
74
75 pthread_mutex_destroy(&cuda->mdmutex);
76 if (cuda->md) {
77 if (type->kind == YAKSI_TYPE_KIND__BLKHINDX) {
78 assert(cuda->md->u.blkhindx.array_of_displs);
79 cerr = cudaFree((void *) cuda->md->u.blkhindx.array_of_displs);
80 YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
81 } else if (type->kind == YAKSI_TYPE_KIND__HINDEXED) {
82 assert(cuda->md->u.hindexed.array_of_displs);
83 cerr = cudaFree((void *) cuda->md->u.hindexed.array_of_displs);
84 YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
85
86 assert(cuda->md->u.hindexed.array_of_blocklengths);
87 cerr = cudaFree((void *) cuda->md->u.hindexed.array_of_blocklengths);
88 YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
89 }
90
91 cerr = cudaFree(cuda->md);
92 YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
93 }
94
95 free(cuda);
96
97 fn_exit:
98 return rc;
99 fn_fail:
100 goto fn_exit;
101 }
102