1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include <stdlib.h>
7 #include <assert.h>
8 #include "yaksi.h"
9 #include "yaksu.h"
10 #include "yaksuri_cudai.h"
11 
get_num_elements(yaksi_type_s * type)12 static uintptr_t get_num_elements(yaksi_type_s * type)
13 {
14     switch (type->kind) {
15         case YAKSI_TYPE_KIND__BUILTIN:
16             return type->num_contig;
17 
18         case YAKSI_TYPE_KIND__CONTIG:
19             return type->u.contig.count * get_num_elements(type->u.contig.child);
20 
21         case YAKSI_TYPE_KIND__RESIZED:
22             return get_num_elements(type->u.resized.child);
23 
24         case YAKSI_TYPE_KIND__HVECTOR:
25             return type->u.hvector.count * type->u.hvector.blocklength *
26                 get_num_elements(type->u.hvector.child);
27 
28         case YAKSI_TYPE_KIND__BLKHINDX:
29             return type->u.blkhindx.count * type->u.blkhindx.blocklength *
30                 get_num_elements(type->u.blkhindx.child);
31 
32         case YAKSI_TYPE_KIND__HINDEXED:
33             {
34                 uintptr_t nelems = 0;
35                 for (int i = 0; i < type->u.hindexed.count; i++)
36                     nelems += type->u.hindexed.array_of_blocklengths[i];
37                 nelems *= get_num_elements(type->u.hindexed.child);
38                 return nelems;
39             }
40 
41         default:
42             return 0;
43     }
44 }
45 
yaksuri_cudai_type_create_hook(yaksi_type_s * type)46 int yaksuri_cudai_type_create_hook(yaksi_type_s * type)
47 {
48     int rc = YAKSA_SUCCESS;
49 
50     type->backend.cuda.priv = malloc(sizeof(yaksuri_cudai_type_s));
51     YAKSU_ERR_CHKANDJUMP(!type->backend.cuda.priv, rc, YAKSA_ERR__OUT_OF_MEM, fn_fail);
52 
53     yaksuri_cudai_type_s *cuda;
54     cuda = (yaksuri_cudai_type_s *) type->backend.cuda.priv;
55 
56     cuda->num_elements = get_num_elements(type);
57     cuda->md = NULL;
58     pthread_mutex_init(&cuda->mdmutex, NULL);
59 
60     rc = yaksuri_cudai_populate_pupfns(type);
61     YAKSU_ERR_CHECK(rc, fn_fail);
62 
63   fn_exit:
64     return rc;
65   fn_fail:
66     goto fn_exit;
67 }
68 
yaksuri_cudai_type_free_hook(yaksi_type_s * type)69 int yaksuri_cudai_type_free_hook(yaksi_type_s * type)
70 {
71     int rc = YAKSA_SUCCESS;
72     yaksuri_cudai_type_s *cuda = (yaksuri_cudai_type_s *) type->backend.cuda.priv;
73     cudaError_t cerr;
74 
75     pthread_mutex_destroy(&cuda->mdmutex);
76     if (cuda->md) {
77         if (type->kind == YAKSI_TYPE_KIND__BLKHINDX) {
78             assert(cuda->md->u.blkhindx.array_of_displs);
79             cerr = cudaFree((void *) cuda->md->u.blkhindx.array_of_displs);
80             YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
81         } else if (type->kind == YAKSI_TYPE_KIND__HINDEXED) {
82             assert(cuda->md->u.hindexed.array_of_displs);
83             cerr = cudaFree((void *) cuda->md->u.hindexed.array_of_displs);
84             YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
85 
86             assert(cuda->md->u.hindexed.array_of_blocklengths);
87             cerr = cudaFree((void *) cuda->md->u.hindexed.array_of_blocklengths);
88             YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
89         }
90 
91         cerr = cudaFree(cuda->md);
92         YAKSURI_CUDAI_CUDA_ERR_CHKANDJUMP(cerr, rc, fn_fail);
93     }
94 
95     free(cuda);
96 
97   fn_exit:
98     return rc;
99   fn_fail:
100     goto fn_exit;
101 }
102