@@ -49959,18 +49959,17 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
49959
49959
SDValue Ptr = Ld->getBasePtr();
49960
49960
SDValue Chain = Ld->getChain();
49961
49961
for (SDNode *User : Chain->uses()) {
49962
- if (User != N &&
49962
+ auto *UserLd = dyn_cast<MemSDNode>(User);
49963
+ if (User != N && UserLd &&
49963
49964
(User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD ||
49964
49965
User->getOpcode() == X86ISD::VBROADCAST_LOAD ||
49965
49966
ISD::isNormalLoad(User)) &&
49966
- cast<MemSDNode>(User)->getChain() == Chain &&
49967
- !User->hasAnyUseOfValue(1) &&
49967
+ UserLd->getChain() == Chain && !User->hasAnyUseOfValue(1) &&
49968
49968
User->getValueSizeInBits(0).getFixedValue() >
49969
49969
RegVT.getFixedSizeInBits()) {
49970
49970
if (User->getOpcode() == X86ISD::SUBV_BROADCAST_LOAD &&
49971
- cast<MemSDNode>(User)->getBasePtr() == Ptr &&
49972
- cast<MemSDNode>(User)->getMemoryVT().getSizeInBits() ==
49973
- MemVT.getSizeInBits()) {
49971
+ UserLd->getBasePtr() == Ptr &&
49972
+ UserLd->getMemoryVT().getSizeInBits() == MemVT.getSizeInBits()) {
49974
49973
SDValue Extract = extractSubVector(SDValue(User, 0), 0, DAG, SDLoc(N),
49975
49974
RegVT.getSizeInBits());
49976
49975
Extract = DAG.getBitcast(RegVT, Extract);
@@ -49989,7 +49988,7 @@ static SDValue combineLoad(SDNode *N, SelectionDAG &DAG,
49989
49988
// See if we are loading a constant that matches in the lower
49990
49989
// bits of a longer constant (but from a different constant pool ptr).
49991
49990
EVT UserVT = User->getValueType(0);
49992
- SDValue UserPtr = cast<MemSDNode>(User) ->getBasePtr();
49991
+ SDValue UserPtr = UserLd ->getBasePtr();
49993
49992
const Constant *LdC = getTargetConstantFromBasePtr(Ptr);
49994
49993
const Constant *UserC = getTargetConstantFromBasePtr(UserPtr);
49995
49994
if (LdC && UserC && UserPtr != Ptr) {
0 commit comments