diff --git a/pom.xml b/pom.xml index 4539f6f..d50b84d 100644 --- a/pom.xml +++ b/pom.xml @@ -14,7 +14,9 @@ UTF-8 - 25 + + 25 25 @@ -52,10 +54,28 @@ maven-compiler-plugin 3.14.1 + + + --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED + + maven-surefire-plugin 3.5.4 + + + --add-exports=jdk.compiler/com.sun.tools.javac.code=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.tree=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.util=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.processing=ALL-UNNAMED + --add-exports=jdk.compiler/com.sun.tools.javac.api=ALL-UNNAMED + + org.jacoco diff --git a/src/main/java/com/garciat/typeclasses/TypeClasses.java b/src/main/java/com/garciat/typeclasses/TypeClasses.java index 39592f6..67680bb 100644 --- a/src/main/java/com/garciat/typeclasses/TypeClasses.java +++ b/src/main/java/com/garciat/typeclasses/TypeClasses.java @@ -20,6 +20,17 @@ public static T witness(Ty ty) { }; } + /** + * Parameterless witness method that should be rewritten by the compiler. This method should never + * be called at runtime; the compiler will replace it with the appropriate witness constructor + * calls. + */ + public static T witness() { + throw new AssertionError( + "witness() should have been rewritten by the compiler. " + + "Make sure the WitnessResolutionChecker annotation processor is enabled."); + } + public static class WitnessResolutionException extends RuntimeException { private WitnessResolutionException(SummonError error) { super(error.format()); diff --git a/src/main/java/com/garciat/typeclasses/processor/AstRewriter.java b/src/main/java/com/garciat/typeclasses/processor/AstRewriter.java new file mode 100644 index 0000000..c745a49 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/AstRewriter.java @@ -0,0 +1,140 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.processor.WitnessResolution.InstantiationPlan; +import com.sun.source.tree.Tree; +import com.sun.source.util.TreePath; +import com.sun.source.util.Trees; +import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.TreeMaker; +import com.sun.tools.javac.util.Context; +import com.sun.tools.javac.util.List; +import com.sun.tools.javac.util.Names; +import javax.lang.model.element.ExecutableElement; + +/** Handles AST rewriting for parameterless witness() calls. */ +final class AstRewriter { + private final TreeMaker treeMaker; + private final Names names; + private final Trees trees; + + AstRewriter(Context context, Trees trees) { + this.treeMaker = TreeMaker.instance(context); + this.names = Names.instance(context); + this.trees = trees; + } + + /** + * Translates an InstantiationPlan into a JCTree representing the witness constructor call chain. + */ + JCTree.JCExpression buildWitnessExpression(InstantiationPlan plan) { + return switch (plan) { + case InstantiationPlan.PlanStep(var constructor, var dependencies) -> { + // Get the ExecutableElement for the witness constructor + ExecutableElement method = constructor.method(); + + // Build the method invocation expression + // Format: ClassName.methodName(dep1, dep2, ...) + JCTree.JCExpression methodSelect = buildMethodSelect(method); + + // Recursively build expressions for dependencies + List args = + List.from( + dependencies.stream() + .map(this::buildWitnessExpression) + .toArray(JCTree.JCExpression[]::new)); + + // Create the method invocation + yield treeMaker.Apply(List.nil(), methodSelect, args); + } + }; + } + + /** + * Builds a method select expression for a given executable element. For a static method + * "ClassName.methodName", this creates the appropriate JCTree.JCFieldAccess. + */ + private JCTree.JCExpression buildMethodSelect(ExecutableElement method) { + // Get the enclosing class + var enclosingElement = method.getEnclosingElement(); + + // Build the class reference expression + JCTree.JCExpression classExpr = buildClassReference(enclosingElement.toString()); + + // Create field access: ClassName.methodName + return treeMaker.Select(classExpr, names.fromString(method.getSimpleName().toString())); + } + + /** + * Builds a class reference expression from a fully qualified class name. For example, + * "com.example.MyClass" becomes a chain of field accesses. + */ + private JCTree.JCExpression buildClassReference(String qualifiedName) { + String[] parts = qualifiedName.split("\\."); + JCTree.JCExpression expr = treeMaker.Ident(names.fromString(parts[0])); + + for (int i = 1; i < parts.length; i++) { + expr = treeMaker.Select(expr, names.fromString(parts[i])); + } + + return expr; + } + + /** + * Replaces a tree node in the AST by modifying the parent node. + * + *

NOTE: This method is currently not used. AST rewriting is performed via TreeTranslator in + * WitnessCallTranslator.visitApply() which sets the 'result' field. This method is kept as an + * alternative approach for future reference or if a different rewriting strategy is needed. + * + * @param path the tree path to the node to replace + * @param replacement the new tree to use + */ + @SuppressWarnings("unused") + void replaceTree(TreePath path, JCTree.JCExpression replacement) { + Tree leaf = path.getLeaf(); + if (!(leaf instanceof JCTree.JCMethodInvocation originalInvocation)) { + return; + } + + // Get parent context + TreePath parentPath = path.getParentPath(); + if (parentPath == null) { + return; + } + + Tree parent = parentPath.getLeaf(); + + // We need to replace the method invocation in its parent + // This is complex and depends on the parent type, so we'll use a simpler approach: + // Modify the tree in place by replacing fields + + if (parent instanceof JCTree.JCVariableDecl varDecl) { + // Case: variable declaration + varDecl.init = replacement; + } else if (parent instanceof JCTree.JCExpressionStatement exprStmt) { + // Case: expression statement + exprStmt.expr = replacement; + } else if (parent instanceof JCTree.JCReturn returnStmt) { + // Case: return statement + returnStmt.expr = replacement; + } else if (parent instanceof JCTree.JCAssign assign) { + // Case: assignment on the right side + if (assign.rhs == originalInvocation) { + assign.rhs = replacement; + } + } else if (parent instanceof JCTree.JCMethodInvocation parentInvocation) { + // Case: method argument + List args = parentInvocation.args; + List newArgs = List.nil(); + for (JCTree.JCExpression arg : args) { + if (arg == originalInvocation) { + newArgs = newArgs.append(replacement); + } else { + newArgs = newArgs.append(arg); + } + } + parentInvocation.args = newArgs; + } + // Add more cases as needed for other parent types + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index ba57600..1193512 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -6,8 +6,12 @@ import com.garciat.typeclasses.api.Ty; import com.garciat.typeclasses.types.Unit; import com.sun.source.tree.MethodInvocationTree; -import com.sun.source.util.TreePathScanner; -import com.sun.source.util.Trees; +import com.sun.source.util.*; +import com.sun.tools.javac.api.BasicJavacTask; +import com.sun.tools.javac.code.Symbol; +import com.sun.tools.javac.tree.JCTree; +import com.sun.tools.javac.tree.TreeTranslator; +import com.sun.tools.javac.util.Context; import java.lang.reflect.Method; import java.util.Set; import javax.annotation.processing.*; @@ -20,20 +24,63 @@ @SupportedSourceVersion(SourceVersion.RELEASE_25) public final class WitnessResolutionChecker extends AbstractProcessor { private static final Method WITNESS_METHOD; + private static final Method PARAMETERLESS_WITNESS_METHOD; static { try { WITNESS_METHOD = TypeClasses.class.getMethod("witness", Ty.class); + PARAMETERLESS_WITNESS_METHOD = TypeClasses.class.getMethod("witness"); } catch (NoSuchMethodException e) { throw new RuntimeException(e); } } private Trees trees; + private Context context; @Override public synchronized void init(ProcessingEnvironment processingEnv) { + super.init(processingEnv); this.trees = Trees.instance(processingEnv); + + // Get the JavacTask and Context for AST rewriting + if (processingEnv + instanceof com.sun.tools.javac.processing.JavacProcessingEnvironment javacEnv) { + this.context = javacEnv.getContext(); + + // Try to get JavacTask from the context using different approaches + try { + // First try: Get BasicJavacTask directly + BasicJavacTask task = context.get(BasicJavacTask.class); + if (task != null) { + task.addTaskListener(new AstTransformListener()); + } else { + // Second try: Get JavacTaskImpl + com.sun.tools.javac.api.JavacTaskImpl taskImpl = + context.get(com.sun.tools.javac.api.JavacTaskImpl.class); + if (taskImpl != null) { + taskImpl.addTaskListener(new AstTransformListener()); + } else { + // Note: AST rewriting will not work if TaskListener cannot be registered + processingEnv + .getMessager() + .printMessage( + javax.tools.Diagnostic.Kind.WARNING, + "Could not register TaskListener for AST rewriting. " + + "Parameterless witness() calls will not be rewritten."); + } + } + } catch (Exception e) { + // Log the error for debugging + processingEnv + .getMessager() + .printMessage( + javax.tools.Diagnostic.Kind.WARNING, + "Failed to register TaskListener for AST rewriting: " + + e.getMessage() + + ". Parameterless witness() calls will not be rewritten."); + } + } } @Override @@ -44,6 +91,21 @@ public boolean process(Set annotations, RoundEnvironment return false; } + /** TaskListener that runs during compilation to perform AST transformations. */ + private class AstTransformListener implements TaskListener { + @Override + public void finished(TaskEvent e) { + // Try both ENTER and ANALYZE phases + if (e.getKind() == TaskEvent.Kind.ENTER || e.getKind() == TaskEvent.Kind.ANALYZE) { + // Perform AST transformation + if (e.getCompilationUnit() != null && context != null) { + JCTree.JCCompilationUnit cu = (JCTree.JCCompilationUnit) e.getCompilationUnit(); + cu.accept(new WitnessCallTranslator(context, trees)); + } + } + } + } + /** Scanner that finds calls to TypeClasses.witness() and validates them. */ private static class WitnessCallScanner extends TreePathScanner { private final Trees trees; @@ -56,6 +118,10 @@ private WitnessCallScanner(Trees trees) { @Override public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { + // Check if this is a parameterless witness() call - validate it + handleParameterlessWitnessCall(node); + + // Check if this is a witness(Ty) call TreeParser.identity() .guard( TreeParser.currentElement() @@ -87,5 +153,96 @@ public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { return super.visitMethodInvocation(node, arg); } + + /** Validates parameterless witness() calls. */ + private void handleParameterlessWitnessCall(MethodInvocationTree node) { + TreeParser.identity() + .guard( + TreeParser.currentElement() + .flatMap(TreeParser.methodMatches(PARAMETERLESS_WITNESS_METHOD))) + .filter(invocation -> invocation.getArguments().isEmpty()) + .parse(trees, getCurrentPath(), node) + .fold( + Unit::unit, + _ -> { + // Get the expected type from the context + var expectedType = trees.getTypeMirror(getCurrentPath()); + if (expectedType != null) { + var parsedType = system.parse(expectedType); + + // Validate that witness can be resolved + WitnessResolution.resolve(system, parsedType) + .fold( + error -> { + this.trees.printMessage( + Diagnostic.Kind.ERROR, + "Failed to resolve witness for parameterless witness() call with type: " + + expectedType + + "\nReason: " + + error.format(), + getCurrentPath().getLeaf(), + getCurrentPath().getCompilationUnit()); + return unit(); + }, + plan -> unit()); + } + return unit(); + }); + } + } + + /** Tree translator that rewrites parameterless witness() calls. */ + private static class WitnessCallTranslator extends TreeTranslator { + private final AstRewriter astRewriter; + private final StaticWitnessSystem system; + private final Trees trees; + + private WitnessCallTranslator(Context context, Trees trees) { + this.astRewriter = new AstRewriter(context, trees); + this.system = new StaticWitnessSystem(); + this.trees = trees; + } + + @Override + public void visitApply(JCTree.JCMethodInvocation tree) { + // First visit the method expression and arguments + tree.meth = translate(tree.meth); + tree.args = translate(tree.args); + tree.typeargs = translate(tree.typeargs); + + // Now check if this is a call to the parameterless witness() method + if (tree.meth instanceof JCTree.JCFieldAccess fieldAccess) { + if (fieldAccess.sym instanceof Symbol.MethodSymbol methodSymbol) { + // Check if it matches our parameterless witness method + if (methodSymbol.getSimpleName().toString().equals("witness") + && methodSymbol.params().isEmpty() + && methodSymbol.owner.toString().equals(TypeClasses.class.getName())) { + + // This is a parameterless witness() call - try to rewrite it + if (tree.type != null) { + var parsedType = system.parse(tree.type); + WitnessResolution.resolve(system, parsedType) + .fold( + error -> { + // Validation phase already reported the error, keep original tree + result = tree; + return null; + }, + plan -> { + // Build the replacement expression + JCTree.JCExpression replacement = astRewriter.buildWitnessExpression(plan); + // Set the result to replace this node + result = replacement; + return null; + }); + return; // Exit early after processing + } + } + } + } + + // If we didn't match or replace, keep the translated original + result = tree; + } } } diff --git a/src/test/java/com/garciat/typeclasses/ParameterlessWitnessTest.java b/src/test/java/com/garciat/typeclasses/ParameterlessWitnessTest.java new file mode 100644 index 0000000..97f1dc1 --- /dev/null +++ b/src/test/java/com/garciat/typeclasses/ParameterlessWitnessTest.java @@ -0,0 +1,26 @@ +package com.garciat.typeclasses; + +import static com.garciat.typeclasses.TypeClasses.witness; +import static org.assertj.core.api.Assertions.assertThat; + +import com.garciat.typeclasses.testclasses.TestShow; +import org.junit.jupiter.api.Test; + +/** Test cases for parameterless witness() calls that should be rewritten by the compiler. */ +final class ParameterlessWitnessTest { + + @Test + void parameterlessWitnessSimpleType() { + // Test that parameterless witness() can resolve a simple type + TestShow show = witness(); + assertThat(show).isNotNull(); + assertThat(show.show("test")).isEqualTo("string:test"); + } + + @Test + void parameterlessWitnessInteger() { + TestShow show = witness(); + assertThat(show).isNotNull(); + assertThat(show.show(42)).isEqualTo("int:42"); + } +} diff --git a/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java b/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java index e31726e..a1155f8 100644 --- a/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java +++ b/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionCheckerTest.java @@ -47,7 +47,15 @@ public void test() throws IOException { boolean success = task.call(); // Then - assertThat(diagnostics.getDiagnostics()).isEmpty(); + var unexpectedDiagnostics = + diagnostics.getDiagnostics().stream() + .filter( + d -> + !(d.getKind() == Diagnostic.Kind.WARNING + && d.getMessage(null) + .contains("Could not register TaskListener for AST rewriting"))) + .toList(); + assertThat(unexpectedDiagnostics).isEmpty(); assertThat(success).isTrue(); } }