1 /******************************************************************************
2  * Copyright 1998-2019 Lawrence Livermore National Security, LLC and other
3  * HYPRE Project Developers. See the top-level COPYRIGHT file for details.
4  *
5  * SPDX-License-Identifier: (Apache-2.0 OR MIT)
6  ******************************************************************************/
7 
8 #include "seq_mv.h"
9 #include "csr_spgemm_device.h"
10 
11 #if defined(HYPRE_USING_CUDA) || defined(HYPRE_USING_HIP)
12 
13 /* assume d_c is of length m and contains the "sizes" */
14 void
hypre_create_ija(HYPRE_Int m,HYPRE_Int * d_c,HYPRE_Int * d_i,HYPRE_Int ** d_j,HYPRE_Complex ** d_a,HYPRE_Int * nnz)15 hypre_create_ija( HYPRE_Int       m,
16                   HYPRE_Int      *d_c,
17                   HYPRE_Int      *d_i,
18                   HYPRE_Int     **d_j,
19                   HYPRE_Complex **d_a,
20                   HYPRE_Int      *nnz )
21 {
22    hypre_Memset(d_i, 0, sizeof(HYPRE_Int), HYPRE_MEMORY_DEVICE);
23 
24    HYPRE_THRUST_CALL(inclusive_scan, d_c, d_c + m, d_i + 1);
25 
26    hypre_TMemcpy(nnz, d_i + m, HYPRE_Int, 1, HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE);
27 
28    if (d_j)
29    {
30       *d_j = hypre_TAlloc(HYPRE_Int, *nnz, HYPRE_MEMORY_DEVICE);
31    }
32 
33    if (d_a)
34    {
35       *d_a = hypre_TAlloc(HYPRE_Complex, *nnz, HYPRE_MEMORY_DEVICE);
36    }
37 }
38 
39 __global__ void
hypre_SpGemmGhashSize1(HYPRE_Int num_rows,HYPRE_Int * row_id,HYPRE_Int num_ghash,HYPRE_Int * row_sizes,HYPRE_Int * ghash_sizes,HYPRE_Int SHMEM_HASH_SIZE)40 hypre_SpGemmGhashSize1( HYPRE_Int  num_rows,
41                         HYPRE_Int *row_id,
42                         HYPRE_Int  num_ghash,
43                         HYPRE_Int *row_sizes,
44                         HYPRE_Int *ghash_sizes,
45                         HYPRE_Int  SHMEM_HASH_SIZE )
46 {
47    const HYPRE_Int global_thread_id = hypre_cuda_get_grid_thread_id<1,1>();
48 
49    if (global_thread_id >= num_ghash)
50    {
51       return;
52    }
53 
54    HYPRE_Int j = 0;
55 
56    for (HYPRE_Int i = global_thread_id; i < num_rows; i += num_ghash)
57    {
58       const HYPRE_Int rid = row_id ? read_only_load(&row_id[i]) : i;
59       const HYPRE_Int rnz = read_only_load(&row_sizes[rid]);
60       const HYPRE_Int j1 = next_power_of_2(rnz - SHMEM_HASH_SIZE);
61       j = hypre_max(j, j1);
62    }
63 
64    ghash_sizes[global_thread_id] = j;
65 }
66 
67 __global__ void
hypre_SpGemmGhashSize2(HYPRE_Int num_rows,HYPRE_Int * row_id,HYPRE_Int num_ghash,HYPRE_Int * row_sizes,HYPRE_Int * ghash_sizes,HYPRE_Int SHMEM_HASH_SIZE)68 hypre_SpGemmGhashSize2( HYPRE_Int  num_rows,
69                         HYPRE_Int *row_id,
70                         HYPRE_Int  num_ghash,
71                         HYPRE_Int *row_sizes,
72                         HYPRE_Int *ghash_sizes,
73                         HYPRE_Int  SHMEM_HASH_SIZE )
74 {
75    const HYPRE_Int i = hypre_cuda_get_grid_thread_id<1,1>();
76 
77    if (i < num_rows)
78    {
79       const HYPRE_Int rid = row_id ? read_only_load(&row_id[i]) : i;
80       const HYPRE_Int rnz = read_only_load(&row_sizes[rid]);
81       ghash_sizes[rid] = next_power_of_2(rnz - SHMEM_HASH_SIZE);
82    }
83 }
84 
85 HYPRE_Int
hypre_SpGemmCreateGlobalHashTable(HYPRE_Int num_rows,HYPRE_Int * row_id,HYPRE_Int num_ghash,HYPRE_Int * row_sizes,HYPRE_Int SHMEM_HASH_SIZE,HYPRE_Int ** ghash_i_ptr,HYPRE_Int ** ghash_j_ptr,HYPRE_Complex ** ghash_a_ptr,HYPRE_Int * ghash_size_ptr,HYPRE_Int type)86 hypre_SpGemmCreateGlobalHashTable( HYPRE_Int       num_rows,        /* number of rows */
87                                    HYPRE_Int      *row_id,          /* row_id[i] is index of ith row; i if row_id == NULL */
88                                    HYPRE_Int       num_ghash,       /* number of hash tables <= num_rows */
89                                    HYPRE_Int      *row_sizes,       /* row_sizes[rowid[i]] is the size of ith row */
90                                    HYPRE_Int       SHMEM_HASH_SIZE,
91                                    HYPRE_Int     **ghash_i_ptr,     /* of length num_ghash + 1 */
92                                    HYPRE_Int     **ghash_j_ptr,
93                                    HYPRE_Complex **ghash_a_ptr,
94                                    HYPRE_Int      *ghash_size_ptr,
95                                    HYPRE_Int       type )
96 {
97    hypre_assert(type == 2 || num_ghash <= num_rows);
98 
99    HYPRE_Int *ghash_i, ghash_size;
100    dim3 bDim = hypre_GetDefaultCUDABlockDimension();
101 
102    if (type == 1)
103    {
104       ghash_i = hypre_TAlloc(HYPRE_Int, num_ghash + 1, HYPRE_MEMORY_DEVICE);
105       dim3 gDim = hypre_GetDefaultCUDAGridDimension(num_ghash, "thread", bDim);
106       HYPRE_CUDA_LAUNCH( hypre_SpGemmGhashSize1, gDim, bDim,
107                          num_rows, row_id, num_ghash, row_sizes, ghash_i, SHMEM_HASH_SIZE );
108    }
109    else if (type == 2)
110    {
111       ghash_i = hypre_CTAlloc(HYPRE_Int, num_ghash + 1, HYPRE_MEMORY_DEVICE);
112       dim3 gDim = hypre_GetDefaultCUDAGridDimension(num_rows, "thread", bDim);
113       HYPRE_CUDA_LAUNCH( hypre_SpGemmGhashSize2, gDim, bDim,
114                          num_rows, row_id, num_ghash, row_sizes, ghash_i, SHMEM_HASH_SIZE );
115    }
116 
117    hypreDevice_IntegerExclusiveScan(num_ghash + 1, ghash_i);
118 
119    hypre_TMemcpy(&ghash_size, ghash_i + num_ghash, HYPRE_Int, 1, HYPRE_MEMORY_HOST, HYPRE_MEMORY_DEVICE);
120 
121    if (!ghash_size)
122    {
123       hypre_TFree(ghash_i, HYPRE_MEMORY_DEVICE);  hypre_assert(ghash_i == NULL);
124    }
125 
126    if (ghash_i_ptr)
127    {
128       *ghash_i_ptr = ghash_i;
129    }
130 
131    if (ghash_j_ptr)
132    {
133       *ghash_j_ptr = hypre_TAlloc(HYPRE_Int, ghash_size, HYPRE_MEMORY_DEVICE);
134    }
135 
136    if (ghash_a_ptr)
137    {
138       *ghash_a_ptr = hypre_TAlloc(HYPRE_Complex, ghash_size, HYPRE_MEMORY_DEVICE);
139    }
140 
141    if (ghash_size_ptr)
142    {
143       *ghash_size_ptr = ghash_size;
144    }
145 
146    return hypre_error_flag;
147 }
148 
149 #endif
150 
151