1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @run testng TestSpliterator
27  */
28 
29 import jdk.incubator.foreign.MemoryAddress;
30 import jdk.incubator.foreign.MemoryLayout;
31 import jdk.incubator.foreign.MemoryLayouts;
32 import jdk.incubator.foreign.MemorySegment;
33 import jdk.incubator.foreign.SequenceLayout;
34 
35 import java.lang.invoke.VarHandle;
36 import java.util.LinkedList;
37 import java.util.List;
38 import java.util.Map;
39 import java.util.Spliterator;
40 import java.util.concurrent.CountedCompleter;
41 import java.util.concurrent.RecursiveTask;
42 import java.util.concurrent.atomic.AtomicLong;
43 import java.util.function.Consumer;
44 import java.util.function.Supplier;
45 import java.util.stream.LongStream;
46 import java.util.stream.StreamSupport;
47 
48 import org.testng.annotations.*;
49 import static jdk.incubator.foreign.MemorySegment.*;
50 import static org.testng.Assert.*;
51 
52 public class TestSpliterator {
53 
54     static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT)
55             .varHandle(int.class, MemoryLayout.PathElement.sequenceElement());
56 
57     final static int CARRIER_SIZE = 4;
58 
59     @Test(dataProvider = "splits")
testSum(int size, int threshold)60     public void testSum(int size, int threshold) {
61         SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT);
62 
63         //setup
64         MemorySegment segment = MemorySegment.allocateNative(layout);
65         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
66             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
67         }
68         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
69         //serial
70         long serial = sum(0, segment);
71         assertEquals(serial, expected);
72         //parallel counted completer
73         long parallelCounted = new SumSegmentCounted(null, MemorySegment.spliterator(segment, layout), threshold).invoke();
74         assertEquals(parallelCounted, expected);
75         //parallel recursive action
76         long parallelRecursive = new SumSegmentRecursive(MemorySegment.spliterator(segment, layout), threshold).invoke();
77         assertEquals(parallelRecursive, expected);
78         //parallel stream
79         long streamParallel = StreamSupport.stream(MemorySegment.spliterator(segment, layout), true)
80                 .reduce(0L, TestSpliterator::sumSingle, Long::sum);
81         assertEquals(streamParallel, expected);
82         segment.close();
83     }
84 
testSumSameThread()85     public void testSumSameThread() {
86         SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT);
87 
88         //setup
89         MemorySegment segment = MemorySegment.allocateNative(layout);
90         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
91             INT_HANDLE.set(segment.baseAddress(), (long) i, i);
92         }
93         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
94 
95         //check that a segment w/o ACQUIRE access mode can still be used from same thread
96         AtomicLong spliteratorSum = new AtomicLong();
97         spliterator(segment.withAccessModes(MemorySegment.READ), layout)
98                 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s)));
99         assertEquals(spliteratorSum.get(), expected);
100     }
101 
sumSingle(long acc, MemorySegment segment)102     static long sumSingle(long acc, MemorySegment segment) {
103         return acc + (int)INT_HANDLE.get(segment.baseAddress(), 0L);
104     }
105 
sum(long start, MemorySegment segment)106     static long sum(long start, MemorySegment segment) {
107         long sum = start;
108         MemoryAddress base = segment.baseAddress();
109         int length = (int)segment.byteSize();
110         for (int i = 0 ; i < length / CARRIER_SIZE ; i++) {
111             sum += (int)INT_HANDLE.get(base, (long)i);
112         }
113         return sum;
114     }
115 
116     static class SumSegmentCounted extends CountedCompleter<Long> {
117 
118         final long threshold;
119         long localSum = 0;
120         List<SumSegmentCounted> children = new LinkedList<>();
121 
122         private Spliterator<MemorySegment> segmentSplitter;
123 
SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold)124         SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) {
125             super(parent);
126             this.segmentSplitter = segmentSplitter;
127             this.threshold = threshold;
128         }
129 
130         @Override
compute()131         public void compute() {
132             Spliterator<MemorySegment> sub;
133             while (segmentSplitter.estimateSize() > threshold &&
134                     (sub = segmentSplitter.trySplit()) != null) {
135                 addToPendingCount(1);
136                 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold);
137                 children.add(child);
138                 child.fork();
139             }
140             segmentSplitter.forEachRemaining(slice -> {
141                 localSum += sumSingle(0, slice);
142             });
143             tryComplete();
144         }
145 
146         @Override
getRawResult()147         public Long getRawResult() {
148             long sum = localSum;
149             for (SumSegmentCounted c : children) {
150                 sum += c.getRawResult();
151             }
152             return sum;
153         }
154      }
155 
156     static class SumSegmentRecursive extends RecursiveTask<Long> {
157 
158         final long threshold;
159         private final Spliterator<MemorySegment> splitter;
160         private long result;
161 
SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold)162         SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) {
163             this.splitter = splitter;
164             this.threshold = threshold;
165         }
166 
167         @Override
compute()168         protected Long compute() {
169             if (splitter.estimateSize() > threshold) {
170                 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold);
171                 sub.fork();
172                 return compute() + sub.join();
173             } else {
174                 splitter.forEachRemaining(slice -> {
175                     result += sumSingle(0, slice);
176                 });
177                 return result;
178             }
179         }
180     }
181 
182     @DataProvider(name = "splits")
splits()183     public Object[][] splits() {
184         return new Object[][] {
185                 { 10, 1 },
186                 { 100, 1 },
187                 { 1000, 1 },
188                 { 10000, 1 },
189                 { 10, 10 },
190                 { 100, 10 },
191                 { 1000, 10 },
192                 { 10000, 10 },
193                 { 10, 100 },
194                 { 100, 100 },
195                 { 1000, 100 },
196                 { 10000, 100 },
197                 { 10, 1000 },
198                 { 100, 1000 },
199                 { 1000, 1000 },
200                 { 10000, 1000 },
201                 { 10, 10000 },
202                 { 100, 10000 },
203                 { 1000, 10000 },
204                 { 10000, 10000 },
205         };
206     }
207 
208     @DataProvider(name = "accessScenarios")
accessScenarios()209     public Object[][] accessScenarios() {
210         SequenceLayout layout = MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_INT);
211         var mallocSegment = MemorySegment.allocateNative(layout);
212 
213         Map<Supplier<Spliterator<MemorySegment>>,Integer> l = Map.of(
214             () -> spliterator(mallocSegment.withAccessModes(ALL_ACCESS), layout), ALL_ACCESS,
215             () -> spliterator(mallocSegment.withAccessModes(0), layout), 0,
216             () -> spliterator(mallocSegment.withAccessModes(READ), layout), READ,
217             () -> spliterator(mallocSegment.withAccessModes(CLOSE), layout), 0,
218             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE), layout), READ|WRITE,
219             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE), layout), READ|WRITE|ACQUIRE,
220             () -> spliterator(mallocSegment.withAccessModes(READ|WRITE|ACQUIRE|HANDOFF), layout), READ|WRITE|ACQUIRE|HANDOFF
221 
222         );
223         return l.entrySet().stream().map(e -> new Object[] { e.getKey(), e.getValue() }).toArray(Object[][]::new);
224     }
225 
assertAccessModes(int accessModes)226     static Consumer<MemorySegment> assertAccessModes(int accessModes) {
227         return segment -> {
228             assertTrue(segment.hasAccessModes(accessModes & ~CLOSE));
229             assertEquals(segment.accessModes(), accessModes & ~CLOSE);
230         };
231     }
232 
233     @Test(dataProvider = "accessScenarios")
testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier, int expectedAccessModes)234     public void testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier,
235                                 int expectedAccessModes) {
236         Spliterator<MemorySegment> spliterator = spliteratorSupplier.get();
237         spliterator.forEachRemaining(assertAccessModes(expectedAccessModes));
238 
239         spliterator = spliteratorSupplier.get();
240         do { } while (spliterator.tryAdvance(assertAccessModes(expectedAccessModes)));
241 
242         splitOrConsume(spliteratorSupplier.get(), assertAccessModes(expectedAccessModes));
243     }
244 
splitOrConsume(Spliterator<MemorySegment> spliterator, Consumer<MemorySegment> consumer)245     static void splitOrConsume(Spliterator<MemorySegment> spliterator,
246                                Consumer<MemorySegment> consumer) {
247         var s1 = spliterator.trySplit();
248         if (s1 != null) {
249             splitOrConsume(s1, consumer);
250             splitOrConsume(spliterator, consumer);
251         } else {
252             spliterator.forEachRemaining(consumer);
253         }
254     }
255 }
256