diff --git a/Engine/Helper.cs b/Engine/Helper.cs
index da5bbb890..66105b1b3 100644
--- a/Engine/Helper.cs
+++ b/Engine/Helper.cs
@@ -679,7 +679,7 @@ internal string GetTypeFromMemberExpressionAstHelper(MemberExpressionAst memberA
///
///
///
- internal Type GetTypeFromAnalysis(VariableExpressionAst varAst, Ast ast)
+ public Type GetTypeFromAnalysis(VariableExpressionAst varAst, Ast ast)
{
try
{
diff --git a/Rules/PossibleIncorrectComparisonWithNull.cs b/Rules/PossibleIncorrectComparisonWithNull.cs
index da97c13f6..cc589f28a 100644
--- a/Rules/PossibleIncorrectComparisonWithNull.cs
+++ b/Rules/PossibleIncorrectComparisonWithNull.cs
@@ -11,6 +11,7 @@
//
using System;
+using System.Linq;
using System.Collections.Generic;
using System.Management.Automation.Language;
using Microsoft.Windows.PowerShell.ScriptAnalyzer.Generic;
@@ -33,17 +34,67 @@ public class PossibleIncorrectComparisonWithNull : IScriptRule {
public IEnumerable AnalyzeScript(Ast ast, string fileName) {
if (ast == null) throw new ArgumentNullException(Strings.NullAstErrorMessage);
- IEnumerable binExpressionAsts = ast.FindAll(testAst => testAst is BinaryExpressionAst, true);
+ IEnumerable binExpressionAsts = ast.FindAll(testAst => testAst is BinaryExpressionAst, false);
- if (binExpressionAsts != null) {
- foreach (BinaryExpressionAst binExpressionAst in binExpressionAsts) {
- if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq)
- || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq))
- && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase)) {
- yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binExpressionAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName);
+ foreach (BinaryExpressionAst binExpressionAst in binExpressionAsts) {
+ if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq)
+ || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq))
+ && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase))
+ {
+ if (IncorrectComparisonWithNull(binExpressionAst, ast))
+ {
+ yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binExpressionAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName);
}
}
}
+
+ IEnumerable funcAsts = ast.FindAll(item => item is FunctionDefinitionAst, true).Union(ast.FindAll(item => item is FunctionMemberAst, true));
+ foreach (Ast funcAst in funcAsts)
+ {
+ IEnumerable binAsts = funcAst.FindAll(item => item is BinaryExpressionAst, true);
+ foreach (BinaryExpressionAst binAst in binAsts)
+ {
+ if (IncorrectComparisonWithNull(binAst, funcAst))
+ {
+ yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName);
+ }
+ }
+ }
+ }
+
+ private bool IncorrectComparisonWithNull(BinaryExpressionAst binExpressionAst, Ast ast)
+ {
+ if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq)
+ || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq))
+ && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase))
+ {
+ if (binExpressionAst.Left.StaticType.IsArray)
+ {
+ return true;
+ }
+ else if (binExpressionAst.Left is VariableExpressionAst)
+ {
+ // ignores if the variable is a special variable
+ if (!Helper.Instance.HasSpecialVars((binExpressionAst.Left as VariableExpressionAst).VariablePath.UserPath))
+ {
+ Type lhsType = Helper.Instance.GetTypeFromAnalysis(binExpressionAst.Left as VariableExpressionAst, ast);
+ if (lhsType == null)
+ {
+ return true;
+ }
+ else if (lhsType.IsArray || lhsType == typeof(object) || lhsType == typeof(Undetermined) || lhsType == typeof(Unreached))
+ {
+ return true;
+ }
+ }
+ }
+ else if (binExpressionAst.Left.StaticType == typeof(object))
+ {
+ return true;
+ }
+ }
+
+ return false;
}
///
diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1
index a2b520739..fbfffd93d 100644
--- a/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1
+++ b/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1
@@ -1,4 +1,23 @@
function CompareWithNull {
if ($DebugPreference -eq $null) {
}
+}
+
+if (@("dfd", "eee") -eq $null)
+{
+}
+
+if ($randomUninitializedVariable -eq $null)
+{
+}
+
+function Test
+{
+ $b = "dd", "ddfd";
+ if ($b -ceq $null)
+ {
+ if ("dd","ee" -eq $null)
+ {
+ }
+ }
}
\ No newline at end of file
diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1
index 4ef850c55..bdcb40902 100644
--- a/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1
+++ b/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1
@@ -7,8 +7,8 @@ $noViolations = Invoke-ScriptAnalyzer $directory\PossibleIncorrectComparisonWith
Describe "PossibleIncorrectComparisonWithNull" {
Context "When there are violations" {
- It "has 1 possible incorrect comparison with null violation" {
- $violations.Count | Should Be 1
+ It "has 4 possible incorrect comparison with null violation" {
+ $violations.Count | Should Be 4
}
It "has the correct description message" {
diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1
index 5ac741e0a..6a34d3c80 100644
--- a/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1
+++ b/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1
@@ -1,4 +1,15 @@
function CompareWithNull {
if ($null -eq $DebugPreference) {
}
+ if ($DebugPreference -eq $null) {
+ }
+}
+
+$a = 3
+
+if ($a -eq $null)
+{
+ if (3 -eq $null)
+ {
+ }
}
\ No newline at end of file