1 /* 2 * Copyright (c) 2018, Google LLC. All rights reserved. 3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. 4 * 5 * This code is free software; you can redistribute it and/or modify it 6 * under the terms of the GNU General Public License version 2 only, as 7 * published by the Free Software Foundation. 8 * 9 * This code is distributed in the hope that it will be useful, but WITHOUT 10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or 11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 12 * version 2 for more details (a copy is included in the LICENSE file that 13 * accompanied this code). 14 * 15 * You should have received a copy of the GNU General Public License version 16 * 2 along with this work; if not, write to the Free Software Foundation, 17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. 18 * 19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA 20 * or visit www.oracle.com if you need additional information or have any 21 * questions. 22 */ 23 24 /** 25 * @test 8200301 8201194 26 * @summary deduplicate lambda methods with the same body, target type, and captured state 27 * @modules jdk.jdeps/com.sun.tools.classfile jdk.compiler/com.sun.tools.javac.api 28 * jdk.compiler/com.sun.tools.javac.code jdk.compiler/com.sun.tools.javac.comp 29 * jdk.compiler/com.sun.tools.javac.file jdk.compiler/com.sun.tools.javac.main 30 * jdk.compiler/com.sun.tools.javac.tree jdk.compiler/com.sun.tools.javac.util 31 * @run main DeduplicationTest 32 */ 33 import static java.nio.charset.StandardCharsets.UTF_8; 34 import static java.util.stream.Collectors.joining; 35 import static java.util.stream.Collectors.toList; 36 import static java.util.stream.Collectors.toMap; 37 import static java.util.stream.Collectors.toSet; 38 39 import com.sun.source.util.JavacTask; 40 import com.sun.source.util.TaskEvent; 41 import com.sun.source.util.TaskEvent.Kind; 42 import com.sun.source.util.TaskListener; 43 import com.sun.tools.classfile.Attribute; 44 import com.sun.tools.classfile.BootstrapMethods_attribute; 45 import com.sun.tools.classfile.BootstrapMethods_attribute.BootstrapMethodSpecifier; 46 import com.sun.tools.classfile.ClassFile; 47 import com.sun.tools.classfile.ConstantPool.CONSTANT_MethodHandle_info; 48 import com.sun.tools.javac.api.ClientCodeWrapper.Trusted; 49 import com.sun.tools.javac.api.JavacTool; 50 import com.sun.tools.javac.code.Symbol; 51 import com.sun.tools.javac.code.Symbol.MethodSymbol; 52 import com.sun.tools.javac.comp.TreeDiffer; 53 import com.sun.tools.javac.comp.TreeHasher; 54 import com.sun.tools.javac.file.JavacFileManager; 55 import com.sun.tools.javac.tree.JCTree.JCCompilationUnit; 56 import com.sun.tools.javac.tree.JCTree.JCExpression; 57 import com.sun.tools.javac.tree.JCTree.JCIdent; 58 import com.sun.tools.javac.tree.JCTree.JCLambda; 59 import com.sun.tools.javac.tree.JCTree.JCMethodInvocation; 60 import com.sun.tools.javac.tree.JCTree.JCTypeCast; 61 import com.sun.tools.javac.tree.JCTree.Tag; 62 import com.sun.tools.javac.tree.TreeScanner; 63 import com.sun.tools.javac.util.Context; 64 import com.sun.tools.javac.util.JCDiagnostic; 65 import java.nio.file.Path; 66 import java.nio.file.Paths; 67 import java.util.ArrayList; 68 import java.util.Arrays; 69 import java.util.LinkedHashMap; 70 import java.util.List; 71 import java.util.Locale; 72 import java.util.Map; 73 import java.util.Set; 74 import java.util.TreeSet; 75 import javax.tools.Diagnostic; 76 import javax.tools.DiagnosticListener; 77 import javax.tools.JavaFileObject; 78 79 public class DeduplicationTest { 80 main(String[] args)81 public static void main(String[] args) throws Exception { 82 JavacFileManager fileManager = new JavacFileManager(new Context(), false, UTF_8); 83 JavacTool javacTool = JavacTool.create(); 84 Listener diagnosticListener = new Listener(); 85 Path testSrc = Paths.get(System.getProperty("test.src")); 86 Path file = testSrc.resolve("Deduplication.java"); 87 String sourceVersion = Integer.toString(Runtime.version().feature()); 88 JavacTask task = 89 javacTool.getTask( 90 null, 91 null, 92 diagnosticListener, 93 Arrays.asList( 94 "-d", 95 ".", 96 "-XDdebug.dumpLambdaToMethodDeduplication", 97 "-XDdebug.dumpLambdaToMethodStats", 98 "--enable-preview", "-source", sourceVersion), 99 null, 100 fileManager.getJavaFileObjects(file)); 101 Map<JCLambda, JCLambda> dedupedLambdas = new LinkedHashMap<>(); 102 task.addTaskListener(new TreeDiffHashTaskListener(dedupedLambdas)); 103 Iterable<? extends JavaFileObject> generated = task.generate(); 104 if (!diagnosticListener.unexpected.isEmpty()) { 105 throw new AssertionError( 106 diagnosticListener 107 .unexpected 108 .stream() 109 .map( 110 d -> 111 String.format( 112 "%s: %s", 113 d.getCode(), d.getMessage(Locale.getDefault()))) 114 .collect(joining(", ", "unexpected diagnostics: ", ""))); 115 } 116 117 // Assert that each group of lambdas was deduplicated. 118 Map<JCLambda, JCLambda> actual = diagnosticListener.deduplicationTargets(); 119 dedupedLambdas.forEach( 120 (k, v) -> { 121 if (!actual.containsKey(k)) { 122 throw new AssertionError("expected " + k + " to be deduplicated"); 123 } 124 if (!v.equals(actual.get(k))) { 125 throw new AssertionError( 126 String.format( 127 "expected %s to be deduplicated to:\n %s\nwas: %s", 128 k, v, actual.get(v))); 129 } 130 }); 131 132 // Assert that the output contains only the canonical lambdas, and not the deduplicated 133 // lambdas. 134 Set<String> bootstrapMethodNames = new TreeSet<>(); 135 for (JavaFileObject output : generated) { 136 ClassFile cf = ClassFile.read(output.openInputStream()); 137 BootstrapMethods_attribute bsm = 138 (BootstrapMethods_attribute) cf.getAttribute(Attribute.BootstrapMethods); 139 for (BootstrapMethodSpecifier b : bsm.bootstrap_method_specifiers) { 140 bootstrapMethodNames.add( 141 ((CONSTANT_MethodHandle_info) 142 cf.constant_pool.get(b.bootstrap_arguments[1])) 143 .getCPRefInfo() 144 .getNameAndTypeInfo() 145 .getName()); 146 } 147 } 148 Set<String> deduplicatedNames = 149 diagnosticListener 150 .expectedLambdaMethods() 151 .stream() 152 .map(s -> s.getSimpleName().toString()) 153 .sorted() 154 .collect(toSet()); 155 if (!deduplicatedNames.equals(bootstrapMethodNames)) { 156 throw new AssertionError( 157 String.format( 158 "expected deduplicated methods: %s, but saw: %s", 159 deduplicatedNames, bootstrapMethodNames)); 160 } 161 } 162 163 /** Returns the parameter symbols of the given lambda. */ paramSymbols(JCLambda lambda)164 private static List<Symbol> paramSymbols(JCLambda lambda) { 165 return lambda.params.stream().map(x -> x.sym).collect(toList()); 166 } 167 168 /** A diagnostic listener that records debug messages related to lambda desugaring. */ 169 @Trusted 170 static class Listener implements DiagnosticListener<JavaFileObject> { 171 172 /** A map from method symbols to lambda trees for desugared lambdas. */ 173 final Map<MethodSymbol, JCLambda> lambdaMethodSymbolsToTrees = new LinkedHashMap<>(); 174 175 /** 176 * A map from lambda trees that were deduplicated to the method symbol of the canonical 177 * lambda implementation method they were deduplicated to. 178 */ 179 final Map<JCLambda, MethodSymbol> deduped = new LinkedHashMap<>(); 180 181 final List<Diagnostic<? extends JavaFileObject>> unexpected = new ArrayList<>(); 182 183 @Override report(Diagnostic<? extends JavaFileObject> diagnostic)184 public void report(Diagnostic<? extends JavaFileObject> diagnostic) { 185 JCDiagnostic d = (JCDiagnostic) diagnostic; 186 switch (d.getCode()) { 187 case "compiler.note.lambda.stat": 188 lambdaMethodSymbolsToTrees.put( 189 (MethodSymbol) d.getArgs()[1], 190 (JCLambda) d.getDiagnosticPosition().getTree()); 191 break; 192 case "compiler.note.verbose.l2m.deduplicate": 193 deduped.put( 194 (JCLambda) d.getDiagnosticPosition().getTree(), 195 (MethodSymbol) d.getArgs()[0]); 196 break; 197 case "compiler.note.preview.filename": 198 case "compiler.note.preview.recompile": 199 break; //ignore 200 default: 201 unexpected.add(diagnostic); 202 } 203 } 204 205 /** Returns expected lambda implementation method symbols. */ expectedLambdaMethods()206 Set<MethodSymbol> expectedLambdaMethods() { 207 return lambdaMethodSymbolsToTrees 208 .entrySet() 209 .stream() 210 .filter(e -> !deduped.containsKey(e.getValue())) 211 .map(Map.Entry::getKey) 212 .collect(toSet()); 213 } 214 215 /** 216 * Returns a mapping from deduplicated lambda trees to the tree of the canonical lambda they 217 * were deduplicated to. 218 */ deduplicationTargets()219 Map<JCLambda, JCLambda> deduplicationTargets() { 220 return deduped.entrySet() 221 .stream() 222 .collect( 223 toMap( 224 Map.Entry::getKey, 225 e -> lambdaMethodSymbolsToTrees.get(e.getValue()), 226 (a, b) -> { 227 throw new AssertionError(); 228 }, 229 LinkedHashMap::new)); 230 } 231 } 232 233 /** 234 * A task listener that tests {@link TreeDiffer} and {@link TreeHasher} on all lambda trees in a 235 * compilation, post-analysis. 236 */ 237 private static class TreeDiffHashTaskListener implements TaskListener { 238 239 /** 240 * A map from deduplicated lambdas to the canonical lambda they are expected to be 241 * deduplicated to. 242 */ 243 private final Map<JCLambda, JCLambda> dedupedLambdas; 244 TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas)245 public TreeDiffHashTaskListener(Map<JCLambda, JCLambda> dedupedLambdas) { 246 this.dedupedLambdas = dedupedLambdas; 247 } 248 249 @Override finished(TaskEvent e)250 public void finished(TaskEvent e) { 251 if (e.getKind() != Kind.ANALYZE) { 252 return; 253 } 254 // Scan the compilation for calls to a varargs method named 'group', whose arguments 255 // are a group of lambdas that are equivalent to each other, but distinct from all 256 // lambdas in the compilation unit outside of that group. 257 List<List<JCLambda>> lambdaGroups = new ArrayList<>(); 258 new TreeScanner() { 259 @Override 260 public void visitApply(JCMethodInvocation tree) { 261 if (tree.getMethodSelect().getTag() == Tag.IDENT 262 && ((JCIdent) tree.getMethodSelect()) 263 .getName() 264 .contentEquals("group")) { 265 List<JCLambda> xs = new ArrayList<>(); 266 for (JCExpression arg : tree.getArguments()) { 267 if (arg instanceof JCTypeCast) { 268 arg = ((JCTypeCast) arg).getExpression(); 269 } 270 xs.add((JCLambda) arg); 271 } 272 lambdaGroups.add(xs); 273 } 274 super.visitApply(tree); 275 } 276 }.scan((JCCompilationUnit) e.getCompilationUnit()); 277 for (int i = 0; i < lambdaGroups.size(); i++) { 278 List<JCLambda> curr = lambdaGroups.get(i); 279 JCLambda first = null; 280 // Assert that all pairwise combinations of lambdas in the group are equal, and 281 // hash to the same value. 282 for (JCLambda lhs : curr) { 283 if (first == null) { 284 first = lhs; 285 } else { 286 dedupedLambdas.put(lhs, first); 287 } 288 for (JCLambda rhs : curr) { 289 if (!new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) 290 .scan(lhs.body, rhs.body)) { 291 throw new AssertionError( 292 String.format( 293 "expected lambdas to be equal\n%s\n%s", lhs, rhs)); 294 } 295 if (TreeHasher.hash(lhs, paramSymbols(lhs)) 296 != TreeHasher.hash(rhs, paramSymbols(rhs))) { 297 throw new AssertionError( 298 String.format( 299 "expected lambdas to hash to the same value\n%s\n%s", 300 lhs, rhs)); 301 } 302 } 303 } 304 // Assert that no lambdas in a group are equal to any lambdas outside that group, 305 // or hash to the same value as lambda outside the group. 306 // (Note that the hash collisions won't result in correctness problems but could 307 // regress performs, and do not currently occurr for any of the test inputs.) 308 for (int j = 0; j < lambdaGroups.size(); j++) { 309 if (i == j) { 310 continue; 311 } 312 for (JCLambda lhs : curr) { 313 for (JCLambda rhs : lambdaGroups.get(j)) { 314 if (new TreeDiffer(paramSymbols(lhs), paramSymbols(rhs)) 315 .scan(lhs.body, rhs.body)) { 316 throw new AssertionError( 317 String.format( 318 "expected lambdas to not be equal\n%s\n%s", 319 lhs, rhs)); 320 } 321 if (TreeHasher.hash(lhs, paramSymbols(lhs)) 322 == TreeHasher.hash(rhs, paramSymbols(rhs))) { 323 throw new AssertionError( 324 String.format( 325 "expected lambdas to hash to different values\n%s\n%s", 326 lhs, rhs)); 327 } 328 } 329 } 330 } 331 } 332 lambdaGroups.clear(); 333 } 334 } 335 } 336