1 /*
2  * Copyright (C) Mellanox Technologies Ltd. 2001-2019.  ALL RIGHTS RESERVED.
3  * See file LICENSE for terms.
4  */
5 
6 package org.openucx.jucx;
7 
8 import org.junit.Test;
9 import org.openucx.jucx.ucp.*;
10 import org.openucx.jucx.ucs.UcsConstants;
11 
12 import java.nio.ByteBuffer;
13 import java.util.Collections;
14 import java.util.concurrent.atomic.AtomicBoolean;
15 
16 import static org.junit.Assert.*;
17 
18 public class UcpWorkerTest extends UcxTest {
19     private static int numWorkers = Runtime.getRuntime().availableProcessors();
20 
21     @Test
testSingleWorker()22     public void testSingleWorker() {
23         UcpContext context = new UcpContext(new UcpParams().requestTagFeature());
24         assertEquals(2, UcsConstants.ThreadMode.UCS_THREAD_MODE_MULTI);
25         assertNotEquals(context.getNativeId(), null);
26         UcpWorker worker = context.newWorker(new UcpWorkerParams());
27         assertNotNull(worker.getNativeId());
28         assertEquals(0, worker.progress()); // No communications was submitted.
29         worker.close();
30         assertNull(worker.getNativeId());
31         context.close();
32     }
33 
34     @Test
testMultipleWorkersWithinSameContext()35     public void testMultipleWorkersWithinSameContext() {
36         UcpContext context = new UcpContext(new UcpParams().requestTagFeature());
37         assertNotEquals(context.getNativeId(), null);
38         UcpWorker[] workers = new UcpWorker[numWorkers];
39         UcpWorkerParams workerParam = new UcpWorkerParams();
40         for (int i = 0; i < numWorkers; i++) {
41             workerParam.clear().setCpu(i).requestThreadSafety();
42             workers[i] = context.newWorker(workerParam);
43             assertNotNull(workers[i].getNativeId());
44         }
45         for (int i = 0; i < numWorkers; i++) {
46             workers[i].close();
47         }
48         context.close();
49     }
50 
51     @Test
testMultipleWorkersFromMultipleContexts()52     public void testMultipleWorkersFromMultipleContexts() {
53         UcpContext tcpContext = new UcpContext(new UcpParams().requestTagFeature());
54         UcpContext rdmaContext = new UcpContext(new UcpParams().requestRmaFeature()
55             .requestAtomic64BitFeature().requestAtomic32BitFeature());
56         UcpWorker[] workers = new UcpWorker[numWorkers];
57         UcpWorkerParams workerParams = new UcpWorkerParams();
58         for (int i = 0; i < numWorkers; i++) {
59             ByteBuffer userData = ByteBuffer.allocateDirect(100);
60             workerParams.clear();
61             if (i % 2 == 0) {
62                 userData.asCharBuffer().put("TCPWorker" + i);
63                 workerParams.requestWakeupRX().setUserData(userData);
64                 workers[i] = tcpContext.newWorker(workerParams);
65             } else {
66                 userData.asCharBuffer().put("RDMAWorker" + i);
67                 workerParams.requestWakeupRMA().setCpu(i).setUserData(userData)
68                     .requestThreadSafety();
69                 workers[i] = rdmaContext.newWorker(workerParams);
70             }
71         }
72         for (int i = 0; i < numWorkers; i++) {
73             workers[i].close();
74         }
75         tcpContext.close();
76         rdmaContext.close();
77     }
78 
79     @Test
testGetWorkerAddress()80     public void testGetWorkerAddress() {
81         UcpContext context = new UcpContext(new UcpParams().requestTagFeature());
82         UcpWorker worker = context.newWorker(new UcpWorkerParams());
83         ByteBuffer workerAddress = worker.getAddress();
84         assertNotNull(workerAddress);
85         assertTrue(workerAddress.capacity() > 0);
86         worker.close();
87         context.close();
88     }
89 
90     @Test
testWorkerSleepWakeup()91     public void testWorkerSleepWakeup() throws InterruptedException {
92         UcpContext context = new UcpContext(new UcpParams()
93             .requestRmaFeature().requestWakeupFeature());
94         UcpWorker worker = context.newWorker(
95             new UcpWorkerParams().requestWakeupRMA());
96 
97         AtomicBoolean success = new AtomicBoolean(false);
98         Thread workerProgressThread = new Thread() {
99             @Override
100             public void run() {
101                 while (!isInterrupted()) {
102                     if (worker.progress() == 0) {
103                         worker.waitForEvents();
104                     }
105                 }
106                 success.set(true);
107             }
108         };
109 
110         workerProgressThread.start();
111 
112         workerProgressThread.interrupt();
113         worker.signal();
114 
115         workerProgressThread.join();
116         assertTrue(success.get());
117 
118         worker.close();
119         context.close();
120     }
121 
122     @Test
testFlushWorker()123     public void testFlushWorker() {
124         int numRequests = 10;
125         // Crerate 2 contexts + 2 workers
126         UcpParams params = new UcpParams().requestRmaFeature();
127         UcpWorkerParams rdmaWorkerParams = new UcpWorkerParams().requestWakeupRMA();
128         UcpContext context1 = new UcpContext(params);
129         UcpContext context2 = new UcpContext(params);
130 
131         ByteBuffer src = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE);
132         ByteBuffer dst = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE);
133         dst.asCharBuffer().put(UcpMemoryTest.RANDOM_TEXT);
134         UcpMemory memory = context2.registerMemory(src);
135 
136         UcpWorker worker1 = context1.newWorker(rdmaWorkerParams);
137         UcpWorker worker2 = context2.newWorker(rdmaWorkerParams);
138 
139         UcpEndpoint ep = worker1.newEndpoint( new UcpEndpointParams()
140             .setUcpAddress(worker2.getAddress()).setPeerErrorHandlingMode());
141         UcpRemoteKey rkey = ep.unpackRemoteKey(memory.getRemoteKeyBuffer());
142 
143         int blockSize = UcpMemoryTest.MEM_SIZE / numRequests;
144         for (int i = 0; i < numRequests; i++) {
145             ep.putNonBlockingImplicit(UcxUtils.getAddress(dst) + i * blockSize,
146                 blockSize, memory.getAddress() + i * blockSize, rkey);
147         }
148 
149         UcpRequest request = worker1.flushNonBlocking(new UcxCallback() {
150             @Override
151             public void onSuccess(UcpRequest request) {
152                 rkey.close();
153                 memory.deregister();
154                 assertEquals(dst.asCharBuffer().toString().trim(), UcpMemoryTest.RANDOM_TEXT);
155             }
156         });
157 
158         while (!request.isCompleted()) {
159             worker1.progress();
160             worker2.progress();
161         }
162 
163         assertTrue(request.isCompleted());
164         Collections.addAll(resources, context1, context2, worker1, worker2, ep);
165         closeResources();
166     }
167 
168     @Test
testTagProbe()169     public void testTagProbe() {
170         UcpParams params = new UcpParams().requestTagFeature();
171         UcpContext context1 = new UcpContext(params);
172         UcpContext context2 = new UcpContext(params);
173 
174         UcpWorker worker1 = context1.newWorker(new UcpWorkerParams());
175         UcpWorker worker2 = context2.newWorker(new UcpWorkerParams());
176         ByteBuffer recvBuffer = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE);
177 
178         UcpTagMessage message = worker1.tagProbeNonBlocking(0, 0, false);
179 
180         assertNull(message);
181 
182         UcpEndpoint endpoint = worker2.newEndpoint(
183             new UcpEndpointParams().setUcpAddress(worker1.getAddress()));
184 
185         endpoint.sendTaggedNonBlocking(
186             ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE), null);
187 
188         do {
189             worker1.progress();
190             worker2.progress();
191             message = worker1.tagProbeNonBlocking(0, 0, true);
192         } while (message == null);
193 
194         assertEquals(UcpMemoryTest.MEM_SIZE, message.getRecvLength());
195         assertEquals(0, message.getSenderTag());
196 
197         UcpRequest recv = worker1.recvTaggedMessageNonBlocking(recvBuffer, message, null);
198 
199         worker1.progressRequest(recv);
200 
201         Collections.addAll(resources, context1, context2, worker1, worker2, endpoint);
202     }
203 }
204