From 4fe943f712ae4a47f8ee986b768ddd1a1db14e3f Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 15:57:38 +0100 Subject: [PATCH 1/8] Implement Java Compiler Plugin for witness resolution verification --- .../processor/OverlappingInstances.java | 35 +++++ .../typeclasses/processor/ParsedType.java | 34 +++++ .../processor/StaticWitnessSystem.java | 122 ++++++++++++++++++ .../typeclasses/processor/Unification.java | 52 ++++++++ .../processor/WitnessConstructor.java | 11 ++ .../processor/WitnessResolution.java | 78 +++++++++++ .../processor/WitnessResolutionChecker.java | 122 ++++++++++++++++++ .../services/com.sun.source.util.Plugin | 1 + .../com/garciat/typeclasses/ExamplesTest.java | 17 ++- .../WitnessResolutionProcessorTest.java | 71 ++++++++++ 10 files changed, 541 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/garciat/typeclasses/processor/OverlappingInstances.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/ParsedType.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/Unification.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/WitnessConstructor.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/WitnessResolution.java create mode 100644 src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java create mode 100644 src/main/resources/META-INF/services/com.sun.source.util.Plugin create mode 100644 src/test/java/com/garciat/typeclasses/processor/WitnessResolutionProcessorTest.java diff --git a/src/main/java/com/garciat/typeclasses/processor/OverlappingInstances.java b/src/main/java/com/garciat/typeclasses/processor/OverlappingInstances.java new file mode 100644 index 0000000..acb0cbb --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/OverlappingInstances.java @@ -0,0 +1,35 @@ +package com.garciat.typeclasses.processor; + +import static com.garciat.typeclasses.api.TypeClass.Witness.Overlap.OVERLAPPABLE; +import static com.garciat.typeclasses.api.TypeClass.Witness.Overlap.OVERLAPPING; + +import java.util.List; + +public final class OverlappingInstances { + private OverlappingInstances() {} + + /** + * @implSpec 6.8.8.5. + * Overlapping instances + */ + public static List reduce(List candidates) { + return candidates.stream() + .filter( + iX -> + candidates.stream().filter(iY -> iX != iY).noneMatch(iY -> isOverlappedBy(iX, iY))) + .toList(); + } + + private static boolean isOverlappedBy(WitnessConstructor iX, WitnessConstructor iY) { + return (iX.overlap() == OVERLAPPABLE || iY.overlap() == OVERLAPPING) + && isSubstitutionInstance(iX, iY) + && !isSubstitutionInstance(iY, iX); + } + + private static boolean isSubstitutionInstance( + WitnessConstructor base, WitnessConstructor reference) { + return Unification.unify(base.returnType(), reference.returnType()) + .fold(() -> false, map -> !map.isEmpty()); + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/ParsedType.java b/src/main/java/com/garciat/typeclasses/processor/ParsedType.java new file mode 100644 index 0000000..57bc210 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/ParsedType.java @@ -0,0 +1,34 @@ +package com.garciat.typeclasses.processor; + +import javax.lang.model.type.DeclaredType; +import javax.lang.model.type.PrimitiveType; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.type.TypeVariable; + +public sealed interface ParsedType { + record Var(TypeVariable java) implements ParsedType {} + + record App(ParsedType fun, ParsedType arg) implements ParsedType {} + + record ArrayOf(ParsedType elementType) implements ParsedType {} + + record Const(DeclaredType java) implements ParsedType {} + + record Primitive(PrimitiveType java) implements ParsedType {} + + default String format() { + return switch (this) { + case Var v -> v.java.toString(); + case Const c -> + c.java().asElement().getSimpleName() + + c.java().getTypeArguments().stream() + .map(TypeMirror::toString) + .reduce((a, b) -> a + ", " + b) + .map(s -> "[" + s + "]") + .orElse(""); + case App a -> a.fun.format() + "(" + a.arg.format() + ")"; + case ArrayOf a -> a.elementType.format() + "[]"; + case Primitive p -> p.java().toString(); + }; + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java b/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java new file mode 100644 index 0000000..0d4b494 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java @@ -0,0 +1,122 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.api.TypeClass; +import com.garciat.typeclasses.api.hkt.TApp; +import com.garciat.typeclasses.api.hkt.TPar; +import com.garciat.typeclasses.api.hkt.TagBase; +import com.garciat.typeclasses.impl.utils.Lists; +import com.garciat.typeclasses.types.Maybe; +import com.garciat.typeclasses.types.Pair; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Stream; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.*; +import javax.lang.model.util.Types; + +public class StaticWitnessSystem { + private static final Class TAG_BASE_CLASS = TagBase.class; + private static final Class TAPP_CLASS = TApp.class; + private static final Class TPAR_CLASS = TPar.class; + + private final Types types; + + public StaticWitnessSystem(Types types) { + this.types = types; + } + + public List findRules(ParsedType target) { + return switch (target) { + case ParsedType.App(var fun, var arg) -> Lists.concat(findRules(fun), findRules(arg)); + case ParsedType.Const(var java) -> + java.asElement().getEnclosedElements().stream() + .flatMap(isInstanceOf(ExecutableElement.class)) + .flatMap(method -> parseWitnessConstructor(method).stream()) + .toList(); + case ParsedType.Var(var ignore) -> List.of(); + case ParsedType.ArrayOf(var ignore) -> List.of(); + case ParsedType.Primitive(var ignore) -> List.of(); + }; + } + + public Maybe parseWitnessConstructor(ExecutableElement method) { + if (method.getModifiers().contains(Modifier.PUBLIC) + && method.getModifiers().contains(Modifier.STATIC) + && method.getAnnotation(TypeClass.Witness.class) instanceof TypeClass.Witness witnessAnn) { + return Maybe.just( + new WitnessConstructor( + method, + witnessAnn.overlap(), + method.getParameters().stream() + .map(VariableElement::asType) + .map(this::parse) + .toList(), + parse(method.getReturnType()))); + + } else { + return Maybe.nothing(); + } + } + + public List parseAll(List types) { + return types.stream().map(this::parse).toList(); + } + + public ParsedType parse(TypeMirror type) { + return switch (type) { + case TypeVariable tv -> new ParsedType.Var(tv); + case ArrayType at -> new ParsedType.ArrayOf(parse(at.getComponentType())); + // Store primitive as its boxed type representation, just to have a DeclaredType. + case PrimitiveType pt -> new ParsedType.Primitive(pt); + case DeclaredType dt + when parseTagType(dt) instanceof Maybe.Just(var realType) -> + new ParsedType.Const(realType); + case DeclaredType dt when dt.getTypeArguments().isEmpty() -> new ParsedType.Const(dt); + case DeclaredType dt + when parseAppType(dt) + instanceof + Maybe.Just>( + Pair(var fun, var arg)) -> + new ParsedType.App(parse(fun), parse(arg)); + case DeclaredType dt -> + parseAll(dt.getTypeArguments()).stream() + .reduce(parse(types.erasure(dt)), ParsedType.App::new); + case WildcardType wt -> + throw new IllegalArgumentException("Cannot parse wildcard type: " + wt); + default -> throw new IllegalArgumentException("Unsupported type: " + type); + }; + } + + private static Maybe parseTagType(DeclaredType t) { + if (t.asElement() instanceof TypeElement tag + && tag.getEnclosingElement() instanceof TypeElement enclosing + && enclosing.asType() instanceof DeclaredType enclosingType + && tag.getSuperclass() instanceof DeclaredType tagSuperType + && tagSuperType.asElement() instanceof TypeElement tagSuper + && tagSuper.getQualifiedName().contentEquals(TAG_BASE_CLASS.getName())) { + return Maybe.just(enclosingType); + } else { + return Maybe.nothing(); + } + } + + private Maybe> parseAppType(DeclaredType t) { + return t.getTypeArguments().size() == 2 && isAppType(types.erasure(t)) + ? Maybe.just(new Pair<>(t.getTypeArguments().get(0), t.getTypeArguments().get(1))) + : Maybe.nothing(); + } + + private boolean isAppType(TypeMirror erasure) { + return erasure instanceof DeclaredType dt + && dt.asElement() instanceof TypeElement te + && (te.getQualifiedName().contentEquals(TAPP_CLASS.getName()) + || te.getQualifiedName().contentEquals(TPAR_CLASS.getName())); + } + + private static Function> isInstanceOf(Class cls) { + return u -> cls.isInstance(u) ? Stream.of(cls.cast(u)) : Stream.empty(); + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/Unification.java b/src/main/java/com/garciat/typeclasses/processor/Unification.java new file mode 100644 index 0000000..e340617 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/Unification.java @@ -0,0 +1,52 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.impl.utils.Maps; +import com.garciat.typeclasses.types.Maybe; +import com.garciat.typeclasses.types.Pair; +import java.util.List; +import java.util.Map; + +public final class Unification { + private Unification() {} + + public static Maybe> unify(ParsedType t1, ParsedType t2) { + return switch (Pair.of(t1, t2)) { + case Pair(ParsedType.Var var1, ParsedType.Primitive p) -> + Maybe.nothing(); // no primitives in generics + case Pair(ParsedType.Var var1, var t) -> Maybe.just(Map.of(var1, t)); + case Pair(ParsedType.Const const1, ParsedType.Const const2) + when const1.equals(const2) -> + Maybe.just(Map.of()); + case Pair( + ParsedType.App(var fun1, var arg1), + ParsedType.App(var fun2, var arg2)) -> + Maybe.apply(Maps::merge, unify(fun1, fun2), unify(arg1, arg2)); + case Pair( + ParsedType.ArrayOf(var elem1), + ParsedType.ArrayOf(var elem2)) -> + unify(elem1, elem2); + case Pair( + ParsedType.Primitive(var prim1), + ParsedType.Primitive(var prim2)) + when prim1.equals(prim2) -> + Maybe.just(Map.of()); + default -> Maybe.nothing(); + }; + } + + public static ParsedType substitute(Map map, ParsedType type) { + return switch (type) { + case ParsedType.Var var -> map.getOrDefault(var, var); + case ParsedType.App(var fun, var arg) -> + new ParsedType.App(substitute(map, fun), substitute(map, arg)); + case ParsedType.ArrayOf var -> new ParsedType.ArrayOf(substitute(map, var.elementType())); + case ParsedType.Primitive p -> p; + case ParsedType.Const c -> c; + }; + } + + public static List substituteAll( + Map map, List types) { + return types.stream().map(t -> substitute(map, t)).toList(); + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessConstructor.java b/src/main/java/com/garciat/typeclasses/processor/WitnessConstructor.java new file mode 100644 index 0000000..d89d7ba --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessConstructor.java @@ -0,0 +1,11 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.api.TypeClass; +import java.util.List; +import javax.lang.model.element.ExecutableElement; + +public record WitnessConstructor( + ExecutableElement method, + TypeClass.Witness.Overlap overlap, + List paramTypes, + ParsedType returnType) {} diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolution.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolution.java new file mode 100644 index 0000000..d5a5dd0 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolution.java @@ -0,0 +1,78 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.impl.utils.ZeroOneMore; +import com.garciat.typeclasses.types.Either; +import com.garciat.typeclasses.types.Maybe; +import java.util.List; +import java.util.stream.Collectors; + +public final class WitnessResolution { + private WitnessResolution() {} + + /** Resolves a ParsedType into an InstantiationPlan. */ + public static Either resolve( + StaticWitnessSystem system, ParsedType target) { + + List matches = + OverlappingInstances.reduce(system.findRules(target)).stream() + .flatMap(rule -> tryMatch(rule, target).stream()) + .toList(); + + return switch (ZeroOneMore.of(matches)) { + case ZeroOneMore.One(Match(var rule, var requirements)) -> + Either.traverse(requirements, req -> resolve(system, req)) + .map( + dependencies -> new InstantiationPlan.PlanStep(rule, dependencies)) + .mapLeft(error -> new ResolutionError.Nested(target, error)); + case ZeroOneMore.Zero() -> Either.left(new ResolutionError.NotFound(target)); + case ZeroOneMore.More(var matches2) -> + Either.left( + new ResolutionError.Ambiguous(target, matches2.stream().map(Match::rule).toList())); + }; + } + + private static Maybe tryMatch(WitnessConstructor rule, ParsedType target) { + return Unification.unify(rule.returnType(), target) + .map(map -> Unification.substituteAll(map, rule.paramTypes())) + .map(requirements -> new Match(rule, requirements)); + } + + record Match(WitnessConstructor rule, List requirements) {} + + /** + * Represents the fully resolved instantiation plan. This is a tree structure where each node is a + * step in the instantiation process, with dependencies on other steps. + */ + public sealed interface InstantiationPlan { + record PlanStep(WitnessConstructor target, List dependencies) + implements InstantiationPlan {} + } + + public sealed interface ResolutionError { + record NotFound(ParsedType target) implements ResolutionError {} + + record Ambiguous(ParsedType target, List candidates) + implements ResolutionError {} + + record Nested(ParsedType target, ResolutionError cause) implements ResolutionError {} + + default String format() { + return switch (this) { + case NotFound(ParsedType target) -> "No witness found for type: " + target.format(); + case Ambiguous(ParsedType target, List candidates) -> + "Ambiguous witnesses found for type: " + + target.format() + + "\nCandidates:\n" + + candidates.stream() + .map(WitnessConstructor::toString) + .collect(Collectors.joining("\n")) + .indent(2); + case Nested(ParsedType target, ResolutionError cause) -> + "While resolving witness for type: " + + target.format() + + "\nCaused by: " + + cause.format().indent(2); + }; + } + } +} diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java new file mode 100644 index 0000000..45f9d24 --- /dev/null +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -0,0 +1,122 @@ +package com.garciat.typeclasses.processor; + +import com.garciat.typeclasses.TypeClasses; +import com.garciat.typeclasses.api.Ty; +import com.garciat.typeclasses.types.Either; +import com.sun.source.tree.*; +import com.sun.source.util.*; +import java.lang.reflect.Method; +import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.type.DeclaredType; +import javax.lang.model.type.TypeMirror; +import javax.lang.model.util.Types; +import javax.tools.Diagnostic; + +public final class WitnessResolutionChecker implements Plugin { + private static final Method WITNESS_METHOD; + + static { + try { + WITNESS_METHOD = TypeClasses.class.getMethod("witness", Ty.class); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } + + @Override + public String getName() { + return "WitnessResolutionChecker"; + } + + @Override + public void init(JavacTask task, String... args) { + task.addTaskListener( + new TaskListener() { + @Override + public void finished(TaskEvent e) { + if (e.getKind() != TaskEvent.Kind.ANALYZE) { + return; + } + + if (e.getCompilationUnit() == null) { + return; + } + + Trees trees = Trees.instance(task); + new WitnessCallScanner(trees, task.getTypes()).scan(e.getCompilationUnit(), trees); + } + }); + } + + /** Scanner that finds calls to TypeClasses.witness() and validates them. */ + private static class WitnessCallScanner extends TreePathScanner { + private final Trees trees; + private final Types types; + private final StaticWitnessSystem system; + + WitnessCallScanner(Trees trees, Types types) { + this.trees = trees; + this.types = types; + this.system = new StaticWitnessSystem(types); + } + + @Override + public Void visitClass(ClassTree node, Trees trees) { + return super.visitClass(node, trees); + } + + @Override + public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { + Element element = trees.getElement(getCurrentPath()); + + if (isMethodCall(WITNESS_METHOD, element)) { + // Found a call to TypeClasses.witness() + // The first argument is expected to be of the form "new Ty<>() {}" + ExpressionTree firstArg = node.getArguments().get(0); + + // Check if it's a "new Ty<>() {}" anonymous class creation + if (firstArg instanceof NewClassTree newClass) { + Tree tyApp = newClass.getClassBody().getImplementsClause().get(0); + + TypeMirror typeMirror = + trees.getTypeMirror(trees.getPath(getCurrentPath().getCompilationUnit(), tyApp)); + + // Try to extract witness type and verify resolution + if (typeMirror instanceof DeclaredType declaredType) { + TypeMirror witnessTypeMirror = declaredType.getTypeArguments().get(0); + + ParsedType target = system.parse(witnessTypeMirror); + + switch (WitnessResolution.resolve(system, target)) { + case Either.Left< + WitnessResolution.ResolutionError, WitnessResolution.InstantiationPlan>( + var error) -> + this.trees.printMessage( + Diagnostic.Kind.ERROR, + "Failed to resolve witness for type: " + + target.format() + + "\nReason: " + + error.format(), + getCurrentPath().getLeaf(), + getCurrentPath().getCompilationUnit()); + case Either.Right< + WitnessResolution.ResolutionError, WitnessResolution.InstantiationPlan> + v -> {} + } + } + } + } + + return super.visitMethodInvocation(node, trees); + } + } + + private static boolean isMethodCall(Method target, Element element) { + return element instanceof ExecutableElement method + && method.getSimpleName().contentEquals(target.getName()) + && method.getEnclosingElement() instanceof TypeElement methodOwner + && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName()); + } +} diff --git a/src/main/resources/META-INF/services/com.sun.source.util.Plugin b/src/main/resources/META-INF/services/com.sun.source.util.Plugin new file mode 100644 index 0000000..ae61afe --- /dev/null +++ b/src/main/resources/META-INF/services/com.sun.source.util.Plugin @@ -0,0 +1 @@ +com.garciat.typeclasses.processor.WitnessResolutionChecker \ No newline at end of file diff --git a/src/test/java/com/garciat/typeclasses/ExamplesTest.java b/src/test/java/com/garciat/typeclasses/ExamplesTest.java index 1addc77..353ef6e 100644 --- a/src/test/java/com/garciat/typeclasses/ExamplesTest.java +++ b/src/test/java/com/garciat/typeclasses/ExamplesTest.java @@ -3,8 +3,21 @@ import static com.garciat.typeclasses.TypeClasses.witness; import com.garciat.typeclasses.api.Ty; -import com.garciat.typeclasses.classes.*; -import com.garciat.typeclasses.types.*; +import com.garciat.typeclasses.classes.Arbitrary; +import com.garciat.typeclasses.classes.Eq; +import com.garciat.typeclasses.classes.Foldable; +import com.garciat.typeclasses.classes.Monoid; +import com.garciat.typeclasses.classes.Ord; +import com.garciat.typeclasses.classes.PrintAll; +import com.garciat.typeclasses.classes.Show; +import com.garciat.typeclasses.classes.SumAllInt; +import com.garciat.typeclasses.classes.Traversable; +import com.garciat.typeclasses.types.F3; +import com.garciat.typeclasses.types.FwdList; +import com.garciat.typeclasses.types.JavaList; +import com.garciat.typeclasses.types.Maybe; +import com.garciat.typeclasses.types.Sum; +import com.garciat.typeclasses.types.Unit; import java.util.List; import java.util.Map; import java.util.Optional; diff --git a/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionProcessorTest.java b/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionProcessorTest.java new file mode 100644 index 0000000..3569d6c --- /dev/null +++ b/src/test/java/com/garciat/typeclasses/processor/WitnessResolutionProcessorTest.java @@ -0,0 +1,71 @@ +package com.garciat.typeclasses.processor; + +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; + +import java.io.File; +import java.io.IOException; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.List; +import javax.tools.DiagnosticCollector; +import javax.tools.JavaFileObject; +import javax.tools.StandardLocation; +import javax.tools.ToolProvider; +import org.jspecify.annotations.Nullable; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +public class WitnessResolutionProcessorTest { + @Nullable @TempDir Path tempDir; + + @Test + public void test() throws IOException { + requireNonNull(tempDir); + + // Given + var compiler = ToolProvider.getSystemJavaCompiler(); + + var diagnostics = new DiagnosticCollector(); + + var fileManager = compiler.getStandardFileManager(diagnostics, null, null); + fileManager.setLocation(StandardLocation.CLASS_OUTPUT, List.of(tempDir.toFile())); + fileManager.setLocation(StandardLocation.SOURCE_OUTPUT, List.of(tempDir.toFile())); + + var files = new java.util.ArrayList(); + files.add(new File("src/test/java/com/garciat/typeclasses/ExamplesTest.java")); + + Files.walkFileTree( + Path.of("src/main/java"), + new SimpleFileVisitor<>() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) { + if (file.toString().endsWith(".java")) { + files.add(file.toFile()); + } + return FileVisitResult.CONTINUE; + } + }); + + var compilationUnits = fileManager.getJavaFileObjectsFromFiles(files); + + var task = + compiler.getTask( + null, + fileManager, + diagnostics, + List.of("-Xplugin:WitnessResolutionChecker"), + null, + compilationUnits); + + // When + boolean success = task.call(); + + // Then + assertThat(diagnostics.getDiagnostics()).isEmpty(); + assertThat(success).isTrue(); + } +} From bd7b8d0e24bafe223f630377b28bcfda6b6b1732 Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 16:09:52 +0100 Subject: [PATCH 2/8] Bump jacoco --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index b6bd176..b7534dd 100644 --- a/pom.xml +++ b/pom.xml @@ -92,7 +92,7 @@ org.jacoco jacoco-maven-plugin - 0.8.12 + 0.8.14 From 4536ec587c364b2355c054c99478b5100e7166f4 Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 16:09:59 +0100 Subject: [PATCH 3/8] Clean up --- .../processor/WitnessResolutionChecker.java | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index 45f9d24..c773b03 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -53,12 +53,10 @@ public void finished(TaskEvent e) { /** Scanner that finds calls to TypeClasses.witness() and validates them. */ private static class WitnessCallScanner extends TreePathScanner { private final Trees trees; - private final Types types; private final StaticWitnessSystem system; WitnessCallScanner(Trees trees, Types types) { this.trees = trees; - this.types = types; this.system = new StaticWitnessSystem(types); } @@ -74,18 +72,18 @@ public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { if (isMethodCall(WITNESS_METHOD, element)) { // Found a call to TypeClasses.witness() // The first argument is expected to be of the form "new Ty<>() {}" - ExpressionTree firstArg = node.getArguments().get(0); + ExpressionTree firstArg = node.getArguments().getFirst(); // Check if it's a "new Ty<>() {}" anonymous class creation if (firstArg instanceof NewClassTree newClass) { - Tree tyApp = newClass.getClassBody().getImplementsClause().get(0); + Tree tyApp = newClass.getClassBody().getImplementsClause().getFirst(); TypeMirror typeMirror = trees.getTypeMirror(trees.getPath(getCurrentPath().getCompilationUnit(), tyApp)); // Try to extract witness type and verify resolution if (typeMirror instanceof DeclaredType declaredType) { - TypeMirror witnessTypeMirror = declaredType.getTypeArguments().get(0); + TypeMirror witnessTypeMirror = declaredType.getTypeArguments().getFirst(); ParsedType target = system.parse(witnessTypeMirror); @@ -103,7 +101,7 @@ public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { getCurrentPath().getCompilationUnit()); case Either.Right< WitnessResolution.ResolutionError, WitnessResolution.InstantiationPlan> - v -> {} + ignore -> {} } } } From d0c16a0cf8e3e04d58fb6882fa9e58d614cb4657 Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 19:01:42 +0100 Subject: [PATCH 4/8] Simplify some bits --- .../processor/StaticWitnessSystem.java | 24 ++++++++----------- .../processor/WitnessResolutionChecker.java | 21 ++++++---------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java b/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java index 0d4b494..8bfe4ee 100644 --- a/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java +++ b/src/main/java/com/garciat/typeclasses/processor/StaticWitnessSystem.java @@ -15,18 +15,13 @@ import javax.lang.model.element.TypeElement; import javax.lang.model.element.VariableElement; import javax.lang.model.type.*; -import javax.lang.model.util.Types; public class StaticWitnessSystem { private static final Class TAG_BASE_CLASS = TagBase.class; private static final Class TAPP_CLASS = TApp.class; private static final Class TPAR_CLASS = TPar.class; - private final Types types; - - public StaticWitnessSystem(Types types) { - this.types = types; - } + public StaticWitnessSystem() {} public List findRules(ParsedType target) { return switch (target) { @@ -42,7 +37,7 @@ public List findRules(ParsedType target) { }; } - public Maybe parseWitnessConstructor(ExecutableElement method) { + private Maybe parseWitnessConstructor(ExecutableElement method) { if (method.getModifiers().contains(Modifier.PUBLIC) && method.getModifiers().contains(Modifier.STATIC) && method.getAnnotation(TypeClass.Witness.class) instanceof TypeClass.Witness witnessAnn) { @@ -61,10 +56,6 @@ public Maybe parseWitnessConstructor(ExecutableElement metho } } - public List parseAll(List types) { - return types.stream().map(this::parse).toList(); - } - public ParsedType parse(TypeMirror type) { return switch (type) { case TypeVariable tv -> new ParsedType.Var(tv); @@ -82,8 +73,9 @@ when parseAppType(dt) Pair(var fun, var arg)) -> new ParsedType.App(parse(fun), parse(arg)); case DeclaredType dt -> - parseAll(dt.getTypeArguments()).stream() - .reduce(parse(types.erasure(dt)), ParsedType.App::new); + dt.getTypeArguments().stream() + .map(this::parse) + .reduce(new ParsedType.Const(erasure(dt)), ParsedType.App::new); case WildcardType wt -> throw new IllegalArgumentException("Cannot parse wildcard type: " + wt); default -> throw new IllegalArgumentException("Unsupported type: " + type); @@ -104,7 +96,7 @@ private static Maybe parseTagType(DeclaredType t) { } private Maybe> parseAppType(DeclaredType t) { - return t.getTypeArguments().size() == 2 && isAppType(types.erasure(t)) + return t.getTypeArguments().size() == 2 && isAppType(erasure(t)) ? Maybe.just(new Pair<>(t.getTypeArguments().get(0), t.getTypeArguments().get(1))) : Maybe.nothing(); } @@ -116,6 +108,10 @@ private boolean isAppType(TypeMirror erasure) { || te.getQualifiedName().contentEquals(TPAR_CLASS.getName())); } + private DeclaredType erasure(DeclaredType t) { + return t.asElement().asType() instanceof DeclaredType typeCtor ? typeCtor : t; + } + private static Function> isInstanceOf(Class cls) { return u -> cls.isInstance(u) ? Stream.of(cls.cast(u)) : Stream.empty(); } diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index c773b03..13bafff 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -11,7 +11,6 @@ import javax.lang.model.element.TypeElement; import javax.lang.model.type.DeclaredType; import javax.lang.model.type.TypeMirror; -import javax.lang.model.util.Types; import javax.tools.Diagnostic; public final class WitnessResolutionChecker implements Plugin { @@ -44,29 +43,23 @@ public void finished(TaskEvent e) { return; } - Trees trees = Trees.instance(task); - new WitnessCallScanner(trees, task.getTypes()).scan(e.getCompilationUnit(), trees); + new WitnessCallScanner(Trees.instance(task)).scan(e.getCompilationUnit(), null); } }); } /** Scanner that finds calls to TypeClasses.witness() and validates them. */ - private static class WitnessCallScanner extends TreePathScanner { + private static class WitnessCallScanner extends TreePathScanner { private final Trees trees; private final StaticWitnessSystem system; - WitnessCallScanner(Trees trees, Types types) { + private WitnessCallScanner(Trees trees) { this.trees = trees; - this.system = new StaticWitnessSystem(types); + this.system = new StaticWitnessSystem(); } @Override - public Void visitClass(ClassTree node, Trees trees) { - return super.visitClass(node, trees); - } - - @Override - public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { + public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { Element element = trees.getElement(getCurrentPath()); if (isMethodCall(WITNESS_METHOD, element)) { @@ -94,7 +87,7 @@ public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { this.trees.printMessage( Diagnostic.Kind.ERROR, "Failed to resolve witness for type: " - + target.format() + + witnessTypeMirror + "\nReason: " + error.format(), getCurrentPath().getLeaf(), @@ -107,7 +100,7 @@ public Void visitMethodInvocation(MethodInvocationTree node, Trees trees) { } } - return super.visitMethodInvocation(node, trees); + return super.visitMethodInvocation(node, arg); } } From 7fc6dfa92db5e9bb2a5c4016cc1ea7c4d94b2e9a Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 20:35:38 +0100 Subject: [PATCH 5/8] Simplify tree matching code --- .../processor/WitnessResolutionChecker.java | 162 ++++++++++++------ 1 file changed, 114 insertions(+), 48 deletions(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index 13bafff..e217e54 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -1,12 +1,24 @@ package com.garciat.typeclasses.processor; +import static com.garciat.typeclasses.types.Unit.unit; + import com.garciat.typeclasses.TypeClasses; import com.garciat.typeclasses.api.Ty; -import com.garciat.typeclasses.types.Either; -import com.sun.source.tree.*; -import com.sun.source.util.*; +import com.garciat.typeclasses.types.Maybe; +import com.garciat.typeclasses.types.Unit; +import com.sun.source.tree.ClassTree; +import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.MethodInvocationTree; +import com.sun.source.tree.NewClassTree; +import com.sun.source.tree.Tree; +import com.sun.source.util.JavacTask; +import com.sun.source.util.Plugin; +import com.sun.source.util.TaskEvent; +import com.sun.source.util.TaskListener; +import com.sun.source.util.TreePath; +import com.sun.source.util.TreePathScanner; +import com.sun.source.util.Trees; import java.lang.reflect.Method; -import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; import javax.lang.model.type.DeclaredType; @@ -60,54 +72,108 @@ private WitnessCallScanner(Trees trees) { @Override public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { - Element element = trees.getElement(getCurrentPath()); - - if (isMethodCall(WITNESS_METHOD, element)) { - // Found a call to TypeClasses.witness() - // The first argument is expected to be of the form "new Ty<>() {}" - ExpressionTree firstArg = node.getArguments().getFirst(); - - // Check if it's a "new Ty<>() {}" anonymous class creation - if (firstArg instanceof NewClassTree newClass) { - Tree tyApp = newClass.getClassBody().getImplementsClause().getFirst(); - - TypeMirror typeMirror = - trees.getTypeMirror(trees.getPath(getCurrentPath().getCompilationUnit(), tyApp)); - - // Try to extract witness type and verify resolution - if (typeMirror instanceof DeclaredType declaredType) { - TypeMirror witnessTypeMirror = declaredType.getTypeArguments().getFirst(); - - ParsedType target = system.parse(witnessTypeMirror); - - switch (WitnessResolution.resolve(system, target)) { - case Either.Left< - WitnessResolution.ResolutionError, WitnessResolution.InstantiationPlan>( - var error) -> - this.trees.printMessage( - Diagnostic.Kind.ERROR, - "Failed to resolve witness for type: " - + witnessTypeMirror - + "\nReason: " - + error.format(), - getCurrentPath().getLeaf(), - getCurrentPath().getCompilationUnit()); - case Either.Right< - WitnessResolution.ResolutionError, WitnessResolution.InstantiationPlan> - ignore -> {} - } - } - } - } + Parser.unaryMethodCallArgument(WITNESS_METHOD) + .flatMap(Parser.newAnonymousClassBody()) + .flatMap(Parser.singleImplementsClause()) + .flatMap(Parser.treeTypeMirror()) + .flatMap(Parser.rawTypeMatches(Ty.class)) + .flatMap(Parser.unaryTypeArgument()) + .parse(trees, getCurrentPath(), node) + .fold( + Unit::unit, + witnessType -> + WitnessResolution.resolve(system, system.parse(witnessType)) + .fold( + error -> { + this.trees.printMessage( + Diagnostic.Kind.ERROR, + "Failed to resolve witness for type: " + + witnessType + + "\nReason: " + + error.format(), + getCurrentPath().getLeaf(), + getCurrentPath().getCompilationUnit()); + return unit(); + }, + plan -> unit())); return super.visitMethodInvocation(node, arg); } } +} + +interface Parser { + Maybe parse(Trees trees, TreePath current, T input); + + default Parser flatMap(Parser next) { + return (trees, current, input) -> + this.parse(trees, current, input).flatMap(r -> next.parse(trees, current, r)); + } + + static Parser unaryMethodCallArgument(Method target) { + return (trees, current, input) -> { + if (trees.getElement(current) instanceof ExecutableElement method + && method.getSimpleName().contentEquals(target.getName()) + && method.getEnclosingElement() instanceof TypeElement methodOwner + && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName()) + && input.getArguments().size() == 1) { + return Maybe.just(input.getArguments().getFirst()); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser newAnonymousClassBody() { + return (trees, current, input) -> { + if (input instanceof NewClassTree newClass && newClass.getClassBody() != null) { + return Maybe.just(newClass.getClassBody()); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser singleImplementsClause() { + return (trees, current, input) -> { + if (input.getImplementsClause() != null && input.getImplementsClause().size() == 1) { + return Maybe.just(input.getImplementsClause().getFirst()); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser treeTypeMirror() { + return (trees, current, input) -> { + try { + TypeMirror typeMirror = + trees.getTypeMirror(trees.getPath(current.getCompilationUnit(), input)); + return Maybe.just(typeMirror); + } catch (IllegalArgumentException e) { + return Maybe.nothing(); + } + }; + } + + static Parser rawTypeMatches(Class cls) { + return (trees, current, input) -> { + if (input instanceof DeclaredType declaredType + && declaredType.asElement() instanceof TypeElement typeElement + && typeElement.getQualifiedName().contentEquals(cls.getName())) { + return Maybe.just(declaredType); + } + return Maybe.nothing(); + }; + } - private static boolean isMethodCall(Method target, Element element) { - return element instanceof ExecutableElement method - && method.getSimpleName().contentEquals(target.getName()) - && method.getEnclosingElement() instanceof TypeElement methodOwner - && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName()); + static Parser unaryTypeArgument() { + return (trees, current, input) -> { + if (input.getTypeArguments().size() == 1) { + return Maybe.just(input.getTypeArguments().getFirst()); + } else { + return Maybe.nothing(); + } + }; } } From e31bf617a1b39459e6569d6e243d760641dda68a Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 20:38:53 +0100 Subject: [PATCH 6/8] Remove unused --- .../garciat/typeclasses/processor/WitnessResolutionChecker.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index e217e54..b551c56 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -110,7 +110,7 @@ default Parser flatMap(Parser next) { this.parse(trees, current, input).flatMap(r -> next.parse(trees, current, r)); } - static Parser unaryMethodCallArgument(Method target) { + static Parser unaryMethodCallArgument(Method target) { return (trees, current, input) -> { if (trees.getElement(current) instanceof ExecutableElement method && method.getSimpleName().contentEquals(target.getName()) From 3204e4f6eae1df1309ef21063a65c280e0bc3e0d Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Tue, 16 Dec 2025 21:14:23 +0100 Subject: [PATCH 7/8] Split method parser --- .../processor/WitnessResolutionChecker.java | 53 ++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index b551c56..7b6697e 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -19,6 +19,7 @@ import com.sun.source.util.TreePathScanner; import com.sun.source.util.Trees; import java.lang.reflect.Method; +import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; import javax.lang.model.type.DeclaredType; @@ -72,7 +73,11 @@ private WitnessCallScanner(Trees trees) { @Override public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { - Parser.unaryMethodCallArgument(WITNESS_METHOD) + Parser.unaryCallArgument() + .guard( + Parser.currentElement() + .flatMap(Parser.executableElement()) + .flatMap(Parser.methodMatches(WITNESS_METHOD))) .flatMap(Parser.newAnonymousClassBody()) .flatMap(Parser.singleImplementsClause()) .flatMap(Parser.treeTypeMirror()) @@ -110,13 +115,47 @@ default Parser flatMap(Parser next) { this.parse(trees, current, input).flatMap(r -> next.parse(trees, current, r)); } - static Parser unaryMethodCallArgument(Method target) { + default Parser guard(Parser predicate) { + return (trees, current, input) -> + predicate.parse(trees, current, input).flatMap(ignore -> this.parse(trees, current, input)); + } + + static Parser currentElement() { + return (trees, current, input) -> { + Element element = trees.getElement(current); + if (element != null) { + return Maybe.just(element); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser executableElement() { + return (trees, current, input) -> { + if (input instanceof ExecutableElement method) { + return Maybe.just(method); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser methodMatches(Method target) { + return (trees, current, input) -> { + if (input.getSimpleName().contentEquals(target.getName()) + && input.getEnclosingElement() instanceof TypeElement methodOwner + && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName())) { + return Maybe.just(input); + } else { + return Maybe.nothing(); + } + }; + } + + static Parser unaryCallArgument() { return (trees, current, input) -> { - if (trees.getElement(current) instanceof ExecutableElement method - && method.getSimpleName().contentEquals(target.getName()) - && method.getEnclosingElement() instanceof TypeElement methodOwner - && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName()) - && input.getArguments().size() == 1) { + if (input.getArguments().size() == 1) { return Maybe.just(input.getArguments().getFirst()); } else { return Maybe.nothing(); From 0ede5363cae3f11b4de5f0ab1d08a174764438e1 Mon Sep 17 00:00:00 2001 From: Gabriel Garcia Date: Wed, 17 Dec 2025 09:08:33 +0100 Subject: [PATCH 8/8] Split parsers --- .../processor/WitnessResolutionChecker.java | 124 ++++++++++-------- 1 file changed, 72 insertions(+), 52 deletions(-) diff --git a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java index 7b6697e..a367f43 100644 --- a/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java +++ b/src/main/java/com/garciat/typeclasses/processor/WitnessResolutionChecker.java @@ -19,6 +19,10 @@ import com.sun.source.util.TreePathScanner; import com.sun.source.util.Trees; import java.lang.reflect.Method; +import java.util.List; +import java.util.Objects; +import java.util.function.Function; +import java.util.function.Predicate; import javax.lang.model.element.Element; import javax.lang.model.element.ExecutableElement; import javax.lang.model.element.TypeElement; @@ -73,11 +77,11 @@ private WitnessCallScanner(Trees trees) { @Override public Void visitMethodInvocation(MethodInvocationTree node, Void arg) { - Parser.unaryCallArgument() + Parser.identity() .guard( Parser.currentElement() - .flatMap(Parser.executableElement()) .flatMap(Parser.methodMatches(WITNESS_METHOD))) + .flatMap(Parser.unaryCallArgument()) .flatMap(Parser.newAnonymousClassBody()) .flatMap(Parser.singleImplementsClause()) .flatMap(Parser.treeTypeMirror()) @@ -115,72 +119,90 @@ default Parser flatMap(Parser next) { this.parse(trees, current, input).flatMap(r -> next.parse(trees, current, r)); } - default Parser guard(Parser predicate) { + default Parser map(Function mapper) { + return flatMap(mapping(mapper)); + } + + default Parser filter(Predicate predicate) { + return flatMap(filtering(predicate)); + } + + default Parser guard(Parser predicate) { return (trees, current, input) -> - predicate.parse(trees, current, input).flatMap(ignore -> this.parse(trees, current, input)); + this.parse(trees, current, input) + .flatMap(r -> predicate.parse(trees, current, r).map(x -> r)); } - static Parser currentElement() { + static Parser identity() { + return (trees, current, input) -> Maybe.just(input); + } + + static Parser mapping(Function mapper) { + return (trees, current, input) -> Maybe.just(mapper.apply(input)); + } + + static Parser filtering(Predicate predicate) { return (trees, current, input) -> { - Element element = trees.getElement(current); - if (element != null) { - return Maybe.just(element); + if (predicate.test(input)) { + return Maybe.just(input); } else { return Maybe.nothing(); } }; } - static Parser executableElement() { + static Parser notNull() { + return filtering(Objects::nonNull); + } + + static Parser as(Class cls) { return (trees, current, input) -> { - if (input instanceof ExecutableElement method) { - return Maybe.just(method); + if (cls.isInstance(input)) { + return Maybe.just(cls.cast(input)); } else { return Maybe.nothing(); } }; } - static Parser methodMatches(Method target) { + static Parser currentElement() { return (trees, current, input) -> { - if (input.getSimpleName().contentEquals(target.getName()) - && input.getEnclosingElement() instanceof TypeElement methodOwner - && methodOwner.getQualifiedName().contentEquals(target.getDeclaringClass().getName())) { - return Maybe.just(input); + Element element = trees.getElement(current); + if (element != null) { + return Maybe.just(element); } else { return Maybe.nothing(); } }; } + static Parser methodMatches(Method target) { + return Parser.as(ExecutableElement.class) + .filter(m -> m.getSimpleName().contentEquals(target.getName())) + .guard( + mapping(ExecutableElement::getEnclosingElement) + .flatMap(as(TypeElement.class)) + .map(TypeElement::getQualifiedName) + .filter(name -> name.contentEquals(target.getDeclaringClass().getName()))); + } + static Parser unaryCallArgument() { - return (trees, current, input) -> { - if (input.getArguments().size() == 1) { - return Maybe.just(input.getArguments().getFirst()); - } else { - return Maybe.nothing(); - } - }; + return mapping(MethodInvocationTree::getArguments) + .filter(list -> list.size() == 1) + .map(List::getFirst); } static Parser newAnonymousClassBody() { - return (trees, current, input) -> { - if (input instanceof NewClassTree newClass && newClass.getClassBody() != null) { - return Maybe.just(newClass.getClassBody()); - } else { - return Maybe.nothing(); - } - }; + return Parser.as(NewClassTree.class) + .map(NewClassTree::getClassBody) + .flatMap(notNull()); } static Parser singleImplementsClause() { - return (trees, current, input) -> { - if (input.getImplementsClause() != null && input.getImplementsClause().size() == 1) { - return Maybe.just(input.getImplementsClause().getFirst()); - } else { - return Maybe.nothing(); - } - }; + return mapping(ClassTree::getImplementsClause) + .flatMap(notNull()) + .filter(list -> list.size() == 1) + .map(List::getFirst); } static Parser treeTypeMirror() { @@ -196,23 +218,21 @@ static Parser treeTypeMirror() { } static Parser rawTypeMatches(Class cls) { - return (trees, current, input) -> { - if (input instanceof DeclaredType declaredType - && declaredType.asElement() instanceof TypeElement typeElement - && typeElement.getQualifiedName().contentEquals(cls.getName())) { - return Maybe.just(declaredType); - } - return Maybe.nothing(); - }; + return Parser.as(DeclaredType.class) + .guard( + declaredTypeElement() + .flatMap(as(TypeElement.class)) + .map(TypeElement::getQualifiedName) + .filter(name -> name.contentEquals(cls.getName()))); } static Parser unaryTypeArgument() { - return (trees, current, input) -> { - if (input.getTypeArguments().size() == 1) { - return Maybe.just(input.getTypeArguments().getFirst()); - } else { - return Maybe.nothing(); - } - }; + return mapping(DeclaredType::getTypeArguments) + .filter(list -> list.size() == 1) + .map(List::getFirst); + } + + static Parser declaredTypeElement() { + return mapping(DeclaredType::asElement).flatMap(notNull()); } }