1 /*
2  * Copyright (c) 2018, Google LLC. All rights reserved.
3  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4  *
5  * This code is free software; you can redistribute it and/or modify it
6  * under the terms of the GNU General Public License version 2 only, as
7  * published by the Free Software Foundation.
8  *
9  * This code is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
12  * version 2 for more details (a copy is included in the LICENSE file that
13  * accompanied this code).
14  *
15  * You should have received a copy of the GNU General Public License version
16  * 2 along with this work; if not, write to the Free Software Foundation,
17  * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18  *
19  * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20  * or visit www.oracle.com if you need additional information or have any
21  * questions.
22  */
23 
24 /**
25  * @test 8200301 8201194
26  * @summary deduplicate lambda methods with the same body, target type, and captured state
27  * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api
28  *     jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp
29  *     jdk.compiler/com.sun.tools.javac.file jdk.compiler/com.sun.tools.javac.main
30  *     jdk.compiler/com.sun.tools.javac.tree jdk.compiler/com.sun.tools.javac.util
31  * @run main DeduplicationTest
32  */
33 import static java.nio.charset.StandardCharsets.UTF_8;
34 import static java.util.stream.Collectors.joining;
35 import static java.util.stream.Collectors.toList;
36 import static java.util.stream.Collectors.toMap;
37 import static java.util.stream.Collectors.toSet;
38 
39 import com.sun.source.util.JavacTask;
40 import com.sun.source.util.TaskEvent;
41 import com.sun.source.util.TaskEvent.Kind;
42 import com.sun.source.util.TaskListener;
43 import com.sun.tools.classfile.Attribute;
44 import com.sun.tools.classfile.BootstrapMethods_attribute;
45 import com.sun.tools.classfile.BootstrapMethods_attribute.BootstrapMethodSpecifier;
46 import com.sun.tools.classfile.ClassFile;
47 import com.sun.tools.classfile.ConstantPool.CONSTANT_MethodHandle_info;
48 import com.sun.tools.javac.api.ClientCodeWrapper.Trusted;
49 import com.sun.tools.javac.api.JavacTool;
50 import com.sun.tools.javac.code.Symbol;
51 import com.sun.tools.javac.code.Symbol.MethodSymbol;
52 import com.sun.tools.javac.comp.TreeDiffer;
53 import com.sun.tools.javac.comp.TreeHasher;
54 import com.sun.tools.javac.file.JavacFileManager;
55 import com.sun.tools.javac.tree.JCTree.JCCompilationUnit;
56 import com.sun.tools.javac.tree.JCTree.JCExpression;
57 import com.sun.tools.javac.tree.JCTree.JCIdent;
58 import com.sun.tools.javac.tree.JCTree.JCLambda;
59 import com.sun.tools.javac.tree.JCTree.JCMethodInvocation;
60 import com.sun.tools.javac.tree.JCTree.JCTypeCast;
61 import com.sun.tools.javac.tree.JCTree.Tag;
62 import com.sun.tools.javac.tree.TreeScanner;
63 import com.sun.tools.javac.util.Context;
64 import com.sun.tools.javac.util.JCDiagnostic;
65 import java.nio.file.Path;
66 import java.nio.file.Paths;
67 import java.util.ArrayList;
68 import java.util.Arrays;
69 import java.util.LinkedHashMap;
70 import java.util.List;
71 import java.util.Locale;
72 import java.util.Map;
73 import java.util.Set;
74 import java.util.TreeSet;
75 import javax.tools.Diagnostic;
76 import javax.tools.DiagnosticListener;
77 import javax.tools.JavaFileObject;
78 
79 public class DeduplicationTest {
80 
main(String[] args)81     public static void main(String[] args) throws Exception {
82         JavacFileManager fileManager = new JavacFileManager(new Context(), false, UTF_8);
83         JavacTool javacTool = JavacTool.create();
84         Listener diagnosticListener = new Listener();
85         Path testSrc = Paths.get(System.getProperty("test.src"));
86         Path file = testSrc.resolve("Deduplication.java");
87         String sourceVersion = Integer.toString(Runtime.version().feature());
88         JavacTask task =
89                 javacTool.getTask(
90                         null,
91                         null,
92                         diagnosticListener,
93                         Arrays.asList(
94                                 "-d",
95                                 ".",
96                                 "-XDdebug.dumpLambdaToMethodDeduplication",
97                                 "-XDdebug.dumpLambdaToMethodStats",
98                                 "--enable-preview", "-source", sourceVersion),
99                         null,
100                         fileManager.getJavaFileObjects(file));
101         Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>();
102         task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas));
103         Iterable<? extends JavaFileObject> generated = task.generate();
104         if (!diagnosticListener.unexpected.isEmpty()) {
105             throw new AssertionError(
106                     diagnosticListener
107                             .unexpected
108                             .stream()
109                             .map(
110                                     d ->
111                                             String.format(
112                                                     "%s: %s",
113                                                     d.getCode(), d.getMessage(Locale.getDefault())))
114                             .collect(joining(", ", "unexpected diagnostics: ", "")));
115         }
116 
117         // Assert that each group of lambdas was deduplicated.
118         Map<JCLambda, JCLambda> actual = diagnosticListener.deduplicationTargets();
119         dedupedLambdas.forEach(
120                 (k, v) -> {
121                     if (!actual.containsKey(k)) {
122                         throw new AssertionError("expected " + k + " to be deduplicated");
123                     }
124                     if (!v.equals(actual.get(k))) {
125                         throw new AssertionError(
126                                 String.format(
127                                         "expected %s to be deduplicated to:\n  %s\nwas:  %s",
128                                         k, v, actual.get(v)));
129                     }
130                 });
131 
132         // Assert that the output contains only the canonical lambdas, and not the deduplicated
133         // lambdas.
134         Set<String> bootstrapMethodNames = new TreeSet<>();
135         for (JavaFileObject output : generated) {
136             ClassFile cf = ClassFile.read(output.openInputStream());
137             BootstrapMethods_attribute bsm =
138                     (BootstrapMethods_attribute) cf.getAttribute(Attribute.BootstrapMethods);
139             for (BootstrapMethodSpecifier b : bsm.bootstrap_method_specifiers) {
140                 bootstrapMethodNames.add(
141                         ((CONSTANT_MethodHandle_info)
142                                         cf.constant_pool.get(b.bootstrap_arguments[1]))
143                                 .getCPRefInfo()
144                                 .getNameAndTypeInfo()
145                                 .getName());
146             }
147         }
148         Set<String> deduplicatedNames =
149                 diagnosticListener
150                         .expectedLambdaMethods()
151                         .stream()
152                         .map(s -> s.getSimpleName().toString())
153                         .sorted()
154                         .collect(toSet());
155         if (!deduplicatedNames.equals(bootstrapMethodNames)) {
156             throw new AssertionError(
157                     String.format(
158                             "expected deduplicated methods: %s, but saw: %s",
159                             deduplicatedNames, bootstrapMethodNames));
160         }
161     }
162 
163     /** Returns the parameter symbols of the given lambda. */
paramSymbols(JCLambda lambda)164     private static List<Symbol> paramSymbols(JCLambda lambda) {
165         return lambda.params.stream().map(x -> x.sym).collect(toList());
166     }
167 
168     /** A diagnostic listener that records debug messages related to lambda desugaring. */
169     @Trusted
170     static class Listener implements DiagnosticListener<JavaFileObject> {
171 
172         /** A map from method symbols to lambda trees for desugared lambdas. */
173         final Map<MethodSymbol, JCLambda> lambdaMethodSymbolsToTrees = new LinkedHashMap<>();
174 
175         /**
176          * A map from lambda trees that were deduplicated to the method symbol of the canonical
177          * lambda implementation method they were deduplicated to.
178          */
179         final Map<JCLambda, MethodSymbol> deduped = new LinkedHashMap<>();
180 
181         final List<Diagnostic<? extends JavaFileObject>> unexpected = new ArrayList<>();
182 
183         @Override
report(Diagnostic<? extends JavaFileObject> diagnostic)184         public void report(Diagnostic<? extends JavaFileObject> diagnostic) {
185             JCDiagnostic d = (JCDiagnostic) diagnostic;
186             switch (d.getCode()) {
187                 case "compiler.note.lambda.stat":
188                     lambdaMethodSymbolsToTrees.put(
189                             (MethodSymbol) d.getArgs()[1],
190                             (JCLambda) d.getDiagnosticPosition().getTree());
191                     break;
192                 case "compiler.note.verbose.l2m.deduplicate":
193                     deduped.put(
194                             (JCLambda) d.getDiagnosticPosition().getTree(),
195                             (MethodSymbol) d.getArgs()[0]);
196                     break;
197                 case "compiler.note.preview.filename":
198                 case "compiler.note.preview.recompile":
199                     break; //ignore
200                 default:
201                     unexpected.add(diagnostic);
202             }
203         }
204 
205         /** Returns expected lambda implementation method symbols. */
expectedLambdaMethods()206         Set<MethodSymbol> expectedLambdaMethods() {
207             return lambdaMethodSymbolsToTrees
208                     .entrySet()
209                     .stream()
210                     .filter(e -> !deduped.containsKey(e.getValue()))
211                     .map(Map.Entry::getKey)
212                     .collect(toSet());
213         }
214 
215         /**
216          * Returns a mapping from deduplicated lambda trees to the tree of the canonical lambda they
217          * were deduplicated to.
218          */
deduplicationTargets()219         Map<JCLambda, JCLambda> deduplicationTargets() {
220             return deduped.entrySet()
221                     .stream()
222                     .collect(
223                             toMap(
224                                     Map.Entry::getKey,
225                                     e -> lambdaMethodSymbolsToTrees.get(e.getValue()),
226                                     (a, b) -> {
227                                         throw new AssertionError();
228                                     },
229                                     LinkedHashMap::new));
230         }
231     }
232 
233     /**
234      * A task listener that tests {@link TreeDiffer} and {@link TreeHasher} on all lambda trees in a
235      * compilation, post-analysis.
236      */
237     private static class TreeDiffHashTaskListener implements TaskListener {
238 
239         /**
240          * A map from deduplicated lambdas to the canonical lambda they are expected to be
241          * deduplicated to.
242          */
243         private final Map<JCLambda, JCLambda> dedupedLambdas;
244 
TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas)245         public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas) {
246             this.dedupedLambdas = dedupedLambdas;
247         }
248 
249         @Override
finished(TaskEvent e)250         public void finished(TaskEvent e) {
251             if (e.getKind() != Kind.ANALYZE) {
252                 return;
253             }
254             // Scan the compilation for calls to a varargs method named 'group', whose arguments
255             // are a group of lambdas that are equivalent to each other, but distinct from all
256             // lambdas in the compilation unit outside of that group.
257             List<List<JCLambda>> lambdaGroups = new ArrayList<>();
258             new TreeScanner() {
259                 @Override
260                 public void visitApply(JCMethodInvocation tree) {
261                     if (tree.getMethodSelect().getTag() == Tag.IDENT
262                             && ((JCIdent) tree.getMethodSelect())
263                                     .getName()
264                                     .contentEquals("group")) {
265                         List<JCLambda> xs = new ArrayList<>();
266                         for (JCExpression arg : tree.getArguments()) {
267                             if (arg instanceof JCTypeCast) {
268                                 arg = ((JCTypeCast) arg).getExpression();
269                             }
270                             xs.add((JCLambda) arg);
271                         }
272                         lambdaGroups.add(xs);
273                     }
274                     super.visitApply(tree);
275                 }
276             }.scan((JCCompilationUnit) e.getCompilationUnit());
277             for (int i = 0; i < lambdaGroups.size(); i++) {
278                 List<JCLambda> curr = lambdaGroups.get(i);
279                 JCLambda first = null;
280                 // Assert that all pairwise combinations of lambdas in the group are equal, and
281                 // hash to the same value.
282                 for (JCLambda lhs : curr) {
283                     if (first == null) {
284                         first = lhs;
285                     } else {
286                         dedupedLambdas.put(lhs, first);
287                     }
288                     for (JCLambda rhs : curr) {
289                         if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
290                                 .scan(lhs.body, rhs.body)) {
291                             throw new AssertionError(
292                                     String.format(
293                                             "expected lambdas to be equal\n%s\n%s", lhs, rhs));
294                         }
295                         if (TreeHasher.hash(lhs, paramSymbols(lhs))
296                                 != TreeHasher.hash(rhs, paramSymbols(rhs))) {
297                             throw new AssertionError(
298                                     String.format(
299                                             "expected lambdas to hash to the same value\n%s\n%s",
300                                             lhs, rhs));
301                         }
302                     }
303                 }
304                 // Assert that no lambdas in a group are equal to any lambdas outside that group,
305                 // or hash to the same value as lambda outside the group.
306                 // (Note that the hash collisions won't result in correctness problems but could
307                 // regress performs, and do not currently occurr for any of the test inputs.)
308                 for (int j = 0; j < lambdaGroups.size(); j++) {
309                     if (i == j) {
310                         continue;
311                     }
312                     for (JCLambda lhs : curr) {
313                         for (JCLambda rhs : lambdaGroups.get(j)) {
314                             if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs))
315                                     .scan(lhs.body, rhs.body)) {
316                                 throw new AssertionError(
317                                         String.format(
318                                                 "expected lambdas to not be equal\n%s\n%s",
319                                                 lhs, rhs));
320                             }
321                             if (TreeHasher.hash(lhs, paramSymbols(lhs))
322                                     == TreeHasher.hash(rhs, paramSymbols(rhs))) {
323                                 throw new AssertionError(
324                                         String.format(
325                                                 "expected lambdas to hash to different values\n%s\n%s",
326                                                 lhs, rhs));
327                             }
328                         }
329                     }
330                 }
331             }
332             lambdaGroups.clear();
333         }
334     }
335 }
336