1 /*
2  * Copyright (c) 2016, 2020, Oracle and/or its affiliates. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 
25 package org.graalvm.compiler.replacements;
26 
27 import static org.graalvm.compiler.nodes.util.ConstantReflectionUtil.loadByteArrayConstant;
28 import static org.graalvm.compiler.nodes.util.ConstantReflectionUtil.loadCharArrayConstant;
29 import static org.graalvm.compiler.replacements.ReplacementsUtil.byteArrayBaseOffset;
30 import static org.graalvm.compiler.replacements.ReplacementsUtil.charArrayBaseOffset;
31 import static org.graalvm.compiler.replacements.SnippetTemplate.DEFAULT_REPLACER;
32 import static org.graalvm.compiler.serviceprovider.GraalUnsafeAccess.getUnsafe;
33 
34 import org.graalvm.compiler.api.replacements.Fold.InjectedParameter;
35 import org.graalvm.compiler.api.replacements.Snippet;
36 import org.graalvm.compiler.api.replacements.Snippet.ConstantParameter;
37 import org.graalvm.compiler.api.replacements.SnippetReflectionProvider;
38 import org.graalvm.compiler.debug.DebugHandlersFactory;
39 import org.graalvm.compiler.nodes.StructuredGraph;
40 import org.graalvm.compiler.nodes.spi.LoweringTool;
41 import org.graalvm.compiler.options.OptionValues;
42 import org.graalvm.compiler.phases.util.Providers;
43 import org.graalvm.compiler.replacements.SnippetTemplate.AbstractTemplates;
44 import org.graalvm.compiler.replacements.SnippetTemplate.Arguments;
45 import org.graalvm.compiler.replacements.SnippetTemplate.SnippetInfo;
46 import org.graalvm.compiler.replacements.nodes.ExplodeLoopNode;
47 
48 import jdk.vm.ci.code.TargetDescription;
49 import jdk.vm.ci.meta.JavaConstant;
50 import jdk.vm.ci.meta.JavaKind;
51 import jdk.vm.ci.meta.MetaAccessProvider;
52 import sun.misc.Unsafe;
53 
54 public class ConstantStringIndexOfSnippets implements Snippets {
55     private static final Unsafe UNSAFE = getUnsafe();
56 
57     public static class Templates extends AbstractTemplates {
58 
59         private final SnippetInfo indexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "indexOfConstant");
60         private final SnippetInfo latin1IndexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "latin1IndexOfConstant");
61         private final SnippetInfo utf16IndexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "utf16IndexOfConstant");
62 
Templates(OptionValues options, Iterable<DebugHandlersFactory> factories, Providers providers, SnippetReflectionProvider snippetReflection, TargetDescription target)63         public Templates(OptionValues options, Iterable<DebugHandlersFactory> factories, Providers providers, SnippetReflectionProvider snippetReflection, TargetDescription target) {
64             super(options, factories, providers, snippetReflection, target);
65         }
66 
lower(SnippetLowerableMemoryNode stringIndexOf, LoweringTool tool)67         public void lower(SnippetLowerableMemoryNode stringIndexOf, LoweringTool tool) {
68             StructuredGraph graph = stringIndexOf.graph();
69             Arguments args = new Arguments(indexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
70             args.add("source", stringIndexOf.getArgument(0));
71             args.add("sourceOffset", stringIndexOf.getArgument(1));
72             args.add("sourceCount", stringIndexOf.getArgument(2));
73             args.addConst("target", stringIndexOf.getArgument(3));
74             args.add("targetOffset", stringIndexOf.getArgument(4));
75             args.add("targetCount", stringIndexOf.getArgument(5));
76             args.add("origFromIndex", stringIndexOf.getArgument(6));
77             JavaConstant targetArg = stringIndexOf.getArgument(3).asJavaConstant();
78             char[] targetCharArray = loadCharArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
79             args.addConst("md2", md2(targetCharArray));
80             args.addConst("cache", computeCache(targetCharArray));
81             template(stringIndexOf, args).instantiate(providers.getMetaAccess(), stringIndexOf, DEFAULT_REPLACER, args);
82         }
83 
lowerLatin1(SnippetLowerableMemoryNode latin1IndexOf, LoweringTool tool)84         public void lowerLatin1(SnippetLowerableMemoryNode latin1IndexOf, LoweringTool tool) {
85             StructuredGraph graph = latin1IndexOf.graph();
86             Arguments args = new Arguments(latin1IndexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
87             args.add("source", latin1IndexOf.getArgument(0));
88             args.add("sourceCount", latin1IndexOf.getArgument(1));
89             args.addConst("target", latin1IndexOf.getArgument(2));
90             args.add("targetCount", latin1IndexOf.getArgument(3));
91             args.add("origFromIndex", latin1IndexOf.getArgument(4));
92             JavaConstant targetArg = latin1IndexOf.getArgument(2).asJavaConstant();
93             byte[] targetByteArray = loadByteArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
94             args.addConst("md2", md2(targetByteArray));
95             args.addConst("cache", computeCache(targetByteArray));
96             template(latin1IndexOf, args).instantiate(providers.getMetaAccess(), latin1IndexOf, DEFAULT_REPLACER, args);
97         }
98 
lowerUTF16(SnippetLowerableMemoryNode utf16IndexOf, LoweringTool tool)99         public void lowerUTF16(SnippetLowerableMemoryNode utf16IndexOf, LoweringTool tool) {
100 
101             StructuredGraph graph = utf16IndexOf.graph();
102             Arguments args = new Arguments(utf16IndexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
103             args.add("source", utf16IndexOf.getArgument(0));
104             args.add("sourceCount", utf16IndexOf.getArgument(1));
105             args.addConst("target", utf16IndexOf.getArgument(2));
106             args.add("targetCount", utf16IndexOf.getArgument(3));
107             args.add("origFromIndex", utf16IndexOf.getArgument(4));
108             JavaConstant targetArg = utf16IndexOf.getArgument(2).asJavaConstant();
109             byte[] targetByteArray = loadByteArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
110             args.addConst("md2", md2Utf16(tool.getMetaAccess(), targetByteArray));
111             args.addConst("cache", computeCacheUtf16(tool.getMetaAccess(), targetByteArray));
112             template(utf16IndexOf, args).instantiate(providers.getMetaAccess(), utf16IndexOf, DEFAULT_REPLACER, args);
113         }
114     }
115 
md2(char[] target)116     static int md2(char[] target) {
117         int c = target.length;
118         if (c == 0) {
119             return 0;
120         }
121         char lastChar = target[c - 1];
122         int md2 = c;
123         for (int i = 0; i < c - 1; i++) {
124             if (target[i] == lastChar) {
125                 md2 = (c - 1) - i;
126             }
127         }
128         return md2;
129     }
130 
computeCache(char[] s)131     static long computeCache(char[] s) {
132         int c = s.length;
133         int cache = 0;
134         int i;
135         for (i = 0; i < c - 1; i++) {
136             cache |= (1 << (s[i] & 63));
137         }
138         return cache;
139     }
140 
md2(byte[] target)141     static int md2(byte[] target) {
142         int c = target.length;
143         if (c == 0) {
144             return 0;
145         }
146         byte lastByte = target[c - 1];
147         int md2 = c;
148         for (int i = 0; i < c - 1; i++) {
149             if (target[i] == lastByte) {
150                 md2 = (c - 1) - i;
151             }
152         }
153         return md2;
154     }
155 
computeCache(byte[] s)156     static long computeCache(byte[] s) {
157         int c = s.length;
158         int cache = 0;
159         int i;
160         for (i = 0; i < c - 1; i++) {
161             cache |= (1 << (s[i] & 63));
162         }
163         return cache;
164     }
165 
md2Utf16(MetaAccessProvider metaAccess, byte[] target)166     static int md2Utf16(MetaAccessProvider metaAccess, byte[] target) {
167         int c = target.length / 2;
168         if (c == 0) {
169             return 0;
170         }
171         long base = metaAccess.getArrayBaseOffset(JavaKind.Byte);
172         char lastChar = UNSAFE.getChar(target, base + (c - 1) * 2);
173         int md2 = c;
174         for (int i = 0; i < c - 1; i++) {
175             char currChar = UNSAFE.getChar(target, base + i * 2);
176             if (currChar == lastChar) {
177                 md2 = (c - 1) - i;
178             }
179         }
180         return md2;
181     }
182 
computeCacheUtf16(MetaAccessProvider metaAccess, byte[] s)183     static long computeCacheUtf16(MetaAccessProvider metaAccess, byte[] s) {
184         int c = s.length / 2;
185         int cache = 0;
186         int i;
187         long base = metaAccess.getArrayBaseOffset(JavaKind.Byte);
188         for (i = 0; i < c - 1; i++) {
189             char currChar = UNSAFE.getChar(s, base + i * 2);
190             cache |= (1 << (currChar & 63));
191         }
192         return cache;
193     }
194 
195     /** Marker value for the {@link InjectedParameter} injected parameter. */
196     static final MetaAccessProvider INJECTED = null;
197 
198     @Snippet
indexOfConstant(char[] source, int sourceOffset, int sourceCount, @ConstantParameter char[] target, int targetOffset, int targetCount, int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache)199     public static int indexOfConstant(char[] source, int sourceOffset, int sourceCount,
200                     @ConstantParameter char[] target, int targetOffset, int targetCount,
201                     int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
202         int fromIndex = origFromIndex;
203         if (fromIndex >= sourceCount) {
204             return (targetCount == 0 ? sourceCount : -1);
205         }
206         if (fromIndex < 0) {
207             fromIndex = 0;
208         }
209         if (targetCount == 0) {
210             return fromIndex;
211         }
212 
213         int targetCountLess1 = targetCount - 1;
214         int sourceEnd = sourceCount - targetCountLess1;
215 
216         long base = charArrayBaseOffset(INJECTED);
217         int lastChar = UNSAFE.getChar(target, base + targetCountLess1 * 2);
218 
219         outer_loop: for (long i = sourceOffset + fromIndex; i < sourceEnd;) {
220             int src = UNSAFE.getChar(source, base + (i + targetCountLess1) * 2);
221             if (src == lastChar) {
222                 // With random strings and a 4-character alphabet,
223                 // reverse matching at this point sets up 0.8% fewer
224                 // frames, but (paradoxically) makes 0.3% more probes.
225                 // Since those probes are nearer the lastChar probe,
226                 // there is may be a net D$ win with reverse matching.
227                 // But, reversing loop inhibits unroll of inner loop
228                 // for unknown reason. So, does running outer loop from
229                 // (sourceOffset - targetCountLess1) to (sourceOffset + sourceCount)
230                 if (targetCount <= 8) {
231                     ExplodeLoopNode.explodeLoop();
232                 }
233                 for (long j = 0; j < targetCountLess1; j++) {
234                     char sourceChar = UNSAFE.getChar(source, base + (i + j) * 2);
235                     if (UNSAFE.getChar(target, base + (targetOffset + j) * 2) != sourceChar) {
236                         if ((cache & (1 << sourceChar)) == 0) {
237                             if (md2 < j + 1) {
238                                 i += j + 1;
239                                 continue outer_loop;
240                             }
241                         }
242                         i += md2;
243                         continue outer_loop;
244                     }
245                 }
246                 return (int) (i - sourceOffset);
247             }
248             if ((cache & (1 << src)) == 0) {
249                 i += targetCountLess1;
250             }
251             i++;
252         }
253         return -1;
254     }
255 
256     @Snippet
utf16IndexOfConstant(byte[] source, int sourceCount, @ConstantParameter byte[] target, int targetCount, int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache)257     public static int utf16IndexOfConstant(byte[] source, int sourceCount,
258                     @ConstantParameter byte[] target, int targetCount,
259                     int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
260         int fromIndex = origFromIndex;
261         if (fromIndex >= sourceCount) {
262             return (targetCount == 0 ? sourceCount : -1);
263         }
264         if (fromIndex < 0) {
265             fromIndex = 0;
266         }
267         if (targetCount == 0) {
268             return fromIndex;
269         }
270 
271         int targetCountLess1 = targetCount - 1;
272         int sourceEnd = sourceCount - targetCountLess1;
273 
274         long base = byteArrayBaseOffset(INJECTED);
275         int lastChar = UNSAFE.getChar(target, base + targetCountLess1 * 2);
276 
277         outer_loop: for (long i = fromIndex; i < sourceEnd;) {
278             int src = UNSAFE.getChar(source, base + (i + targetCountLess1) * 2);
279             if (src == lastChar) {
280                 // With random strings and a 4-character alphabet,
281                 // reverse matching at this point sets up 0.8% fewer
282                 // frames, but (paradoxically) makes 0.3% more probes.
283                 // Since those probes are nearer the lastChar probe,
284                 // there is may be a net D$ win with reverse matching.
285                 // But, reversing loop inhibits unroll of inner loop
286                 // for unknown reason. So, does running outer loop from
287                 // (sourceOffset - targetCountLess1) to (sourceOffset + sourceCount)
288                 if (targetCount <= 8) {
289                     ExplodeLoopNode.explodeLoop();
290                 }
291                 for (long j = 0; j < targetCountLess1; j++) {
292                     char sourceChar = UNSAFE.getChar(source, base + (i + j) * 2);
293                     if (UNSAFE.getChar(target, base + j * 2) != sourceChar) {
294                         if ((cache & (1 << sourceChar)) == 0) {
295                             if (md2 < j + 1) {
296                                 i += j + 1;
297                                 continue outer_loop;
298                             }
299                         }
300                         i += md2;
301                         continue outer_loop;
302                     }
303                 }
304                 return (int) i;
305             }
306             if ((cache & (1 << src)) == 0) {
307                 i += targetCountLess1;
308             }
309             i++;
310         }
311         return -1;
312     }
313 
314     @Snippet
latin1IndexOfConstant(byte[] source, int sourceCount, @ConstantParameter byte[] target, int targetCount, int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache)315     public static int latin1IndexOfConstant(byte[] source, int sourceCount,
316                     @ConstantParameter byte[] target, int targetCount,
317                     int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
318         int fromIndex = origFromIndex;
319         if (fromIndex >= sourceCount) {
320             return (targetCount == 0 ? sourceCount : -1);
321         }
322         if (fromIndex < 0) {
323             fromIndex = 0;
324         }
325         if (targetCount == 0) {
326             return fromIndex;
327         }
328 
329         int targetCountLess1 = targetCount - 1;
330         int sourceEnd = sourceCount - targetCountLess1;
331 
332         long base = byteArrayBaseOffset(INJECTED);
333         int lastByte = UNSAFE.getByte(target, base + targetCountLess1);
334 
335         outer_loop: for (long i = fromIndex; i < sourceEnd;) {
336             int src = UNSAFE.getByte(source, base + i + targetCountLess1);
337             if (src == lastByte) {
338                 // With random strings and a 4-character alphabet,
339                 // reverse matching at this point sets up 0.8% fewer
340                 // frames, but (paradoxically) makes 0.3% more probes.
341                 // Since those probes are nearer the lastByte probe,
342                 // there is may be a net D$ win with reverse matching.
343                 // But, reversing loop inhibits unroll of inner loop
344                 // for unknown reason. So, does running outer loop from
345                 // (sourceOffset - targetCountLess1) to (sourceOffset + sourceCount)
346                 if (targetCount <= 8) {
347                     ExplodeLoopNode.explodeLoop();
348                 }
349                 for (long j = 0; j < targetCountLess1; j++) {
350                     byte sourceByte = UNSAFE.getByte(source, base + i + j);
351                     if (UNSAFE.getByte(target, base + j) != sourceByte) {
352                         if ((cache & (1 << sourceByte)) == 0) {
353                             if (md2 < j + 1) {
354                                 i += j + 1;
355                                 continue outer_loop;
356                             }
357                         }
358                         i += md2;
359                         continue outer_loop;
360                     }
361                 }
362                 return (int) i;
363             }
364             if ((cache & (1 << src)) == 0) {
365                 i += targetCountLess1;
366             }
367             i++;
368         }
369         return -1;
370     }
371 }
372