diff --git a/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java b/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java
index 5c98ad332..386ea6d77 100644
--- a/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java
+++ b/src/main/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitor.java
@@ -29,6 +29,7 @@
import org.openrewrite.java.tree.JavaType.FullyQualified;
import org.openrewrite.marker.Markers;
+import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
@@ -143,6 +144,33 @@ public J visitVariableDeclarations(VariableDeclarations multiVariable, Execution
return mv;
}
+ private static
void addImportsForType(JavaTemplate.Builder template, JavaType type, JavaVisitor
visitor) {
+ List fqns = collectFullyQualifiedNames(type);
+ template.imports(fqns.toArray(new String[0]));
+ for (String fqn : fqns) {
+ visitor.maybeAddImport(fqn);
+ }
+ }
+
+ private static List collectFullyQualifiedNames(JavaType type) {
+ List fqns = new ArrayList<>();
+ if (type instanceof JavaType.Parameterized) {
+ JavaType.Parameterized pt = (JavaType.Parameterized) type;
+ fqns.add(pt.getType().getFullyQualifiedName());
+ pt.getTypeParameters().forEach(tp -> fqns.addAll(collectFullyQualifiedNames(tp)));
+ } else if (type instanceof JavaType.GenericTypeVariable) {
+ ((JavaType.GenericTypeVariable) type).getBounds().forEach(b -> fqns.addAll(collectFullyQualifiedNames(b)));
+ } else if (type instanceof JavaType.Array) {
+ fqns.addAll(collectFullyQualifiedNames(((JavaType.Array) type).getElemType()));
+ } else {
+ FullyQualified fq = TypeUtils.asFullyQualified(type);
+ if (fq != null) {
+ fqns.add(fq.getFullyQualifiedName());
+ }
+ }
+ return fqns;
+ }
+
@RequiredArgsConstructor
private static class AddConstructorVisitor extends JavaVisitor {
@@ -152,58 +180,49 @@ private static class AddConstructorVisitor extends JavaVisitor
@Override
public J visitBlock(Block block, ExecutionContext p) {
- if (getCursor().getParent() != null) {
- Object n = getCursor().getParent().getValue();
- if (n instanceof ClassDeclaration) {
- ClassDeclaration classDecl = (ClassDeclaration) n;
- JavaType.FullyQualified typeFqn = TypeUtils.asFullyQualified(type.getType());
- if (typeFqn != null && classDecl.getKind() == ClassDeclaration.Kind.Type.Class && className.equals(classDecl.getSimpleName())) {
- JavaTemplate.Builder template = JavaTemplate.builder("" +
- classDecl.getSimpleName() + "(" + typeFqn.getClassName() + " " + fieldName + ") {\n" +
- "this." + fieldName + " = " + fieldName + ";\n" +
- "}\n"
- ).contextSensitive();
- FullyQualified fq = TypeUtils.asFullyQualified(type.getType());
- if (fq != null) {
- template.imports(fq.getFullyQualifiedName());
- maybeAddImport(fq);
- }
- Optional firstMethod = block.getStatements().stream().filter(MethodDeclaration.class::isInstance).findFirst();
-
- return firstMethod.map(statement ->
- (J) template.build()
- .apply(getCursor(),
- statement.getCoordinates().before()
- )
- )
- .orElseGet(() ->
- template.build()
- .apply(
- getCursor(),
- block.getCoordinates().lastStatement()
- )
- );
- }
- }
+ if (getCursor().getParent() == null
+ || !(getCursor().getParent().getValue() instanceof ClassDeclaration)) {
+ return block;
}
- return block;
+ ClassDeclaration classDecl = (ClassDeclaration) getCursor().getParent().getValue();
+ JavaType fieldType = type.getType();
+ if (classDecl.getKind() != ClassDeclaration.Kind.Type.Class
+ || !className.equals(classDecl.getSimpleName())
+ || fieldType == null
+ || fieldType instanceof JavaType.Primitive) {
+ return block;
+ }
+
+ String fieldTypeStr = type.toString();
+ JavaTemplate.Builder template = JavaTemplate.builder(
+ classDecl.getSimpleName() + "(" + fieldTypeStr + " " + fieldName + ") {\n" +
+ "this." + fieldName + " = " + fieldName + ";\n" +
+ "}\n"
+ ).contextSensitive();
+ addImportsForType(template, fieldType, this);
+
+ Optional firstMethod = block.getStatements().stream()
+ .filter(MethodDeclaration.class::isInstance).findFirst();
+ return firstMethod.map(statement ->
+ (J) template.build()
+ .apply(getCursor(), statement.getCoordinates().before())
+ )
+ .orElseGet(() ->
+ template.build()
+ .apply(getCursor(), block.getCoordinates().lastStatement())
+ );
}
}
private static class AddConstructorParameterAndAssignment extends JavaIsoVisitor {
private final MethodDeclaration constructor;
private final String fieldName;
- private final String methodType;
+ private final TypeTree type;
public AddConstructorParameterAndAssignment(MethodDeclaration constructor, String fieldName, TypeTree type) {
this.constructor = constructor;
this.fieldName = fieldName;
- JavaType.FullyQualified fq = TypeUtils.asFullyQualified(type.getType());
- if (fq != null) {
- methodType = fq.getClassName();
- } else {
- throw new IllegalArgumentException("Unable to determine parameter type");
- }
+ this.type = type;
}
@Override
@@ -212,11 +231,12 @@ public MethodDeclaration visitMethodDeclaration(MethodDeclaration method, Execut
if (md == this.constructor && md.getBody() != null) {
List params = md.getParameters().stream().filter(s -> !(s instanceof J.Empty)).collect(toList());
String paramsStr = Stream.concat(params.stream()
- .map(s -> "#{}"), Stream.of(methodType + " " + fieldName)).collect(joining(", "));
+ .map(s -> "#{}"), Stream.of(type.toString() + " " + fieldName)).collect(joining(", "));
- md = JavaTemplate.builder(paramsStr)
- .contextSensitive()
- .build()
+ JavaTemplate.Builder templateBuilder = JavaTemplate.builder(paramsStr)
+ .contextSensitive();
+ addImportsForType(templateBuilder, type.getType(), this);
+ md = templateBuilder.build()
.apply(
getCursor(),
md.getCoordinates().replaceParameters(),
diff --git a/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java b/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java
index aeb6abe0b..a54941ce7 100644
--- a/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java
+++ b/src/test/java/org/openrewrite/java/spring/AutowiredFieldIntoConstructorParameterVisitorTest.java
@@ -21,6 +21,7 @@
import org.openrewrite.java.JavaParser;
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;
+import org.openrewrite.test.TypeValidation;
import static org.openrewrite.java.Assertions.java;
import static org.openrewrite.test.RewriteTest.toRecipe;
@@ -413,4 +414,448 @@ public class A {
);
}
+ @Test
+ void fieldWithGenericTypeIntoNewConstructor() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+
+ public class A {
+
+ private final List a;
+
+ A(List a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithGenericTypeIntoExistingConstructor() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List a;
+
+ A() {
+ }
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+
+ public class A {
+
+ private final List a;
+
+ A(List a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithGenericTypeIntoExistingConstructorWithParams() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List a;
+
+ A() {
+ }
+
+ @Autowired
+ A(long l) {
+ }
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ private final List a;
+
+ A() {
+ }
+
+ @Autowired
+ A(long l, List a) {
+ this.a = a;
+ }
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithNestedGenericType() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import java.util.Map;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private Map> a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+ import java.util.Map;
+
+ public class A {
+
+ private final Map> a;
+
+ A(Map> a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithWildcardGenericType() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List> a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+
+ public class A {
+
+ private final List> a;
+
+ A(List> a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithUpperBoundedWildcard() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List extends Number> a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+
+ public class A {
+
+ private final List extends Number> a;
+
+ A(List extends Number> a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithUserDefinedGenericType() {
+ //language=java
+ rewriteRun(
+ spec -> spec.afterTypeValidationOptions(TypeValidation.none()),
+ java(
+ """
+ package demo.model;
+
+ public class MyConfig {
+ }
+ """
+ ),
+ java(
+ """
+ package demo.service;
+
+ public class MyService {
+ }
+ """
+ ),
+ java(
+ """
+ package demo;
+
+ import demo.model.MyConfig;
+ import demo.service.MyService;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private MyService a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import demo.model.MyConfig;
+ import demo.service.MyService;
+
+ public class A {
+
+ private final MyService a;
+
+ A(MyService a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithLowerBoundedWildcard() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import java.util.List;
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private List super Integer> a;
+
+ }
+ """,
+ """
+ package demo;
+
+ import java.util.List;
+
+ public class A {
+
+ private final List super Integer> a;
+
+ A(List super Integer> a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithArrayType() {
+ //language=java
+ rewriteRun(
+ spec -> spec.afterTypeValidationOptions(TypeValidation.all().identifiers(false)),
+ java(
+ """
+ package demo;
+
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired
+ private String[] a;
+
+ }
+ """,
+ """
+ package demo;
+
+ public class A {
+
+ private final String[] a;
+
+ A(String[] a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithAutowiredRequired() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import org.springframework.beans.factory.annotation.Autowired;
+
+ public class A {
+
+ @Autowired(required = false)
+ private String a;
+
+ A() {
+ }
+
+ }
+ """,
+ """
+ package demo;
+
+ public class A {
+
+ private final String a;
+
+ A(String a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
+ @Test
+ void fieldWithMultipleAnnotations() {
+ //language=java
+ rewriteRun(
+ java(
+ """
+ package demo;
+
+ import org.springframework.beans.factory.annotation.Autowired;
+ import org.springframework.beans.factory.annotation.Qualifier;
+
+ public class A {
+
+ @Autowired
+ @Qualifier("myBean")
+ private String a;
+
+ A() {
+ }
+
+ }
+ """,
+ """
+ package demo;
+
+ import org.springframework.beans.factory.annotation.Qualifier;
+
+ public class A {
+
+ @Qualifier("myBean")
+ private final String a;
+
+ A(String a) {
+ this.a = a;
+ }
+
+ }
+ """
+ )
+ );
+ }
+
}