Skip to content

Commit c8af522

Browse files
committed
Rework GoStructInitializationInspection
fixes #2819
1 parent eb669fc commit c8af522

23 files changed

+306
-97
lines changed

src/com/goide/inspections/GoStructInitializationInspection.java

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -22,94 +22,131 @@
2222
import com.goide.util.GoUtil;
2323
import com.intellij.codeInspection.*;
2424
import com.intellij.codeInspection.ui.SingleCheckboxOptionsPanel;
25-
import com.intellij.openapi.progress.ProgressManager;
2625
import com.intellij.openapi.project.Project;
26+
import com.intellij.openapi.util.Comparing;
2727
import com.intellij.openapi.util.InvalidDataException;
2828
import com.intellij.openapi.util.WriteExternalException;
29-
import com.intellij.psi.PsiElement;
30-
import com.intellij.psi.util.PsiTreeUtil;
31-
import com.intellij.util.containers.ContainerUtil;
29+
import com.intellij.util.ObjectUtils;
3230
import org.jdom.Element;
31+
import org.jetbrains.annotations.Contract;
3332
import org.jetbrains.annotations.NotNull;
3433
import org.jetbrains.annotations.Nullable;
3534

3635
import javax.swing.*;
3736
import java.util.List;
3837

38+
import static com.intellij.openapi.util.NullUtils.hasNull;
39+
import static com.intellij.util.containers.ContainerUtil.*;
40+
import static java.util.stream.Collectors.toList;
41+
import static java.util.stream.IntStream.range;
42+
3943
public class GoStructInitializationInspection extends GoInspectionBase {
40-
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct field";
44+
public static final String REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME = "Replace with named struct fields";
45+
private static final GoReplaceWithNamedStructFieldQuickFix QUICK_FIX = new GoReplaceWithNamedStructFieldQuickFix();
4146
public boolean reportLocalStructs;
4247
/**
43-
* @deprecated use reportLocalStructs
48+
* @deprecated use {@link #reportLocalStructs}
4449
*/
4550
@SuppressWarnings("WeakerAccess") public Boolean reportImportedStructs;
4651

4752
@NotNull
4853
@Override
4954
protected GoVisitor buildGoVisitor(@NotNull ProblemsHolder holder, @NotNull LocalInspectionToolSession session) {
5055
return new GoVisitor() {
56+
5157
@Override
52-
public void visitLiteralValue(@NotNull GoLiteralValue o) {
53-
if (PsiTreeUtil.getParentOfType(o, GoReturnStatement.class, GoShortVarDeclaration.class, GoAssignmentStatement.class) == null) {
54-
return;
55-
}
56-
PsiElement parent = o.getParent();
57-
GoType refType = GoPsiImplUtil.getLiteralType(parent, false);
58-
if (refType instanceof GoStructType) {
59-
processStructType(holder, o, (GoStructType)refType);
58+
public void visitCompositeLit(@NotNull GoCompositeLit compositeLit) {
59+
GoLiteralValue literalValue = compositeLit.getLiteralValue();
60+
GoStructType structType = getStructType(literalValue);
61+
if (structType == null || !isStructImportedOrLocalAllowed(structType, literalValue)) return;
62+
63+
List<String> elementsNames = getNames(literalValue.getElementList());
64+
if (hasNull(elementsNames.toArray()) && areElementsNamesMatchesDefinitions(elementsNames, getFieldDefinitionsNames(structType))) {
65+
holder.registerProblem(literalValue, "Unnamed field initializations", ProblemHighlightType.GENERIC_ERROR_OR_WARNING, QUICK_FIX);
6066
}
6167
}
6268
};
6369
}
6470

65-
@Override
66-
public JComponent createOptionsPanel() {
67-
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
71+
@Nullable
72+
@Contract("null -> null")
73+
private static GoStructType getStructType(@Nullable GoLiteralValue literal) {
74+
return literal != null ? ObjectUtils.tryCast(GoPsiImplUtil.getLiteralType(literal.getParent(), false), GoStructType.class) : null;
6875
}
6976

70-
private void processStructType(@NotNull ProblemsHolder holder, @NotNull GoLiteralValue element, @NotNull GoStructType structType) {
71-
if (reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), element.getContainingFile())) {
72-
processLiteralValue(holder, element, structType.getFieldDeclarationList());
73-
}
77+
private boolean isStructImportedOrLocalAllowed(@NotNull GoStructType structType, @NotNull GoLiteralValue literalValue) {
78+
return reportLocalStructs || !GoUtil.inSamePackage(structType.getContainingFile(), literalValue.getContainingFile());
7479
}
7580

76-
private static void processLiteralValue(@NotNull ProblemsHolder holder,
77-
@NotNull GoLiteralValue o,
78-
@NotNull List<GoFieldDeclaration> fields) {
79-
List<GoElement> vals = o.getElementList();
80-
for (int elemId = 0; elemId < vals.size(); elemId++) {
81-
ProgressManager.checkCanceled();
82-
GoElement element = vals.get(elemId);
83-
if (element.getKey() == null && elemId < fields.size()) {
84-
String structFieldName = getFieldName(fields.get(elemId));
85-
LocalQuickFix[] fixes = structFieldName != null ? new LocalQuickFix[]{new GoReplaceWithNamedStructFieldQuickFix(structFieldName)}
86-
: LocalQuickFix.EMPTY_ARRAY;
87-
holder.registerProblem(element, "Unnamed field initialization", ProblemHighlightType.GENERIC_ERROR_OR_WARNING, fixes);
88-
}
89-
}
81+
@NotNull
82+
private static List<String> getNames(@NotNull List<GoElement> elements) {
83+
return map(elements, element -> {
84+
GoKey key = element.getKey();
85+
return key != null ? key.getText() : null;
86+
});
87+
}
88+
89+
private static boolean areElementsNamesMatchesDefinitions(@NotNull List<String> elementsNames,
90+
@NotNull List<String> fieldDefinitionsNames) {
91+
return range(0, elementsNames.size()).allMatch(i -> isNullOrEqual(elementsNames.get(i), getByIndex(fieldDefinitionsNames, i)));
92+
}
93+
94+
@Contract("null, _ -> true")
95+
private static boolean isNullOrEqual(@Nullable Object o, @Nullable Object objectToCompare) {
96+
return o == null || Comparing.equal(o, objectToCompare);
9097
}
9198

9299
@Nullable
93-
private static String getFieldName(@NotNull GoFieldDeclaration declaration) {
94-
List<GoFieldDefinition> list = declaration.getFieldDefinitionList();
95-
GoFieldDefinition fieldDefinition = ContainerUtil.getFirstItem(list);
96-
return fieldDefinition != null ? fieldDefinition.getIdentifier().getText() : null;
100+
private static String getByIndex(@NotNull List<String> list, int index) {
101+
return 0 <= index && index < list.size() ? list.get(index) : null;
102+
}
103+
104+
@NotNull
105+
private static List<String> getFieldDefinitionsNames(@NotNull GoStructType type) {
106+
return type.getFieldDeclarationList().stream()
107+
.flatMap(declaration -> getFieldDefinitionsNames(declaration).stream())
108+
.collect(toList());
109+
}
110+
111+
@NotNull
112+
private static List<String> getFieldDefinitionsNames(@NotNull GoFieldDeclaration declaration) {
113+
GoAnonymousFieldDefinition definition = declaration.getAnonymousFieldDefinition();
114+
return definition != null ? list(definition.getName()) : map(declaration.getFieldDefinitionList(), GoNamedElement::getName);
115+
}
116+
117+
@Override
118+
public JComponent createOptionsPanel() {
119+
return new SingleCheckboxOptionsPanel("Report for local type definitions as well", this, "reportLocalStructs");
97120
}
98121

99122
private static class GoReplaceWithNamedStructFieldQuickFix extends LocalQuickFixBase {
100-
private String myStructField;
101123

102-
public GoReplaceWithNamedStructFieldQuickFix(@NotNull String structField) {
124+
public GoReplaceWithNamedStructFieldQuickFix() {
103125
super(REPLACE_WITH_NAMED_STRUCT_FIELD_FIX_NAME);
104-
myStructField = structField;
105126
}
106127

107128
@Override
108129
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
109-
PsiElement startElement = descriptor.getStartElement();
110-
if (startElement instanceof GoElement) {
111-
startElement.replace(GoElementFactory.createLiteralValueElement(project, myStructField, startElement.getText()));
112-
}
130+
GoLiteralValue literal = ObjectUtils.tryCast(descriptor.getStartElement(), GoLiteralValue.class);
131+
GoStructType structType = getStructType(literal);
132+
List<GoElement> elements = structType != null ? literal.getElementList() : emptyList();
133+
List<String> fieldDefinitionNames = structType != null ? getFieldDefinitionsNames(structType) : emptyList();
134+
if (!areElementsNamesMatchesDefinitions(getNames(elements), fieldDefinitionNames)) return;
135+
replaceElementsByNamed(elements, fieldDefinitionNames, project);
136+
}
137+
}
138+
139+
private static void replaceElementsByNamed(@NotNull List<GoElement> elements,
140+
@NotNull List<String> fieldDefinitionNames,
141+
@NotNull Project project) {
142+
for (int i = 0; i < elements.size(); i++) {
143+
GoElement element = elements.get(i);
144+
String fieldDefinitionName = getByIndex(fieldDefinitionNames, i);
145+
GoValue value = fieldDefinitionName != null && element.getKey() == null ? element.getValue() : null;
146+
if (value == null) continue;
147+
148+
GoElement namedElement = GoElementFactory.createLiteralValueElement(project, fieldDefinitionName, value.getText());
149+
element.replace(namedElement);
113150
}
114151
}
115152

src/com/goide/psi/impl/GoElementFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ public static GoType createType(@NotNull Project project, @NotNull String text)
256256
return PsiTreeUtil.findChildOfType(file, GoType.class);
257257
}
258258

259-
public static PsiElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
259+
public static GoElement createLiteralValueElement(@NotNull Project project, @NotNull String key, @NotNull String value) {
260260
GoFile file = createFileFromText(project, "package a; var _ = struct { a string } { " + key + ": " + value + " }");
261261
return PsiTreeUtil.findChildOfType(file, GoElement.class);
262262
}

testData/inspections/go-struct-initialization/quickFix.go

Lines changed: 0 additions & 9 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package foo
2+
3+
type S struct {
4+
X string
5+
string
6+
Y int
7+
}
8+
func main() {
9+
var s S
10+
s = S{X: "X", string: "a", Y: 1}
11+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package foo
2+
3+
type S struct {
4+
X string
5+
string
6+
Y int
7+
}
8+
func main() {
9+
var s S
10+
s = S<weak_warning descr="Unnamed field initializations">{<caret>"X", "a", Y: 1}</weak_warning>
11+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
s := S{X: 1, Y: 0, 2}
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
s := S<weak_warning descr="Unnamed field initializations">{<caret>1, 0, 2}</weak_warning>
8+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
s := S{<caret>1, 0, X: 2}
8+
9+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package foo
2+
3+
func main() {
4+
type B struct {
5+
Y int
6+
}
7+
8+
type S struct {
9+
X int
10+
B
11+
Z int
12+
}
13+
14+
s := S{X: 1, B: B{Y: 2}, Z: 3}
15+
print(s.B.Y)
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package foo
2+
3+
func main() {
4+
type B struct {
5+
Y int
6+
}
7+
8+
type S struct {
9+
X int
10+
B
11+
Z int
12+
}
13+
14+
s := S<weak_warning descr="Unnamed field initializations">{1<caret>, B{Y: 2}, 3}</weak_warning>
15+
print(s.B.Y)
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package foo
2+
3+
func main() {
4+
type B struct {
5+
Y int
6+
}
7+
8+
type S struct {
9+
X int
10+
b B
11+
Z int
12+
}
13+
14+
s := S{X: 1, b: B{Y: 2}}
15+
print(s.b.Y)
16+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package foo
2+
3+
func main() {
4+
type B struct {
5+
Y int
6+
}
7+
8+
type S struct {
9+
X int
10+
b B
11+
Z int
12+
}
13+
14+
s := S<weak_warning descr="Unnamed field initializations">{1<caret>, B{Y: 2}}</weak_warning>
15+
print(s.b.Y)
16+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
var s S
8+
s = S<weak_warning descr="Unnamed field initializations">{0, 0}</weak_warning>
9+
s = S<weak_warning descr="Unnamed field initializations">{X: 0, 0}</weak_warning>
10+
s = S<weak_warning descr="Unnamed field initializations">{0, Y: 0}</weak_warning>
11+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
var s S
8+
s = S{X: 0, Y: 0}
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package foo
2+
3+
type S struct {
4+
X, Y int
5+
}
6+
func main() {
7+
var s S
8+
s = S<weak_warning descr="Unnamed field initializations">{<caret>0, 0}</weak_warning>
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package foo
2+
3+
import "io"
4+
5+
func _() {
6+
_ = io.LimitedReader<weak_warning descr="Unnamed field initializations">{
7+
<caret>nil,
8+
}</weak_warning>
9+
}

testData/inspections/go-struct-initialization/uninitializedStructImportedOnly.go renamed to testData/inspections/struct-initialization/uninitializedStructImportedOnly.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,15 @@ func _() {
113113
}
114114

115115
// bad defs
116-
_ = io.LimitedReader{
117-
<weak_warning descr="Unnamed field initialization">nil</weak_warning>,
118-
}
119-
_ = os.LinkError{
120-
<weak_warning descr="Unnamed field initialization">"string"</weak_warning>,
121-
<weak_warning descr="Unnamed field initialization">"string"</weak_warning>,
122-
<weak_warning descr="Unnamed field initialization">"string"</weak_warning>,
123-
<weak_warning descr="Unnamed field initialization">nil</weak_warning>,
124-
}
116+
_ = io.LimitedReader<weak_warning descr="Unnamed field initializations">{
117+
nil,
118+
}</weak_warning>
119+
_ = os.LinkError<weak_warning descr="Unnamed field initializations">{
120+
"string",
121+
"string",
122+
"string",
123+
nil,
124+
}</weak_warning>
125125
}
126126

127127
type assertion struct {

0 commit comments

Comments
 (0)