From 69b0cae3ed31de9f41a4d03a80254ec92b582d67 Mon Sep 17 00:00:00 2001 From: Jan Nedbal Date: Tue, 25 Jun 2024 09:18:40 +0200 Subject: [PATCH] QueryResultTypeWalker: fix TypedExpression handling --- .../Doctrine/Query/QueryResultTypeWalker.php | 34 ++++++------------- .../Query/QueryResultTypeWalkerTest.php | 34 ++++--------------- 2 files changed, 18 insertions(+), 50 deletions(-) diff --git a/src/Type/Doctrine/Query/QueryResultTypeWalker.php b/src/Type/Doctrine/Query/QueryResultTypeWalker.php index 55375c5f..30bbf495 100644 --- a/src/Type/Doctrine/Query/QueryResultTypeWalker.php +++ b/src/Type/Doctrine/Query/QueryResultTypeWalker.php @@ -3,7 +3,8 @@ namespace PHPStan\Type\Doctrine\Query; use BackedEnum; -use Doctrine\DBAL\Types\Types; +use Doctrine\DBAL\Types\StringType as DbalStringType; +use Doctrine\DBAL\Types\Type as DbalType; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Mapping\ClassMetadata; use Doctrine\ORM\Query; @@ -817,28 +818,15 @@ public function walkSelectExpression($selectExpression): string $resultAlias = $selectExpression->fieldIdentificationVariable ?? $this->scalarResultCounter++; $type = $this->unmarshalType($expr->dispatch($this)); - if (class_exists(TypedExpression::class) && $expr instanceof TypedExpression) { - $enforcedType = $this->resolveDoctrineType(Types::INTEGER); - $type = TypeTraverser::map($type, static function (Type $type, callable $traverse) use ($enforcedType): Type { - if ($type instanceof UnionType || $type instanceof IntersectionType) { - return $traverse($type); - } - if ($type instanceof NullType) { - return $type; - } - if ($enforcedType->accepts($type, true)->yes()) { - return $type; - } - if ($enforcedType instanceof StringType) { - if ($type instanceof IntegerType || $type instanceof FloatType) { - return TypeCombinator::union($type->toString(), $type); - } - if ($type instanceof BooleanType) { - return TypeCombinator::union($type->toInteger()->toString(), $type); - } - } - return $enforcedType; - }); + if ( + $expr instanceof TypedExpression + && !$expr->getReturnType() instanceof DbalStringType // StringType is no-op, so using TypedExpression with that does nothing + ) { + $dbalTypeName = DbalType::getTypeRegistry()->lookupName($expr->getReturnType()); + $type = TypeCombinator::intersect( // e.g. count is typed as int, but we infer int<0, max> + $type, + $this->resolveDoctrineType($dbalTypeName, null, TypeCombinator::containsNull($type)) + ); } else { // Expressions default to Doctrine's StringType, whose // convertToPHPValue() is a no-op. So the actual type depends on diff --git a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php index f2f349ed..b82c0c1b 100644 --- a/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php +++ b/tests/Type/Doctrine/Query/QueryResultTypeWalkerTest.php @@ -9,7 +9,6 @@ use Doctrine\Common\Collections\ArrayCollection; use Doctrine\ORM\EntityManagerInterface; use Doctrine\ORM\Mapping\Column; -use Doctrine\ORM\Query\AST\TypedExpression; use Doctrine\ORM\Tools\SchemaTool; use PHPStan\Testing\PHPStanTestCase; use PHPStan\Type\Accessory\AccessoryNumericStringType; @@ -613,9 +612,7 @@ public function getTestData(): iterable ], [ new ConstantIntegerType(3), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(4), @@ -623,9 +620,7 @@ public function getTestData(): iterable ], [ new ConstantIntegerType(5), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(6), @@ -680,9 +675,7 @@ public function getTestData(): iterable ], [ new ConstantStringType('count'), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], ]), ' @@ -1358,23 +1351,17 @@ public function getTestData(): iterable $this->constantArray([ [ new ConstantIntegerType(1), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], [ new ConstantIntegerType(2), TypeCombinator::addNull( - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified() + $this->uint() ), ], [ new ConstantIntegerType(3), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], ]), ' @@ -1566,9 +1553,7 @@ public function getTestData(): iterable [new ConstantIntegerType(1), TypeCombinator::addNull($this->numericStringOrInt())], [ new ConstantIntegerType(2), - $this->hasTypedExpressions() - ? $this->uint() - : $this->uintStringified(), + $this->uint(), ], ]), ' @@ -1690,11 +1675,6 @@ private function unumericStringified(): Type ); } - private function hasTypedExpressions(): bool - { - return class_exists(TypedExpression::class); - } - /** * @param array $arrays *