1 /*
2  * Copyright (C) by Argonne National Laboratory
3  *     See COPYRIGHT in top-level directory
4  */
5 
6 #include "mpiimpl.h"
7 #include "mpl.h"
8 
9 #define MAX_PROGRESS_HOOKS 4
10 
11 typedef int (*progress_func_ptr_t) (int *made_progress);
12 typedef struct progress_hook_slot {
13     progress_func_ptr_t func_ptr;
14     MPL_atomic_int_t active;
15 } progress_hook_slot_t;
16 
17 static int registered_progress_hooks = 0;
18 static progress_hook_slot_t progress_hooks[MAX_PROGRESS_HOOKS];
19 
MPIR_Progress_hook_exec_all(int * made_progress)20 int MPIR_Progress_hook_exec_all(int *made_progress)
21 {
22     int mpi_errno = MPI_SUCCESS;
23 
24     for (int i = 0; i < registered_progress_hooks; i++) {
25         int is_active = MPL_atomic_acquire_load_int(&progress_hooks[i].active);
26         if (is_active == TRUE) {
27             MPIR_Assert(progress_hooks[i].func_ptr != NULL);
28             int tmp_progress = 0;
29             mpi_errno = progress_hooks[i].func_ptr(&tmp_progress);
30             MPIR_ERR_CHECK(mpi_errno);
31 
32             *made_progress |= tmp_progress;
33         }
34     }
35 
36   fn_exit:
37     return mpi_errno;
38 
39   fn_fail:
40     goto fn_exit;
41 }
42 
MPIR_Progress_hook_register(int (* progress_fn)(int *),int * id)43 int MPIR_Progress_hook_register(int (*progress_fn) (int *), int *id)
44 {
45     int mpi_errno = MPI_SUCCESS;
46     int i;
47 
48     for (i = 0; i < MAX_PROGRESS_HOOKS; i++) {
49         if (progress_hooks[i].func_ptr == NULL) {
50             progress_hooks[i].func_ptr = progress_fn;
51             MPL_atomic_relaxed_store_int(&progress_hooks[i].active, FALSE);
52             break;
53         }
54     }
55 
56     if (i >= MAX_PROGRESS_HOOKS)
57         goto fn_fail;
58 
59     registered_progress_hooks++;
60 
61     (*id) = i;
62 
63   fn_exit:
64     return mpi_errno;
65   fn_fail:
66     mpi_errno = MPIR_Err_create_code(MPI_SUCCESS, MPIR_ERR_RECOVERABLE,
67                                      "MPID_Progress_register", __LINE__,
68                                      MPI_ERR_INTERN, "**progresshookstoomany", 0);
69     goto fn_exit;
70 }
71 
MPIR_Progress_hook_deregister(int id)72 int MPIR_Progress_hook_deregister(int id)
73 {
74     int mpi_errno = MPI_SUCCESS;
75 
76     MPIR_Assert(id >= 0);
77     MPIR_Assert(id < MAX_PROGRESS_HOOKS);
78     MPIR_Assert(progress_hooks[id].func_ptr != NULL);
79     progress_hooks[id].func_ptr = NULL;
80     MPL_atomic_release_store_int(&progress_hooks[id].active, FALSE);
81 
82     registered_progress_hooks--;
83 
84     return mpi_errno;
85 }
86 
87 /* The below functions assume that each progress hook is protected by
88  * a mutex, which is also shared with other functions that modify the
89  * global state of these hooks.  If we think of each hook as making
90  * progress on a class, then we assume that the public functions to
91  * that class are thread safe.
92  *
93  * In the below code, we only maintain atomicity for reading whether
94  * the "active" field is set or not.  We intentionally avoid using a
95  * critical section for performance reasons.  It is possible that a
96  * different thread deactivates a progress hook after we check if it
97  * is active, but before we execute the function pointer.  In that
98  * case, we simply do an extra poll of the progress hook, which does
99  * not affect correctness.  Note that the func_ptr itself is not
100  * free'd till finalize. */
101 
MPIR_Progress_hook_activate(int id)102 int MPIR_Progress_hook_activate(int id)
103 {
104     int mpi_errno = MPI_SUCCESS;
105 
106     MPIR_Assert(id >= 0);
107     MPIR_Assert(id < MAX_PROGRESS_HOOKS);
108 
109     MPL_atomic_release_store_int(&progress_hooks[id].active, TRUE);
110     MPIR_Assert(progress_hooks[id].func_ptr != NULL);
111 
112     return mpi_errno;
113 }
114 
MPIR_Progress_hook_deactivate(int id)115 int MPIR_Progress_hook_deactivate(int id)
116 {
117     int mpi_errno = MPI_SUCCESS;
118 
119     MPIR_Assert(id >= 0);
120     MPIR_Assert(id < MAX_PROGRESS_HOOKS);
121 
122     MPL_atomic_release_store_int(&progress_hooks[id].active, FALSE);
123     MPIR_Assert(progress_hooks[id].func_ptr != NULL);
124 
125     return mpi_errno;
126 }
127