1 /*
2  * Copyright (c) 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /*
25  * @test
26  * @run testng TestSpliterator
27  */
28 
29 import jdk.incubator.foreign.MemoryLayout;
30 import jdk.incubator.foreign.MemoryLayouts;
31 import jdk.incubator.foreign.MemorySegment;
32 import jdk.incubator.foreign.SequenceLayout;
33 
34 import java.lang.invoke.VarHandle;
35 import java.util.LinkedList;
36 import java.util.List;
37 import java.util.Map;
38 import java.util.Spliterator;
39 import java.util.concurrent.CountedCompleter;
40 import java.util.concurrent.RecursiveTask;
41 import java.util.concurrent.atomic.AtomicLong;
42 import java.util.function.Consumer;
43 import java.util.function.Supplier;
44 import java.util.stream.LongStream;
45 import java.util.stream.StreamSupport;
46 
47 import org.testng.annotations.*;
48 import static jdk.incubator.foreign.MemorySegment.*;
49 import static org.testng.Assert.*;
50 
51 public class TestSpliterator {
52 
53     static final VarHandle INT_HANDLE = MemoryLayout.ofSequence(MemoryLayouts.JAVA_INT)
54             .varHandle(int.class, MemoryLayout.PathElement.sequenceElement());
55 
56     final static int CARRIER_SIZE = 4;
57 
58     @Test(dataProvider = "splits")
testSum(int size, int threshold)59     public void testSum(int size, int threshold) {
60         SequenceLayout layout = MemoryLayout.ofSequence(size, MemoryLayouts.JAVA_INT);
61 
62         //setup
63         MemorySegment segment = MemorySegment.allocateNative(layout).share();
64         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
65             INT_HANDLE.set(segment, (long) i, i);
66         }
67         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
68         //serial
69         long serial = sum(0, segment);
70         assertEquals(serial, expected);
71         //parallel counted completer
72         long parallelCounted = new SumSegmentCounted(null, segment.spliterator(layout), threshold).invoke();
73         assertEquals(parallelCounted, expected);
74         //parallel recursive action
75         long parallelRecursive = new SumSegmentRecursive(segment.spliterator(layout), threshold).invoke();
76         assertEquals(parallelRecursive, expected);
77         //parallel stream
78         long streamParallel = StreamSupport.stream(segment.spliterator(layout), true)
79                 .reduce(0L, TestSpliterator::sumSingle, Long::sum);
80         assertEquals(streamParallel, expected);
81         segment.close();
82     }
83 
testSumSameThread()84     public void testSumSameThread() {
85         SequenceLayout layout = MemoryLayout.ofSequence(1024, MemoryLayouts.JAVA_INT);
86 
87         //setup
88         MemorySegment segment = MemorySegment.allocateNative(layout);
89         for (int i = 0; i < layout.elementCount().getAsLong(); i++) {
90             INT_HANDLE.set(segment, (long) i, i);
91         }
92         long expected = LongStream.range(0, layout.elementCount().getAsLong()).sum();
93 
94         //check that a segment w/o ACQUIRE access mode can still be used from same thread
95         AtomicLong spliteratorSum = new AtomicLong();
96         segment.withAccessModes(MemorySegment.READ).spliterator(layout)
97                 .forEachRemaining(s -> spliteratorSum.addAndGet(sumSingle(0L, s)));
98         assertEquals(spliteratorSum.get(), expected);
99     }
100 
sumSingle(long acc, MemorySegment segment)101     static long sumSingle(long acc, MemorySegment segment) {
102         return acc + (int)INT_HANDLE.get(segment, 0L);
103     }
104 
sum(long start, MemorySegment segment)105     static long sum(long start, MemorySegment segment) {
106         long sum = start;
107         int length = (int)segment.byteSize();
108         for (int i = 0 ; i < length / CARRIER_SIZE ; i++) {
109             sum += (int)INT_HANDLE.get(segment, (long)i);
110         }
111         return sum;
112     }
113 
114     static class SumSegmentCounted extends CountedCompleter<Long> {
115 
116         final long threshold;
117         long localSum = 0;
118         List<SumSegmentCounted> children = new LinkedList<>();
119 
120         private Spliterator<MemorySegment> segmentSplitter;
121 
SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold)122         SumSegmentCounted(SumSegmentCounted parent, Spliterator<MemorySegment> segmentSplitter, long threshold) {
123             super(parent);
124             this.segmentSplitter = segmentSplitter;
125             this.threshold = threshold;
126         }
127 
128         @Override
compute()129         public void compute() {
130             Spliterator<MemorySegment> sub;
131             while (segmentSplitter.estimateSize() > threshold &&
132                     (sub = segmentSplitter.trySplit()) != null) {
133                 addToPendingCount(1);
134                 SumSegmentCounted child = new SumSegmentCounted(this, sub, threshold);
135                 children.add(child);
136                 child.fork();
137             }
138             segmentSplitter.forEachRemaining(slice -> {
139                 localSum += sumSingle(0, slice);
140             });
141             tryComplete();
142         }
143 
144         @Override
getRawResult()145         public Long getRawResult() {
146             long sum = localSum;
147             for (SumSegmentCounted c : children) {
148                 sum += c.getRawResult();
149             }
150             return sum;
151         }
152      }
153 
154     static class SumSegmentRecursive extends RecursiveTask<Long> {
155 
156         final long threshold;
157         private final Spliterator<MemorySegment> splitter;
158         private long result;
159 
SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold)160         SumSegmentRecursive(Spliterator<MemorySegment> splitter, long threshold) {
161             this.splitter = splitter;
162             this.threshold = threshold;
163         }
164 
165         @Override
compute()166         protected Long compute() {
167             if (splitter.estimateSize() > threshold) {
168                 SumSegmentRecursive sub = new SumSegmentRecursive(splitter.trySplit(), threshold);
169                 sub.fork();
170                 return compute() + sub.join();
171             } else {
172                 splitter.forEachRemaining(slice -> {
173                     result += sumSingle(0, slice);
174                 });
175                 return result;
176             }
177         }
178     }
179 
180     @DataProvider(name = "splits")
splits()181     public Object[][] splits() {
182         return new Object[][] {
183                 { 10, 1 },
184                 { 100, 1 },
185                 { 1000, 1 },
186                 { 10000, 1 },
187                 { 10, 10 },
188                 { 100, 10 },
189                 { 1000, 10 },
190                 { 10000, 10 },
191                 { 10, 100 },
192                 { 100, 100 },
193                 { 1000, 100 },
194                 { 10000, 100 },
195                 { 10, 1000 },
196                 { 100, 1000 },
197                 { 1000, 1000 },
198                 { 10000, 1000 },
199                 { 10, 10000 },
200                 { 100, 10000 },
201                 { 1000, 10000 },
202                 { 10000, 10000 },
203         };
204     }
205 
206     @DataProvider(name = "accessScenarios")
accessScenarios()207     public Object[][] accessScenarios() {
208         SequenceLayout layout = MemoryLayout.ofSequence(16, MemoryLayouts.JAVA_INT);
209         var mallocSegment = MemorySegment.allocateNative(layout);
210 
211         Map<Supplier<Spliterator<MemorySegment>>,Integer> l = Map.of(
212             () -> mallocSegment.withAccessModes(ALL_ACCESS).spliterator(layout), ALL_ACCESS,
213             () -> mallocSegment.withAccessModes(0).spliterator(layout), 0,
214             () -> mallocSegment.withAccessModes(READ).spliterator(layout), READ,
215             () -> mallocSegment.withAccessModes(CLOSE).spliterator(layout), 0,
216             () -> mallocSegment.withAccessModes(READ|WRITE).spliterator(layout), READ|WRITE,
217             () -> mallocSegment.withAccessModes(READ|WRITE| SHARE).spliterator(layout), READ|WRITE| SHARE,
218             () -> mallocSegment.withAccessModes(READ|WRITE| SHARE |HANDOFF).spliterator(layout), READ|WRITE| SHARE |HANDOFF
219 
220         );
221         return l.entrySet().stream().map(e -> new Object[] { e.getKey(), e.getValue() }).toArray(Object[][]::new);
222     }
223 
assertAccessModes(int accessModes)224     static Consumer<MemorySegment> assertAccessModes(int accessModes) {
225         return segment -> {
226             assertTrue(segment.hasAccessModes(accessModes & ~CLOSE));
227             assertEquals(segment.accessModes(), accessModes & ~CLOSE);
228         };
229     }
230 
231     @Test(dataProvider = "accessScenarios")
testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier, int expectedAccessModes)232     public void testAccessModes(Supplier<Spliterator<MemorySegment>> spliteratorSupplier,
233                                 int expectedAccessModes) {
234         Spliterator<MemorySegment> spliterator = spliteratorSupplier.get();
235         spliterator.forEachRemaining(assertAccessModes(expectedAccessModes));
236 
237         spliterator = spliteratorSupplier.get();
238         do { } while (spliterator.tryAdvance(assertAccessModes(expectedAccessModes)));
239 
240         splitOrConsume(spliteratorSupplier.get(), assertAccessModes(expectedAccessModes));
241     }
242 
splitOrConsume(Spliterator<MemorySegment> spliterator, Consumer<MemorySegment> consumer)243     static void splitOrConsume(Spliterator<MemorySegment> spliterator,
244                                Consumer<MemorySegment> consumer) {
245         var s1 = spliterator.trySplit();
246         if (s1 != null) {
247             splitOrConsume(s1, consumer);
248             splitOrConsume(spliterator, consumer);
249         } else {
250             spliterator.forEachRemaining(consumer);
251         }
252     }
253 }
254