Skip to content

[hls-explicit-record-fields-plugin] Expand used fields only #3386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions ghcide/src/Development/IDE/GHC/Compat/Core.hs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ module Development.IDE.GHC.Compat.Core (
noLocA,
unLocA,
LocatedAn,
LocatedA,
LocatedN,
#if MIN_VERSION_ghc(9,2,0)
GHC.AnnListItem(..),
GHC.NameAnn(..),
Expand Down Expand Up @@ -1031,6 +1033,19 @@ type LocatedAn a = GHC.LocatedAn a
type LocatedAn a = GHC.Located
#endif

#if MIN_VERSION_ghc(9,2,0)
type LocatedA = GHC.LocatedA
#else
type LocatedA = GHC.Located
#endif

#if MIN_VERSION_ghc(9,2,0)
type LocatedN = GHC.LocatedN
#else
type LocatedN = GHC.Located
#endif


#if MIN_VERSION_ghc(9,2,0)
locA :: SrcSpanAnn' a -> SrcSpan
locA = GHC.locA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ library
, transformers
, ghc-boot-th
, unordered-containers
, containers
, extra
hs-source-dirs: src
default-language: Haskell2010

Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}

module Ide.Plugin.ExplicitFields
( descriptor
, Log
) where

import Control.Lens ((^.))
import Control.Monad.Extra (maybeM)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Except (ExceptT)
import Data.Foldable (find)
import Data.Generics (GenericQ, everything, extQ,
mkQ)
import qualified Data.HashMap.Strict as HashMap
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe (isJust, listToMaybe,
maybeToList)
import Data.Text (Text)
import Development.IDE (IdeState, NormalizedFilePath,
Pretty (..), Recorder (..),
Rules, WithPriority (..),
realSrcSpanToRange)
realSpan, realSrcSpanToRange)
import Development.IDE.Core.Rules (runAction)
import Development.IDE.Core.RuleTypes (TcModuleResult (..),
TypeCheck (..))
Expand All @@ -38,11 +44,15 @@ import Development.IDE.GHC.Compat (HsConDetails (RecCon),
import Development.IDE.GHC.Compat.Core (Extension (NamedFieldPuns),
GhcPass,
HsExpr (RecordCon, rcon_flds),
LHsExpr, Pass (..), Pat (..),
RealSrcSpan, conPatDetails,
hfbPun, hs_valds,
HsRecField, LHsExpr, LocatedA,
LocatedN, Name, Pass (..),
Pat (..), RealSrcSpan,
conPatDetails, getUnique,
hfbPun, hfbRHS, hs_valds,
mapConPatDetail, mapLoc,
nameSrcSpan, nameUnique,
pattern RealSrcSpan)
import Development.IDE.GHC.Compat.Util (Unique, nonDetCmpUnique)
import Development.IDE.GHC.Util (getExtensions,
printOutputable)
import Development.IDE.Graph (RuleResult)
Expand Down Expand Up @@ -137,23 +147,42 @@ codeActionProvider ideState pId (CodeActionParams _ _ docId range _) = pluginRes
title = "Expand record wildcard"

collectRecordsRule :: Recorder (WithPriority Log) -> Rules ()
collectRecordsRule recorder = define (cmapWithPrio LogShake recorder) $ \CollectRecords nfp -> do
tmr <- use TypeCheck nfp
let exts = getEnabledExtensions <$> tmr
recs = concat $ maybeToList (getRecords <$> tmr)
logWith recorder Debug (LogCollectedRecords recs)
let renderedRecs = traverse renderRecordInfo recs
recMap = RangeMap.fromList (realSrcSpanToRange . renderedSrcSpan) <$> renderedRecs
logWith recorder Debug (LogRenderedRecords (concat renderedRecs))
pure ([], CRR <$> recMap <*> exts)
collectRecordsRule recorder = define (cmapWithPrio LogShake recorder) $ \CollectRecords nfp ->
justOrFail "Unable to TypeCheck" (use TypeCheck nfp) $ \tmr -> do
let exts = getEnabledExtensions tmr
recs = getRecords tmr
logWith recorder Debug (LogCollectedRecords recs)
let names = getNames tmr
renderedRecs = traverse (renderRecordInfo names) recs
recMap = RangeMap.fromList (realSrcSpanToRange . renderedSrcSpan) <$> renderedRecs
logWith recorder Debug (LogRenderedRecords (concat renderedRecs))
pure ([], CRR <$> recMap <*> Just exts)

where
getEnabledExtensions :: TcModuleResult -> [GhcExtension]
getEnabledExtensions = map GhcExtension . getExtensions . tmrParsed

justOrFail :: MonadFail m => String -> m (Maybe a) -> (a -> m b) -> m b
justOrFail = flip . maybeM . fail

getRecords :: TcModuleResult -> [RecordInfo]
getRecords (tmrRenamed -> (hs_valds -> valBinds,_,_,_)) =
collectRecords valBinds

-- | Collects all 'Name's of a given source file, to be used
-- in the variable usage analysis.
getNames :: TcModuleResult -> Map UniqueKey [LocatedN Name]
getNames (tmrRenamed -> (group,_,_,_)) = collectNames group

newtype UniqueKey = UniqueKey Unique
deriving newtype Eq

getUniqueKey :: Name -> UniqueKey
getUniqueKey = UniqueKey . nameUnique

instance Ord UniqueKey where
compare (UniqueKey u1) (UniqueKey u2) = getUnique u1 `nonDetCmpUnique` getUnique u2

data CollectRecords = CollectRecords
deriving (Eq, Show, Generic)

Expand Down Expand Up @@ -199,9 +228,41 @@ instance Pretty RenderedRecordInfo where

instance NFData RenderedRecordInfo

renderRecordInfo :: RecordInfo -> Maybe RenderedRecordInfo
renderRecordInfo (RecordInfoPat ss pat) = RenderedRecordInfo ss <$> showRecordPat pat
renderRecordInfo (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordCon expr
renderRecordInfo :: Map UniqueKey [LocatedN Name] -> RecordInfo -> Maybe RenderedRecordInfo
renderRecordInfo names (RecordInfoPat ss pat) = RenderedRecordInfo ss <$> showRecordPat names pat
renderRecordInfo _ (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordCon expr

-- | Checks if a 'Name' is referenced in a given list of names. The 'Eq'
-- instance of 'Name's makes use of their unique identifiers, hence any
-- to 'Name' referring to the same entity is considered equal. In order
-- to ensure that no false-positive is reported (in the case where the
-- 'name' itself is part of the given list), the inequality of source
-- locations is also checked.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps an example? I couldn't quite figure out why this matters, and tbh it seems rather weird to have a match on the unique but not on the srcspan?

referencedIn :: Name -> Map UniqueKey [LocatedN Name] -> Bool
referencedIn name names = maybe True hasNameRef $ Map.lookup (getUniqueKey name) names
where
hasNameRef :: [LocatedN Name] -> Bool
hasNameRef = isJust . find (\n -> realSpan (getLoc n) /= realSpan (nameSrcSpan name))

-- Default to leaving the element in if somehow a name can't be extracted (i.e.
-- `getName` returns `Nothing`).
filterReferenced :: (a -> Maybe Name) -> Map UniqueKey [LocatedN Name] -> [a] -> [a]
filterReferenced getName names = filter (\x -> maybe True (`referencedIn` names) (getName x))

preprocessRecordPat
:: p ~ GhcPass 'Renamed
=> Map UniqueKey [LocatedN Name]
-> HsRecFields p (LPat p)
-> HsRecFields p (LPat p)
preprocessRecordPat = preprocessRecord (getFieldName . unLoc)
where
getFieldName x = case unLoc (hfbRHS x) of
VarPat _ x' -> Just $ unLoc x'
_ -> Nothing

-- No need to check the name usage in the record construction case
preprocessRecordCon :: HsRecFields (GhcPass c) arg -> HsRecFields (GhcPass c) arg
preprocessRecordCon = preprocessRecord (const Nothing) Map.empty

-- We make use of the `Outputable` instances on AST types to pretty-print
-- the renamed and expanded records back into source form, to be substituted
Expand All @@ -212,8 +273,13 @@ renderRecordInfo (RecordInfoCon ss expr) = RenderedRecordInfo ss <$> showRecordC
-- as we want to print the records in their fully expanded form.
-- Here `rec_dotdot` is set to `Nothing` so that fields are printed without
-- such post-processing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this comment could do with some tweaks. It suggests that this is primarily about printing the records which makes it sounds like you always print them as they are, but now you're adding some additional logic to print something else depending on the fields in use.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have revised this and a bunch of other comments, can you take another look to see if it looks better?

preprocessRecord :: HsRecFields (GhcPass c) arg -> HsRecFields (GhcPass c) arg
preprocessRecord flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
preprocessRecord
:: p ~ GhcPass c
=> (LocatedA (HsRecField p arg) -> Maybe Name)
-> Map UniqueKey [LocatedN Name]
-> HsRecFields p arg
-> HsRecFields p arg
preprocessRecord getName names flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
where
no_pun_count = maybe (length (rec_flds flds)) unLoc (rec_dotdot flds)
-- Field binds of the explicit form (e.g. `{ a = a' }`) should be
Expand All @@ -223,29 +289,35 @@ preprocessRecord flds = flds { rec_dotdot = Nothing , rec_flds = rec_flds' }
-- puns (since there is similar mechanism in the `Outputable` instance as
-- explained above).
puns' = map (mapLoc (\fld -> fld { hfbPun = True })) puns
rec_flds' = no_puns <> puns'

showRecordPat :: Outputable (Pat (GhcPass c)) => Pat (GhcPass c) -> Maybe Text
showRecordPat = fmap printOutputable . mapConPatDetail (\case
RecCon flds -> Just $ RecCon (preprocessRecord flds)
-- Unused fields are filtered out so that they don't end up in the expanded
-- form.
punsUsed = filterReferenced getName names puns'
rec_flds' = no_puns <> punsUsed

showRecordPat :: Outputable (Pat (GhcPass 'Renamed)) => Map UniqueKey [LocatedN Name] -> Pat (GhcPass 'Renamed) -> Maybe Text
showRecordPat names = fmap printOutputable . mapConPatDetail (\case
RecCon flds -> Just $ RecCon (preprocessRecordPat names flds)
_ -> Nothing)

showRecordCon :: Outputable (HsExpr (GhcPass c)) => HsExpr (GhcPass c) -> Maybe Text
showRecordCon expr@(RecordCon _ _ flds) =
Just $ printOutputable $
expr { rcon_flds = preprocessRecord flds }
expr { rcon_flds = preprocessRecordCon flds }
showRecordCon _ = Nothing

collectRecords :: GenericQ [RecordInfo]
collectRecords = everything (<>) (maybeToList . (Nothing `mkQ` getRecPatterns `extQ` getRecCons))

collectNames :: GenericQ (Map UniqueKey [LocatedN Name])
collectNames = everything (Map.unionWith (<>)) (Map.empty `mkQ` (\x -> Map.singleton (getUniqueKey (unLoc x)) [x]))

getRecCons :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo
getRecCons e@(unLoc -> RecordCon _ _ flds)
| isJust (rec_dotdot flds) = mkRecInfo e
where
mkRecInfo :: LHsExpr (GhcPass 'Renamed) -> Maybe RecordInfo
mkRecInfo expr = listToMaybe
[ RecordInfoCon realSpan (unLoc expr) | RealSrcSpan realSpan _ <- [ getLoc expr ]]
[ RecordInfoCon realSpan' (unLoc expr) | RealSrcSpan realSpan' _ <- [ getLoc expr ]]
getRecCons _ = Nothing

getRecPatterns :: LPat (GhcPass 'Renamed) -> Maybe RecordInfo
Expand All @@ -254,7 +326,7 @@ getRecPatterns conPat@(conPatDetails . unLoc -> Just (RecCon flds))
where
mkRecInfo :: LPat (GhcPass 'Renamed) -> Maybe RecordInfo
mkRecInfo pat = listToMaybe
[ RecordInfoPat realSpan (unLoc pat) | RealSrcSpan realSpan _ <- [ getLoc pat ]]
[ RecordInfoPat realSpan' (unLoc pat) | RealSrcSpan realSpan' _ <- [ getLoc pat ]]
getRecPatterns _ = Nothing

collectRecords' :: MonadIO m => IdeState -> NormalizedFilePath -> ExceptT String m CollectRecordsResult
Expand Down
4 changes: 3 additions & 1 deletion plugins/hls-explicit-record-fields-plugin/test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ plugin = mkPluginTestDescriptor ExplicitFields.descriptor "explicit-fields"
test :: TestTree
test = testGroup "explicit-fields"
[ mkTest "WildcardOnly" "WildcardOnly" 12 10 12 20
, mkTest "Unused" "Unused" 12 10 12 20
, mkTest "Unused2" "Unused2" 12 10 12 20
, mkTest "WithPun" "WithPun" 13 10 13 25
, mkTest "WithExplicitBind" "WithExplicitBind" 12 10 12 32
, mkTest "Mixed" "Mixed" 13 10 13 37
, mkTest "Mixed" "Mixed" 14 10 14 37
, mkTest "Construction" "Construction" 16 5 16 15
, mkTestNoAction "ExplicitBinds" "ExplicitBinds" 11 10 11 52
, mkTestNoAction "Puns" "Puns" 12 10 12 31
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
, quux :: Double
}

convertMe :: MyRec -> String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
, quux :: Double
}

convertMe :: MyRec -> String
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE NamedFieldPuns #-}

module Unused where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {foo, bar} = show foo ++ show bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}

module Unused where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {..} = show foo ++ show bar
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE NamedFieldPuns #-}

module Unused2 where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {foo, bar} = let baz = "baz" in show foo ++ show bar ++ baz
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{-# LANGUAGE Haskell2010 #-}
{-# LANGUAGE RecordWildCards #-}

module Unused2 where

data MyRec = MyRec
{ foo :: Int
, bar :: Int
, baz :: Char
}

convertMe :: MyRec -> String
convertMe MyRec {..} = let baz = "baz" in show foo ++ show bar ++ baz
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
{-# LANGUAGE CPP #-}
module Development.IDE.GHC.Dump(showAstDataHtml) where
import Data.Data hiding (Fixity)
import Development.IDE.GHC.Compat hiding (NameAnn)
import Development.IDE.GHC.Compat hiding (LocatedA,
NameAnn)
import Development.IDE.GHC.Compat.ExactPrint
import GHC.Hs.Dump
#if MIN_VERSION_ghc(9,2,1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ import GHC (AddEpAnn (Ad
DeltaPos (..),
EpAnn (..),
EpaLocation (..),
LEpaComment,
LocatedA)
LEpaComment)
#else
import Language.Haskell.GHC.ExactPrint.Types (Annotation (annsDP),
DeltaPos,
Expand Down Expand Up @@ -1535,11 +1534,7 @@ findPositionNoImports ps fileContents =

-- | find line number right after module ... where
findPositionAfterModuleName :: Annotated ParsedSource
#if MIN_VERSION_ghc(9,2,0)
-> LocatedA ModuleName
#else
-> Located ModuleName
#endif
-> Maybe Int
findPositionAfterModuleName ps hsmodName' = do
-- Note that 'where' keyword and comments are not part of the AST. They belongs to
Expand Down