diff --git a/plugins/hls-explicit-record-fields-plugin/src/Ide/Plugin/ExplicitFields.hs b/plugins/hls-explicit-record-fields-plugin/src/Ide/Plugin/ExplicitFields.hs index 12a0791b6c..751f5eefb6 100644 --- a/plugins/hls-explicit-record-fields-plugin/src/Ide/Plugin/ExplicitFields.hs +++ b/plugins/hls-explicit-record-fields-plugin/src/Ide/Plugin/ExplicitFields.hs @@ -6,6 +6,7 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE ViewPatterns #-} @@ -18,12 +19,13 @@ module Ide.Plugin.ExplicitFields import Control.Lens ((^.)) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Except (ExceptT) +import Data.Bifunctor (first) import Data.Functor ((<&>)) -import Data.Generics (GenericQ, everything, extQ, - mkQ) +import Data.Generics (GenericQ, everything, + everythingBut, extQ, mkQ) import qualified Data.HashMap.Strict as HashMap -import Data.Maybe (isJust, listToMaybe, - maybeToList, fromMaybe) +import Data.Maybe (fromMaybe, isJust, + listToMaybe, maybeToList) import Data.Text (Text) import Development.IDE (IdeState, NormalizedFilePath, Pretty (..), Recorder (..), @@ -36,11 +38,11 @@ import Development.IDE.Core.Shake (define, use) import qualified Development.IDE.Core.Shake as Shake import Development.IDE.GHC.Compat (HsConDetails (RecCon), HsRecFields (..), LPat, - Outputable, getLoc, unLoc, - recDotDot) + Outputable, getLoc, recDotDot, + unLoc) import Development.IDE.GHC.Compat.Core (Extension (NamedFieldPuns), - GhcPass, - HsExpr (RecordCon, rcon_flds), + GhcPass, HsExpansion (..), + HsExpr (RecordCon, XExpr, rcon_flds), HsRecField, LHsExpr, LocatedA, Name, Pass (..), Pat (..), RealSrcSpan, UniqFM, @@ -329,8 +331,13 @@ showRecordCon expr@(RecordCon _ _ flds) = expr { rcon_flds = preprocessRecordCon flds } showRecordCon _ = Nothing +-- It's important that we use everthingBut here, because if we used everything +-- we would get duplicates for every case that occurs inside a HsExpanded expression. collectRecords :: GenericQ [RecordInfo] -collectRecords = everything (<>) (maybeToList . (Nothing `mkQ` getRecPatterns `extQ` getRecCons)) +collectRecords = + everythingBut (<>) (first maybeToList . ((Nothing, False) `mkQ` getRecPatterns' `extQ` getRecCons)) + where + getRecPatterns' = (,False) . getRecPatterns -- | Collect 'Name's into a map, indexed by the names' unique identifiers. -- The 'Eq' instance of 'Name's makes use of their unique identifiers, hence @@ -347,14 +354,19 @@ collectRecords = everything (<>) (maybeToList . (Nothing `mkQ` getRecPatterns `e collectNames :: GenericQ (UniqFM Name [Name]) collectNames = everything (plusUFM_C (<>)) (emptyUFM `mkQ` (\x -> unitUFM x [x])) -getRecCons :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo +getRecCons :: LHsExpr (GhcPass 'Renamed) -> (Maybe RecordInfo, Bool) +-- When we stumble upon an occurrence of HsExpanded, we only want to follow a +-- single branch. We do this here, by explicitly returning occurrences from +-- traversing the original branch, and returning True, which keeps syb from +-- implicitly continuing to traverse. +getRecCons (unLoc -> XExpr (HsExpanded _ expanded)) = (listToMaybe (collectRecords expanded), True) getRecCons e@(unLoc -> RecordCon _ _ flds) - | isJust (rec_dotdot flds) = mkRecInfo e + | isJust (rec_dotdot flds) = (mkRecInfo e, False) where mkRecInfo :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo mkRecInfo expr = listToMaybe [ RecordInfoCon realSpan' (unLoc expr) | RealSrcSpan realSpan' _ <- [ getLoc expr ]] -getRecCons _ = Nothing +getRecCons _ = (Nothing, False) getRecPatterns :: LPat (GhcPass 'Renamed) -> Maybe RecordInfo getRecPatterns conPat@(conPatDetails . unLoc -> Just (RecCon flds)) diff --git a/plugins/hls-explicit-record-fields-plugin/test/Main.hs b/plugins/hls-explicit-record-fields-plugin/test/Main.hs index abbf3d8809..b686c08c2a 100644 --- a/plugins/hls-explicit-record-fields-plugin/test/Main.hs +++ b/plugins/hls-explicit-record-fields-plugin/test/Main.hs @@ -27,6 +27,7 @@ test = testGroup "explicit-fields" , mkTest "WithExplicitBind" "WithExplicitBind" 12 10 12 32 , mkTest "Mixed" "Mixed" 14 10 14 37 , mkTest "Construction" "Construction" 16 5 16 15 + , mkTest "Construction (Dot)" "3574" 16 5 16 15 , mkTestNoAction "ExplicitBinds" "ExplicitBinds" 11 10 11 52 , mkTestNoAction "Puns" "Puns" 12 10 12 31 , mkTestNoAction "Infix" "Infix" 11 11 11 31 @@ -41,12 +42,16 @@ mkTestNoAction title fp x1 y1 x2 y2 = actions <- getExplicitFieldsActions doc x1 y1 x2 y2 liftIO $ actions @?= [] -mkTest :: TestName -> FilePath -> UInt -> UInt -> UInt -> UInt -> TestTree -mkTest title fp x1 y1 x2 y2 = +mkTestWithCount :: Int -> TestName -> FilePath -> UInt -> UInt -> UInt -> UInt -> TestTree +mkTestWithCount cnt title fp x1 y1 x2 y2 = goldenWithHaskellDoc plugin title testDataDir fp "expected" "hs" $ \doc -> do - (act:_) <- getExplicitFieldsActions doc x1 y1 x2 y2 + acts@(act:_) <- getExplicitFieldsActions doc x1 y1 x2 y2 + liftIO $ length acts @?= cnt executeCodeAction act +mkTest :: TestName -> FilePath -> UInt -> UInt -> UInt -> UInt -> TestTree +mkTest = mkTestWithCount 1 + getExplicitFieldsActions :: TextDocumentIdentifier -> UInt -> UInt -> UInt -> UInt diff --git a/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.expected.hs b/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.expected.hs new file mode 100644 index 0000000000..9dacc875f0 --- /dev/null +++ b/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.expected.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE Haskell2010 #-} +{-# LANGUAGE RecordWildCards #-} +{-# Language OverloadedRecordDot #-} +{-# LANGUAGE NamedFieldPuns #-} +module Construction where + +data MyRec = MyRec + { foo :: Int + , bar :: Int + , baz :: Char + } + +convertMe :: () -> Int +convertMe _ = + let foo = 3 + bar = 5 + baz = 'a' + in MyRec {foo, bar, baz}.foo diff --git a/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.hs b/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.hs new file mode 100644 index 0000000000..ae802d4345 --- /dev/null +++ b/plugins/hls-explicit-record-fields-plugin/test/testdata/3574.hs @@ -0,0 +1,17 @@ +{-# LANGUAGE Haskell2010 #-} +{-# LANGUAGE RecordWildCards #-} +{-# Language OverloadedRecordDot #-} +module Construction where + +data MyRec = MyRec + { foo :: Int + , bar :: Int + , baz :: Char + } + +convertMe :: () -> Int +convertMe _ = + let foo = 3 + bar = 5 + baz = 'a' + in MyRec {..}.foo