diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs index 28a3bf8274..db20420ede 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs @@ -1,12 +1,16 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} module Ide.Plugin.Tactic.CodeGen where +import Control.Lens ((+~), (%~), (<>~)) import Control.Monad.Except import Control.Monad.State (MonadState) import Control.Monad.State.Class (modify) +import Data.Generics.Product (field) import Data.List import qualified Data.Map as M import qualified Data.Set as S @@ -31,10 +35,25 @@ useOccName :: MonadState TacticState m => Judgement -> OccName -> m () useOccName jdg name = -- Only score points if this is in the local hypothesis case M.lookup name $ jLocalHypothesis jdg of - Just{} -> modify $ withUsedVals $ S.insert name + Just{} -> modify + $ (withUsedVals $ S.insert name) + . (field @"ts_unused_top_vals" %~ S.delete name) Nothing -> pure () +------------------------------------------------------------------------------ +-- | Doing recursion incurs a small penalty in the score. +penalizeRecursion :: MonadState TacticState m => m () +penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1 + + +------------------------------------------------------------------------------ +-- | Insert some values into the unused top values field. These are +-- subsequently removed via 'useOccName'. +addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m () +addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals + + destructMatches :: (DataCon -> Judgement -> Rule) -- ^ How to construct each match diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs index 743448dc64..3beb40daa4 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs @@ -75,6 +75,13 @@ introducing ns = field @"_jHypothesis" <>~ M.fromList ns +------------------------------------------------------------------------------ +-- | Add some terms to the ambient hypothesis +introducingAmbient :: [(OccName, a)] -> Judgement' a -> Judgement' a +introducingAmbient ns = + field @"_jAmbientHypothesis" <>~ M.fromList ns + + filterPosition :: OccName -> Int -> Judgement -> Judgement filterPosition defn pos jdg = withHypothesis (M.filterWithKey go) jdg diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs index 94850fa4e0..25bf3e5c62 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs @@ -27,9 +27,10 @@ import Control.Monad.State.Class (gets, modify) import Control.Monad.State.Strict (StateT (..)) import Data.Coerce import Data.Either +import Data.Foldable import Data.Functor ((<&>)) import Data.Generics (mkQ, everything, gcount) -import Data.List (sortBy) +import Data.List (nub, sortBy) import Data.Ord (comparing, Down(..)) import qualified Data.Set as S import Development.IDE.GHC.Compat @@ -71,7 +72,12 @@ runTactic -> Either [TacticError] RunTacticResults runTactic ctx jdg t = let skolems = tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg - tacticState = defaultTacticState { ts_skolems = skolems } + unused_topvals = nub $ join $ join $ toList $ _jPositionMaps jdg + tacticState = + defaultTacticState + { ts_skolems = skolems + , ts_unused_top_vals = S.fromList unused_topvals + } in case partitionEithers . flip runReader ctx . unExtractM @@ -126,21 +132,31 @@ setRecursionFrameData b = do [] -> [] +------------------------------------------------------------------------------ +-- | Given the results of running a tactic, score the solutions by +-- desirability. +-- +-- TODO(sandy): This function is completely unprincipled and was just hacked +-- together to produce the right test results. scoreSolution :: LHsExpr GhcPs -> TacticState -> [Judgement] -> ( Penalize Int -- number of holes , Reward Bool -- all bindings used + , Penalize Int -- unused top-level bindings , Penalize Int -- number of introduced bindings , Reward Int -- number used bindings + , Penalize Int -- number of recursive calls , Penalize Int -- size of extract ) scoreSolution ext TacticState{..} holes = ( Penalize $ length holes , Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals + , Penalize $ S.size ts_unused_top_vals , Penalize $ S.size ts_intro_vals , Reward $ S.size ts_used_vals + , Penalize $ ts_recursion_penality , Penalize $ solutionSize ext ) diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs index 4a6389ec9f..f1c2a6d220 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs @@ -69,8 +69,9 @@ recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions attemptOn (const $ fmap fst defs) $ \name -> do modify $ withRecursionStack (False :) + penalizeRecursion ensure recursiveCleanup (withRecursionStack tail) $ do - (localTactic (apply name) $ introducing defs) + (localTactic (apply name) $ introducingAmbient defs) <@> fmap (localTactic assumption . filterPosition name) [0..] @@ -88,6 +89,7 @@ intros = rule $ \jdg -> do let jdg' = introducing (zip vs $ coerce as) $ withNewGoal (CType b) jdg modify $ withIntroducedVals $ mappend $ S.fromList vs + when (isTopHole jdg) $ addUnusedTopVals $ S.fromList vs (tr, sg) <- newSubgoal $ bool diff --git a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs index 2d7299a380..6b4201b49a 100644 --- a/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/tactics/src/Ide/Plugin/Tactic/Types.hs @@ -74,10 +74,21 @@ instance Show DataCon where ------------------------------------------------------------------------------ data TacticState = TacticState { ts_skolems :: !([TyVar]) + -- ^ The known skolems. , ts_unifier :: !(TCvSubst) + -- ^ The current substitution of univars. , ts_used_vals :: !(Set OccName) + -- ^ Set of values used by tactics. , ts_intro_vals :: !(Set OccName) + -- ^ Set of values introduced by tactics. + , ts_unused_top_vals :: !(Set OccName) + -- ^ Set of currently unused arguments to the function being defined. , ts_recursion_stack :: ![Bool] + -- ^ Stack for tracking whether or not the current recursive call has + -- used at least one smaller pat val. Recursive calls for which this + -- value is 'False' are guaranteed to loop, and must be pruned. + , ts_recursion_penality :: !Int + -- ^ Number of calls to recursion. We penalize each. , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) @@ -100,7 +111,9 @@ defaultTacticState = , ts_unifier = emptyTCvSubst , ts_used_vals = mempty , ts_intro_vals = mempty + , ts_unused_top_vals = mempty , ts_recursion_stack = mempty + , ts_recursion_penality = 0 , ts_unique_gen = unsafeDefaultUniqueSupply } diff --git a/test/functional/Tactic.hs b/test/functional/Tactic.hs index eb31c58327..2f3b05d31a 100644 --- a/test/functional/Tactic.hs +++ b/test/functional/Tactic.hs @@ -109,6 +109,7 @@ tests = testGroup , goldenTest "GoldenSuperclass.hs" 7 8 Auto "" , ignoreTestBecause "It is unreliable in circleci builds" $ goldenTest "GoldenApplicativeThen.hs" 2 11 Auto "" + , goldenTest "GoldenSafeHead.hs" 2 12 Auto "" ] diff --git a/test/testdata/tactic/GoldenSafeHead.hs b/test/testdata/tactic/GoldenSafeHead.hs new file mode 100644 index 0000000000..6a5d27c0d1 --- /dev/null +++ b/test/testdata/tactic/GoldenSafeHead.hs @@ -0,0 +1,2 @@ +safeHead :: [x] -> Maybe x +safeHead = _ diff --git a/test/testdata/tactic/GoldenSafeHead.hs.expected b/test/testdata/tactic/GoldenSafeHead.hs.expected new file mode 100644 index 0000000000..7a404f1d4e --- /dev/null +++ b/test/testdata/tactic/GoldenSafeHead.hs.expected @@ -0,0 +1,5 @@ +safeHead :: [x] -> Maybe x +safeHead = (\ l_x + -> case l_x of + [] -> Nothing + (x : l_x2) -> Just x)