diff --git a/src/Type/Doctrine/Query/QueryResultTypeWalker.php b/src/Type/Doctrine/Query/QueryResultTypeWalker.php index d68bcdb8..b053a696 100644 --- a/src/Type/Doctrine/Query/QueryResultTypeWalker.php +++ b/src/Type/Doctrine/Query/QueryResultTypeWalker.php @@ -3,7 +3,7 @@ namespace PHPStan\Type\Doctrine\Query; use BackedEnum; -use Doctrine\DBAL\Types\Types; +use Doctrine\DBAL\Types\Type as DbalType; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Mapping\ClassMetadata; use Doctrine\ORM\Query; @@ -361,9 +361,11 @@ public function walkFunction($function): string case $function instanceof AST\Functions\MaxFunction: case $function instanceof AST\Functions\MinFunction: case $function instanceof AST\Functions\SumFunction: - case $function instanceof AST\Functions\CountFunction: return $function->getSql($this); + case $function instanceof AST\Functions\CountFunction: + return $this->marshalType(IntegerRangeType::fromInterval(0, null)); + case $function instanceof AST\Functions\AbsFunction: $exprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->simpleArithmeticExpression)); @@ -816,28 +818,11 @@ public function walkSelectExpression($selectExpression): string $resultAlias = $selectExpression->fieldIdentificationVariable ?? $this->scalarResultCounter++; $type = $this->unmarshalType($expr->dispatch($this)); - if (class_exists(TypedExpression::class) && $expr instanceof TypedExpression) { - $enforcedType = $this->resolveDoctrineType(Types::INTEGER); - $type = TypeTraverser::map($type, static function (Type $type, callable $traverse) use ($enforcedType): Type { - if ($type instanceof UnionType || $type instanceof IntersectionType) { - return $traverse($type); - } - if ($type instanceof NullType) { - return $type; - } - if ($enforcedType->accepts($type, true)->yes()) { - return $type; - } - if ($enforcedType instanceof StringType) { - if ($type instanceof IntegerType || $type instanceof FloatType) { - return TypeCombinator::union($type->toString(), $type); - } - if ($type instanceof BooleanType) { - return TypeCombinator::union($type->toInteger()->toString(), $type); - } - } - return $enforcedType; - }); + if ($expr instanceof TypedExpression) { + $type = TypeCombinator::intersect( + $type, + $this->resolveDoctrineType(DbalType::lookupName($expr->getReturnType()), null, TypeCombinator::containsNull($type)) + ); } else { // Expressions default to Doctrine's StringType, whose // convertToPHPValue() is a no-op. So the actual type depends on diff --git a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php index 983362c2..9801fd5f 100644 --- a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php +++ b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php @@ -9,7 +9,6 @@ use Doctrine\Common\Collections\ArrayCollection; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Mapping\Column; -use Doctrine\ORM\Query\AST\TypedExpression; use Doctrine\ORM\Tools\SchemaTool; use PHPStan\Testing\PHPStanTestCase; use PHPStan\Type\Accessory\AccessoryNumericStringType; @@ -611,9 +610,7 @@ public function getTestData(): iterable ], [ new ConstantIntegerType(3), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(4), @@ -621,9 +618,7 @@ public function getTestData(): iterable ], [ new ConstantIntegerType(5), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(6), @@ -645,6 +640,29 @@ public function getTestData(): iterable ', ]; + yield 'count' => [ + $this->constantArray([ + [ + new ConstantIntegerType(1), + $this->uint(), + ], + [ + new ConstantIntegerType(2), + $this->uint(), + ], + [ + new ConstantIntegerType(3), + $this->uint(), + ], + ]), + ' + SELECT COUNT(m.stringNullColumn), + COUNT(m.stringColumn), + COUNT(m) + FROM QueryResult\Entities\Many m + ', + ]; + yield 'aggregate lowercase' => [ $this->constantArray([ [ @@ -678,9 +696,7 @@ public function getTestData(): iterable ], [ new ConstantStringType('count'), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], ]), ' @@ -1346,29 +1362,28 @@ public function getTestData(): iterable $this->constantArray([ [ new ConstantIntegerType(1), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(2), TypeCombinator::addNull( - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified() + $this->uint() ), ], [ new ConstantIntegerType(3), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), + ], + [ + new ConstantIntegerType(4), + $this->uint(), ], ]), ' SELECT LENGTH(m.stringColumn), LENGTH(m.stringNullColumn), - LENGTH(\'foo\') + LENGTH(\'foo\'), + LENGTH(COALESCE(m.stringNullColumn, \'\')) FROM QueryResult\Entities\Many m ', ]; @@ -1554,9 +1569,7 @@ public function getTestData(): iterable [new ConstantIntegerType(1), TypeCombinator::addNull($this->numericStringOrInt())], [ new ConstantIntegerType(2), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], ]), ' @@ -1678,11 +1691,6 @@ private function unumericStringified(): Type ); } - private function hasTypedExpressions(): bool - { - return class_exists(TypedExpression::class); - } - /** * @param array $arrays *