1 /* 2 * Copyright (C) Mellanox Technologies Ltd. 2001-2019. ALL RIGHTS RESERVED. 3 * See file LICENSE for terms. 4 */ 5 6 package org.openucx.jucx; 7 8 import org.junit.Test; 9 import org.openucx.jucx.ucp.*; 10 import org.openucx.jucx.ucs.UcsConstants; 11 12 import java.nio.ByteBuffer; 13 import java.util.Collections; 14 import java.util.concurrent.atomic.AtomicBoolean; 15 16 import static org.junit.Assert.*; 17 18 public class UcpWorkerTest extends UcxTest { 19 private static int numWorkers = Runtime.getRuntime().availableProcessors(); 20 21 @Test testSingleWorker()22 public void testSingleWorker() { 23 UcpContext context = new UcpContext(new UcpParams().requestTagFeature()); 24 assertEquals(2, UcsConstants.ThreadMode.UCS_THREAD_MODE_MULTI); 25 assertNotEquals(context.getNativeId(), null); 26 UcpWorker worker = context.newWorker(new UcpWorkerParams()); 27 assertNotNull(worker.getNativeId()); 28 assertEquals(0, worker.progress()); // No communications was submitted. 29 worker.close(); 30 assertNull(worker.getNativeId()); 31 context.close(); 32 } 33 34 @Test testMultipleWorkersWithinSameContext()35 public void testMultipleWorkersWithinSameContext() { 36 UcpContext context = new UcpContext(new UcpParams().requestTagFeature()); 37 assertNotEquals(context.getNativeId(), null); 38 UcpWorker[] workers = new UcpWorker[numWorkers]; 39 UcpWorkerParams workerParam = new UcpWorkerParams(); 40 for (int i = 0; i < numWorkers; i++) { 41 workerParam.clear().setCpu(i).requestThreadSafety(); 42 workers[i] = context.newWorker(workerParam); 43 assertNotNull(workers[i].getNativeId()); 44 } 45 for (int i = 0; i < numWorkers; i++) { 46 workers[i].close(); 47 } 48 context.close(); 49 } 50 51 @Test testMultipleWorkersFromMultipleContexts()52 public void testMultipleWorkersFromMultipleContexts() { 53 UcpContext tcpContext = new UcpContext(new UcpParams().requestTagFeature()); 54 UcpContext rdmaContext = new UcpContext(new UcpParams().requestRmaFeature() 55 .requestAtomic64BitFeature().requestAtomic32BitFeature()); 56 UcpWorker[] workers = new UcpWorker[numWorkers]; 57 UcpWorkerParams workerParams = new UcpWorkerParams(); 58 for (int i = 0; i < numWorkers; i++) { 59 ByteBuffer userData = ByteBuffer.allocateDirect(100); 60 workerParams.clear(); 61 if (i % 2 == 0) { 62 userData.asCharBuffer().put("TCPWorker" + i); 63 workerParams.requestWakeupRX().setUserData(userData); 64 workers[i] = tcpContext.newWorker(workerParams); 65 } else { 66 userData.asCharBuffer().put("RDMAWorker" + i); 67 workerParams.requestWakeupRMA().setCpu(i).setUserData(userData) 68 .requestThreadSafety(); 69 workers[i] = rdmaContext.newWorker(workerParams); 70 } 71 } 72 for (int i = 0; i < numWorkers; i++) { 73 workers[i].close(); 74 } 75 tcpContext.close(); 76 rdmaContext.close(); 77 } 78 79 @Test testGetWorkerAddress()80 public void testGetWorkerAddress() { 81 UcpContext context = new UcpContext(new UcpParams().requestTagFeature()); 82 UcpWorker worker = context.newWorker(new UcpWorkerParams()); 83 ByteBuffer workerAddress = worker.getAddress(); 84 assertNotNull(workerAddress); 85 assertTrue(workerAddress.capacity() > 0); 86 worker.close(); 87 context.close(); 88 } 89 90 @Test testWorkerSleepWakeup()91 public void testWorkerSleepWakeup() throws InterruptedException { 92 UcpContext context = new UcpContext(new UcpParams() 93 .requestRmaFeature().requestWakeupFeature()); 94 UcpWorker worker = context.newWorker( 95 new UcpWorkerParams().requestWakeupRMA()); 96 97 AtomicBoolean success = new AtomicBoolean(false); 98 Thread workerProgressThread = new Thread() { 99 @Override 100 public void run() { 101 while (!isInterrupted()) { 102 if (worker.progress() == 0) { 103 worker.waitForEvents(); 104 } 105 } 106 success.set(true); 107 } 108 }; 109 110 workerProgressThread.start(); 111 112 workerProgressThread.interrupt(); 113 worker.signal(); 114 115 workerProgressThread.join(); 116 assertTrue(success.get()); 117 118 worker.close(); 119 context.close(); 120 } 121 122 @Test testFlushWorker()123 public void testFlushWorker() { 124 int numRequests = 10; 125 // Crerate 2 contexts + 2 workers 126 UcpParams params = new UcpParams().requestRmaFeature(); 127 UcpWorkerParams rdmaWorkerParams = new UcpWorkerParams().requestWakeupRMA(); 128 UcpContext context1 = new UcpContext(params); 129 UcpContext context2 = new UcpContext(params); 130 131 ByteBuffer src = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE); 132 ByteBuffer dst = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE); 133 dst.asCharBuffer().put(UcpMemoryTest.RANDOM_TEXT); 134 UcpMemory memory = context2.registerMemory(src); 135 136 UcpWorker worker1 = context1.newWorker(rdmaWorkerParams); 137 UcpWorker worker2 = context2.newWorker(rdmaWorkerParams); 138 139 UcpEndpoint ep = worker1.newEndpoint( new UcpEndpointParams() 140 .setUcpAddress(worker2.getAddress()).setPeerErrorHandlingMode()); 141 UcpRemoteKey rkey = ep.unpackRemoteKey(memory.getRemoteKeyBuffer()); 142 143 int blockSize = UcpMemoryTest.MEM_SIZE / numRequests; 144 for (int i = 0; i < numRequests; i++) { 145 ep.putNonBlockingImplicit(UcxUtils.getAddress(dst) + i * blockSize, 146 blockSize, memory.getAddress() + i * blockSize, rkey); 147 } 148 149 UcpRequest request = worker1.flushNonBlocking(new UcxCallback() { 150 @Override 151 public void onSuccess(UcpRequest request) { 152 rkey.close(); 153 memory.deregister(); 154 assertEquals(dst.asCharBuffer().toString().trim(), UcpMemoryTest.RANDOM_TEXT); 155 } 156 }); 157 158 while (!request.isCompleted()) { 159 worker1.progress(); 160 worker2.progress(); 161 } 162 163 assertTrue(request.isCompleted()); 164 Collections.addAll(resources, context1, context2, worker1, worker2, ep); 165 closeResources(); 166 } 167 168 @Test testTagProbe()169 public void testTagProbe() { 170 UcpParams params = new UcpParams().requestTagFeature(); 171 UcpContext context1 = new UcpContext(params); 172 UcpContext context2 = new UcpContext(params); 173 174 UcpWorker worker1 = context1.newWorker(new UcpWorkerParams()); 175 UcpWorker worker2 = context2.newWorker(new UcpWorkerParams()); 176 ByteBuffer recvBuffer = ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE); 177 178 UcpTagMessage message = worker1.tagProbeNonBlocking(0, 0, false); 179 180 assertNull(message); 181 182 UcpEndpoint endpoint = worker2.newEndpoint( 183 new UcpEndpointParams().setUcpAddress(worker1.getAddress())); 184 185 endpoint.sendTaggedNonBlocking( 186 ByteBuffer.allocateDirect(UcpMemoryTest.MEM_SIZE), null); 187 188 do { 189 worker1.progress(); 190 worker2.progress(); 191 message = worker1.tagProbeNonBlocking(0, 0, true); 192 } while (message == null); 193 194 assertEquals(UcpMemoryTest.MEM_SIZE, message.getRecvLength()); 195 assertEquals(0, message.getSenderTag()); 196 197 UcpRequest recv = worker1.recvTaggedMessageNonBlocking(recvBuffer, message, null); 198 199 worker1.progressRequest(recv); 200 201 Collections.addAll(resources, context1, context2, worker1, worker2, endpoint); 202 } 203 } 204