diff --git a/Engine/Helper.cs b/Engine/Helper.cs index da5bbb890..66105b1b3 100644 --- a/Engine/Helper.cs +++ b/Engine/Helper.cs @@ -679,7 +679,7 @@ internal string GetTypeFromMemberExpressionAstHelper(MemberExpressionAst memberA /// /// /// - internal Type GetTypeFromAnalysis(VariableExpressionAst varAst, Ast ast) + public Type GetTypeFromAnalysis(VariableExpressionAst varAst, Ast ast) { try { diff --git a/Rules/PossibleIncorrectComparisonWithNull.cs b/Rules/PossibleIncorrectComparisonWithNull.cs index da97c13f6..cc589f28a 100644 --- a/Rules/PossibleIncorrectComparisonWithNull.cs +++ b/Rules/PossibleIncorrectComparisonWithNull.cs @@ -11,6 +11,7 @@ // using System; +using System.Linq; using System.Collections.Generic; using System.Management.Automation.Language; using Microsoft.Windows.PowerShell.ScriptAnalyzer.Generic; @@ -33,17 +34,67 @@ public class PossibleIncorrectComparisonWithNull : IScriptRule { public IEnumerable AnalyzeScript(Ast ast, string fileName) { if (ast == null) throw new ArgumentNullException(Strings.NullAstErrorMessage); - IEnumerable binExpressionAsts = ast.FindAll(testAst => testAst is BinaryExpressionAst, true); + IEnumerable binExpressionAsts = ast.FindAll(testAst => testAst is BinaryExpressionAst, false); - if (binExpressionAsts != null) { - foreach (BinaryExpressionAst binExpressionAst in binExpressionAsts) { - if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq) - || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq)) - && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase)) { - yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binExpressionAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName); + foreach (BinaryExpressionAst binExpressionAst in binExpressionAsts) { + if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq) + || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq)) + && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase)) + { + if (IncorrectComparisonWithNull(binExpressionAst, ast)) + { + yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binExpressionAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName); } } } + + IEnumerable funcAsts = ast.FindAll(item => item is FunctionDefinitionAst, true).Union(ast.FindAll(item => item is FunctionMemberAst, true)); + foreach (Ast funcAst in funcAsts) + { + IEnumerable binAsts = funcAst.FindAll(item => item is BinaryExpressionAst, true); + foreach (BinaryExpressionAst binAst in binAsts) + { + if (IncorrectComparisonWithNull(binAst, funcAst)) + { + yield return new DiagnosticRecord(Strings.PossibleIncorrectComparisonWithNullError, binAst.Extent, GetName(), DiagnosticSeverity.Warning, fileName); + } + } + } + } + + private bool IncorrectComparisonWithNull(BinaryExpressionAst binExpressionAst, Ast ast) + { + if ((binExpressionAst.Operator.Equals(TokenKind.Equals) || binExpressionAst.Operator.Equals(TokenKind.Ceq) + || binExpressionAst.Operator.Equals(TokenKind.Cne) || binExpressionAst.Operator.Equals(TokenKind.Ine) || binExpressionAst.Operator.Equals(TokenKind.Ieq)) + && binExpressionAst.Right.Extent.Text.Equals("$null", StringComparison.OrdinalIgnoreCase)) + { + if (binExpressionAst.Left.StaticType.IsArray) + { + return true; + } + else if (binExpressionAst.Left is VariableExpressionAst) + { + // ignores if the variable is a special variable + if (!Helper.Instance.HasSpecialVars((binExpressionAst.Left as VariableExpressionAst).VariablePath.UserPath)) + { + Type lhsType = Helper.Instance.GetTypeFromAnalysis(binExpressionAst.Left as VariableExpressionAst, ast); + if (lhsType == null) + { + return true; + } + else if (lhsType.IsArray || lhsType == typeof(object) || lhsType == typeof(Undetermined) || lhsType == typeof(Unreached)) + { + return true; + } + } + } + else if (binExpressionAst.Left.StaticType == typeof(object)) + { + return true; + } + } + + return false; } /// diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1 index a2b520739..fbfffd93d 100644 --- a/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1 +++ b/Tests/Rules/PossibleIncorrectComparisonWithNull.ps1 @@ -1,4 +1,23 @@ function CompareWithNull { if ($DebugPreference -eq $null) { } +} + +if (@("dfd", "eee") -eq $null) +{ +} + +if ($randomUninitializedVariable -eq $null) +{ +} + +function Test +{ + $b = "dd", "ddfd"; + if ($b -ceq $null) + { + if ("dd","ee" -eq $null) + { + } + } } \ No newline at end of file diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1 index 4ef850c55..bdcb40902 100644 --- a/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1 +++ b/Tests/Rules/PossibleIncorrectComparisonWithNull.tests.ps1 @@ -7,8 +7,8 @@ $noViolations = Invoke-ScriptAnalyzer $directory\PossibleIncorrectComparisonWith Describe "PossibleIncorrectComparisonWithNull" { Context "When there are violations" { - It "has 1 possible incorrect comparison with null violation" { - $violations.Count | Should Be 1 + It "has 4 possible incorrect comparison with null violation" { + $violations.Count | Should Be 4 } It "has the correct description message" { diff --git a/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1 b/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1 index 5ac741e0a..6a34d3c80 100644 --- a/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1 +++ b/Tests/Rules/PossibleIncorrectComparisonWithNullNoViolations.ps1 @@ -1,4 +1,15 @@ function CompareWithNull { if ($null -eq $DebugPreference) { } + if ($DebugPreference -eq $null) { + } +} + +$a = 3 + +if ($a -eq $null) +{ + if (3 -eq $null) + { + } } \ No newline at end of file