1 /* 2 * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /* 25 * @test 26 * @run testng TestSpliterator 27 */ 28 29 import jdk.incubator.foreign.MemoryLayout; 30 import jdk.incubator.foreign.MemoryLayouts; 31 import jdk.incubator.foreign.MemorySegment; 32 import jdk.incubator.foreign.SequenceLayout; 33 34 import java.lang.invoke.VarHandle; 35 import java.util.LinkedList; 36 import java.util.List; 37 import java.util.Map; 38 import java.util.Spliterator; 39 import java.util.concurrent.CountedCompleter; 40 import java.util.concurrent.RecursiveTask; 41 import java.util.concurrent.atomic.AtomicLong; 42 import java.util.function.Consumer; 43 import java.util.function.Supplier; 44 import java.util.stream.LongStream; 45 import java.util.stream.StreamSupport; 46 47 import org.testng.annotations.*; 48 import static jdk.incubator.foreign.MemorySegment.*; 49 import static org.testng.Assert.*; 50 51 public class TestSpliterator { 52 53 static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT) 54 .varHandle(int.class, MemoryLayout.PathElement.sequenceElement()); 55 56 final static int CARRIER_SIZE = 4; 57 58 @Test(dataProvider = "splits") testSum(int size, int threshold)59 public void testSum(int size, int threshold) { 60 SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT); 61 62 //setup 63 MemorySegment segment = MemorySegment.allocateNative(layout).share(); 64 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 65 INT_HANDLE.set(segment, (long) i, i); 66 } 67 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 68 //serial 69 long serial = sum(0, segment); 70 assertEquals(serial, expected); 71 //parallel counted completer 72 long parallelCounted = new SumSegmentCounted(null, segment.spliterator(layout), threshold).invoke(); 73 assertEquals(parallelCounted, expected); 74 //parallel recursive action 75 long parallelRecursive = new SumSegmentRecursive(segment.spliterator(layout), threshold).invoke(); 76 assertEquals(parallelRecursive, expected); 77 //parallel stream 78 long streamParallel = StreamSupport.stream(segment.spliterator(layout), true) 79 .reduce(0L, TestSpliterator::sumSingle, Long::sum); 80 assertEquals(streamParallel, expected); 81 segment.close(); 82 } 83 testSumSameThread()84 public void testSumSameThread() { 85 SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT); 86 87 //setup 88 MemorySegment segment = MemorySegment.allocateNative(layout); 89 for (int i = 0; i < layout.elementCount().getAsLong(); i++) { 90 INT_HANDLE.set(segment, (long) i, i); 91 } 92 long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum(); 93 94 //check that a segment w/o ACQUIRE access mode can still be used from same thread 95 AtomicLong spliteratorSum = new AtomicLong(); 96 segment.withAccessModes(MemorySegment.READ).spliterator(layout) 97 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s))); 98 assertEquals(spliteratorSum.get(), expected); 99 } 100 sumSingle(long acc, MemorySegment segment)101 static long sumSingle(long acc, MemorySegment segment) { 102 return acc + (int)INT_HANDLE.get(segment, 0L); 103 } 104 sum(long start, MemorySegment segment)105 static long sum(long start, MemorySegment segment) { 106 long sum = start; 107 int length = (int)segment.byteSize(); 108 for (int i = 0 ; i < length / CARRIER_SIZE ; i++) { 109 sum += (int)INT_HANDLE.get(segment, (long)i); 110 } 111 return sum; 112 } 113 114 static class SumSegmentCounted extends CountedCompleter<Long> { 115 116 final long threshold; 117 long localSum = 0; 118 List<SumSegmentCounted> children = new LinkedList<>(); 119 120 private Spliterator<MemorySegment> segmentSplitter; 121 SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold)122 SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) { 123 super(parent); 124 this.segmentSplitter = segmentSplitter; 125 this.threshold = threshold; 126 } 127 128 @Override compute()129 public void compute() { 130 Spliterator<MemorySegment> sub; 131 while (segmentSplitter.estimateSize() > threshold && 132 (sub = segmentSplitter.trySplit()) != null) { 133 addToPendingCount(1); 134 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold); 135 children.add(child); 136 child.fork(); 137 } 138 segmentSplitter.forEachRemaining(slice -> { 139 localSum += sumSingle(0, slice); 140 }); 141 tryComplete(); 142 } 143 144 @Override getRawResult()145 public Long getRawResult() { 146 long sum = localSum; 147 for (SumSegmentCounted c : children) { 148 sum += c.getRawResult(); 149 } 150 return sum; 151 } 152 } 153 154 static class SumSegmentRecursive extends RecursiveTask<Long> { 155 156 final long threshold; 157 private final Spliterator<MemorySegment> splitter; 158 private long result; 159 SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold)160 SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) { 161 this.splitter = splitter; 162 this.threshold = threshold; 163 } 164 165 @Override compute()166 protected Long compute() { 167 if (splitter.estimateSize() > threshold) { 168 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold); 169 sub.fork(); 170 return compute() + sub.join(); 171 } else { 172 splitter.forEachRemaining(slice -> { 173 result += sumSingle(0, slice); 174 }); 175 return result; 176 } 177 } 178 } 179 180 @DataProvider(name = "splits") splits()181 public Object[][] splits() { 182 return new Object[][] { 183 { 10, 1 }, 184 { 100, 1 }, 185 { 1000, 1 }, 186 { 10000, 1 }, 187 { 10, 10 }, 188 { 100, 10 }, 189 { 1000, 10 }, 190 { 10000, 10 }, 191 { 10, 100 }, 192 { 100, 100 }, 193 { 1000, 100 }, 194 { 10000, 100 }, 195 { 10, 1000 }, 196 { 100, 1000 }, 197 { 1000, 1000 }, 198 { 10000, 1000 }, 199 { 10, 10000 }, 200 { 100, 10000 }, 201 { 1000, 10000 }, 202 { 10000, 10000 }, 203 }; 204 } 205 206 @DataProvider(name = "accessScenarios") accessScenarios()207 public Object[][] accessScenarios() { 208 SequenceLayout layout = MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_INT); 209 var mallocSegment = MemorySegment.allocateNative(layout); 210 211 Map<Supplier<Spliterator<MemorySegment>>,Integer> l = Map.of( 212 () -> mallocSegment.withAccessModes(ALL_ACCESS).spliterator(layout), ALL_ACCESS, 213 () -> mallocSegment.withAccessModes(0).spliterator(layout), 0, 214 () -> mallocSegment.withAccessModes(READ).spliterator(layout), READ, 215 () -> mallocSegment.withAccessModes(CLOSE).spliterator(layout), 0, 216 () -> mallocSegment.withAccessModes(READ|WRITE).spliterator(layout), READ|WRITE, 217 () -> mallocSegment.withAccessModes(READ|WRITE| SHARE).spliterator(layout), READ|WRITE| SHARE, 218 () -> mallocSegment.withAccessModes(READ|WRITE| SHARE |HANDOFF).spliterator(layout), READ|WRITE| SHARE |HANDOFF 219 220 ); 221 return l.entrySet().stream().map(e -> new Object[] { e.getKey(), e.getValue() }).toArray(Object[][]::new); 222 } 223 assertAccessModes(int accessModes)224 static Consumer<MemorySegment> assertAccessModes(int accessModes) { 225 return segment -> { 226 assertTrue(segment.hasAccessModes(accessModes & ~CLOSE)); 227 assertEquals(segment.accessModes(), accessModes & ~CLOSE); 228 }; 229 } 230 231 @Test(dataProvider = "accessScenarios") testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier, int expectedAccessModes)232 public void testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier, 233 int expectedAccessModes) { 234 Spliterator<MemorySegment> spliterator = spliteratorSupplier.get(); 235 spliterator.forEachRemaining(assertAccessModes(expectedAccessModes)); 236 237 spliterator = spliteratorSupplier.get(); 238 do { } while (spliterator.tryAdvance(assertAccessModes(expectedAccessModes))); 239 240 splitOrConsume(spliteratorSupplier.get(), assertAccessModes(expectedAccessModes)); 241 } 242 splitOrConsume(Spliterator<MemorySegment> spliterator, Consumer<MemorySegment> consumer)243 static void splitOrConsume(Spliterator<MemorySegment> spliterator, 244 Consumer<MemorySegment> consumer) { 245 var s1 = spliterator.trySplit(); 246 if (s1 != null) { 247 splitOrConsume(s1, consumer); 248 splitOrConsume(spliterator, consumer); 249 } else { 250 spliterator.forEachRemaining(consumer); 251 } 252 } 253 } 254