diff --git a/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java b/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java index 5c98ad332..386ea6d77 100644 --- a/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java +++ b/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java @@ -29,6 +29,7 @@ import org.openrewrite.java.tree.JavaType.FullyQualified; import org.openrewrite.marker.Markers; +import java.util.ArrayList; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -143,6 +144,33 @@ public J visitVariableDeclarations(VariableDeclarations multiVariable, Execution return mv; } + private static

void addImportsForType(JavaTemplate.Builder template, JavaType type, JavaVisitor

visitor) { + List fqns = collectFullyQualifiedNames(type); + template.imports(fqns.toArray(new String[0])); + for (String fqn : fqns) { + visitor.maybeAddImport(fqn); + } + } + + private static List collectFullyQualifiedNames(JavaType type) { + List fqns = new ArrayList<>(); + if (type instanceof JavaType.Parameterized) { + JavaType.Parameterized pt = (JavaType.Parameterized) type; + fqns.add(pt.getType().getFullyQualifiedName()); + pt.getTypeParameters().forEach(tp -> fqns.addAll(collectFullyQualifiedNames(tp))); + } else if (type instanceof JavaType.GenericTypeVariable) { + ((JavaType.GenericTypeVariable) type).getBounds().forEach(b -> fqns.addAll(collectFullyQualifiedNames(b))); + } else if (type instanceof JavaType.Array) { + fqns.addAll(collectFullyQualifiedNames(((JavaType.Array) type).getElemType())); + } else { + FullyQualified fq = TypeUtils.asFullyQualified(type); + if (fq != null) { + fqns.add(fq.getFullyQualifiedName()); + } + } + return fqns; + } + @RequiredArgsConstructor private static class AddConstructorVisitor extends JavaVisitor { @@ -152,58 +180,49 @@ private static class AddConstructorVisitor extends JavaVisitor @Override public J visitBlock(Block block, ExecutionContext p) { - if (getCursor().getParent() != null) { - Object n = getCursor().getParent().getValue(); - if (n instanceof ClassDeclaration) { - ClassDeclaration classDecl = (ClassDeclaration) n; - JavaType.FullyQualified typeFqn = TypeUtils.asFullyQualified(type.getType()); - if (typeFqn != null && classDecl.getKind() == ClassDeclaration.Kind.Type.Class && className.equals(classDecl.getSimpleName())) { - JavaTemplate.Builder template = JavaTemplate.builder("" + - classDecl.getSimpleName() + "(" + typeFqn.getClassName() + " " + fieldName + ") {\n" + - "this." + fieldName + " = " + fieldName + ";\n" + - "}\n" - ).contextSensitive(); - FullyQualified fq = TypeUtils.asFullyQualified(type.getType()); - if (fq != null) { - template.imports(fq.getFullyQualifiedName()); - maybeAddImport(fq); - } - Optional firstMethod = block.getStatements().stream().filter(MethodDeclaration.class::isInstance).findFirst(); - - return firstMethod.map(statement -> - (J) template.build() - .apply(getCursor(), - statement.getCoordinates().before() - ) - ) - .orElseGet(() -> - template.build() - .apply( - getCursor(), - block.getCoordinates().lastStatement() - ) - ); - } - } + if (getCursor().getParent() == null + || !(getCursor().getParent().getValue() instanceof ClassDeclaration)) { + return block; } - return block; + ClassDeclaration classDecl = (ClassDeclaration) getCursor().getParent().getValue(); + JavaType fieldType = type.getType(); + if (classDecl.getKind() != ClassDeclaration.Kind.Type.Class + || !className.equals(classDecl.getSimpleName()) + || fieldType == null + || fieldType instanceof JavaType.Primitive) { + return block; + } + + String fieldTypeStr = type.toString(); + JavaTemplate.Builder template = JavaTemplate.builder( + classDecl.getSimpleName() + "(" + fieldTypeStr + " " + fieldName + ") {\n" + + "this." + fieldName + " = " + fieldName + ";\n" + + "}\n" + ).contextSensitive(); + addImportsForType(template, fieldType, this); + + Optional firstMethod = block.getStatements().stream() + .filter(MethodDeclaration.class::isInstance).findFirst(); + return firstMethod.map(statement -> + (J) template.build() + .apply(getCursor(), statement.getCoordinates().before()) + ) + .orElseGet(() -> + template.build() + .apply(getCursor(), block.getCoordinates().lastStatement()) + ); } } private static class AddConstructorParameterAndAssignment extends JavaIsoVisitor { private final MethodDeclaration constructor; private final String fieldName; - private final String methodType; + private final TypeTree type; public AddConstructorParameterAndAssignment(MethodDeclaration constructor, String fieldName, TypeTree type) { this.constructor = constructor; this.fieldName = fieldName; - JavaType.FullyQualified fq = TypeUtils.asFullyQualified(type.getType()); - if (fq != null) { - methodType = fq.getClassName(); - } else { - throw new IllegalArgumentException("Unable to determine parameter type"); - } + this.type = type; } @Override @@ -212,11 +231,12 @@ public MethodDeclaration visitMethodDeclaration(MethodDeclaration method, Execut if (md == this.constructor && md.getBody() != null) { List params = md.getParameters().stream().filter(s -> !(s instanceof J.Empty)).collect(toList()); String paramsStr = Stream.concat(params.stream() - .map(s -> "#{}"), Stream.of(methodType + " " + fieldName)).collect(joining(", ")); + .map(s -> "#{}"), Stream.of(type.toString() + " " + fieldName)).collect(joining(", ")); - md = JavaTemplate.builder(paramsStr) - .contextSensitive() - .build() + JavaTemplate.Builder templateBuilder = JavaTemplate.builder(paramsStr) + .contextSensitive(); + addImportsForType(templateBuilder, type.getType(), this); + md = templateBuilder.build() .apply( getCursor(), md.getCoordinates().replaceParameters(), diff --git a/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java b/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java index aeb6abe0b..a54941ce7 100644 --- a/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java +++ b/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java @@ -21,6 +21,7 @@ import org.openrewrite.java.JavaParser; import org.openrewrite.test.RecipeSpec; import org.openrewrite.test.RewriteTest; +import org.openrewrite.test.TypeValidation; import static org.openrewrite.java.Assertions.java; import static org.openrewrite.test.RewriteTest.toRecipe; @@ -413,4 +414,448 @@ public class A { ); } + @Test + void fieldWithGenericTypeIntoNewConstructor() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + } + """, + """ + package demo; + + import java.util.List; + + public class A { + + private final List a; + + A(List a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithGenericTypeIntoExistingConstructor() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + A() { + } + + } + """, + """ + package demo; + + import java.util.List; + + public class A { + + private final List a; + + A(List a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithGenericTypeIntoExistingConstructorWithParams() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + A() { + } + + @Autowired + A(long l) { + } + } + """, + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + private final List a; + + A() { + } + + @Autowired + A(long l, List a) { + this.a = a; + } + } + """ + ) + ); + } + + @Test + void fieldWithNestedGenericType() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import java.util.Map; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private Map> a; + + } + """, + """ + package demo; + + import java.util.List; + import java.util.Map; + + public class A { + + private final Map> a; + + A(Map> a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithWildcardGenericType() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + } + """, + """ + package demo; + + import java.util.List; + + public class A { + + private final List a; + + A(List a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithUpperBoundedWildcard() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + } + """, + """ + package demo; + + import java.util.List; + + public class A { + + private final List a; + + A(List a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithUserDefinedGenericType() { + //language=java + rewriteRun( + spec -> spec.afterTypeValidationOptions(TypeValidation.none()), + java( + """ + package demo.model; + + public class MyConfig { + } + """ + ), + java( + """ + package demo.service; + + public class MyService { + } + """ + ), + java( + """ + package demo; + + import demo.model.MyConfig; + import demo.service.MyService; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private MyService a; + + } + """, + """ + package demo; + + import demo.model.MyConfig; + import demo.service.MyService; + + public class A { + + private final MyService a; + + A(MyService a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithLowerBoundedWildcard() { + //language=java + rewriteRun( + java( + """ + package demo; + + import java.util.List; + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private List a; + + } + """, + """ + package demo; + + import java.util.List; + + public class A { + + private final List a; + + A(List a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithArrayType() { + //language=java + rewriteRun( + spec -> spec.afterTypeValidationOptions(TypeValidation.all().identifiers(false)), + java( + """ + package demo; + + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired + private String[] a; + + } + """, + """ + package demo; + + public class A { + + private final String[] a; + + A(String[] a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithAutowiredRequired() { + //language=java + rewriteRun( + java( + """ + package demo; + + import org.springframework.beans.factory.annotation.Autowired; + + public class A { + + @Autowired(required = false) + private String a; + + A() { + } + + } + """, + """ + package demo; + + public class A { + + private final String a; + + A(String a) { + this.a = a; + } + + } + """ + ) + ); + } + + @Test + void fieldWithMultipleAnnotations() { + //language=java + rewriteRun( + java( + """ + package demo; + + import org.springframework.beans.factory.annotation.Autowired; + import org.springframework.beans.factory.annotation.Qualifier; + + public class A { + + @Autowired + @Qualifier("myBean") + private String a; + + A() { + } + + } + """, + """ + package demo; + + import org.springframework.beans.factory.annotation.Qualifier; + + public class A { + + @Qualifier("myBean") + private final String a; + + A(String a) { + this.a = a; + } + + } + """ + ) + ); + } + }