diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs index c73b6090ff..750743dca1 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs @@ -1,4 +1,5 @@ {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} @@ -7,9 +8,10 @@ module Ide.Plugin.Tactic.CodeGen , module Ide.Plugin.Tactic.CodeGen.Utils ) where -import Control.Lens ((+~)) + +import Control.Lens ((%~), (<>~), (&)) import Control.Monad.Except -import Data.Generics.Product (field) +import Data.Generics.Labels () import Data.List import qualified Data.Set as S import Data.Traversable @@ -29,13 +31,6 @@ import Ide.Plugin.Tactic.Types import Type hiding (Var) - ------------------------------------------------------------------------------- --- | Doing recursion incurs a small penalty in the score. -countRecursiveCall :: TacticState -> TacticState -countRecursiveCall = field @"ts_recursion_count" +~ 1 - - destructMatches :: (DataCon -> Judgement -> Rule) -- ^ How to construct each match @@ -62,16 +57,12 @@ destructMatches f scrut t jdg = do $ coerce args j = introduce hy' $ withNewGoal g jdg - Synthesized tr sc uv sg <- f dc j - pure - $ Synthesized - ( rose ("match " <> show dc <> " {" <> - intercalate ", " (fmap show names) <> "}") - $ pure tr) - (sc <> hy') - uv - $ match [mkDestructPat dc names] - $ unLoc sg + ext <- f dc j + pure $ ext + & #syn_trace %~ rose ("match " <> show dc <> " {" <> intercalate ", " (fmap show names) <> "}") + . pure + & #syn_scoped <>~ hy' + & #syn_val %~ match [mkDestructPat dc names] . unLoc ------------------------------------------------------------------------------ @@ -138,19 +129,16 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule destruct' f hi jdg = do when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic let term = hi_name hi - Synthesized tr sc uv ms + ext <- destructMatches f (Just term) (hi_type hi) $ disallowing AlreadyDestructed [term] jdg - pure - $ Synthesized - (rose ("destruct " <> show term) $ pure tr) - sc - (S.insert term uv) - $ noLoc - $ case' (var' term) ms + pure $ ext + & #syn_trace %~ rose ("destruct " <> show term) . pure + & #syn_used_vals %~ S.insert term + & #syn_val %~ noLoc . case' (var' term) ------------------------------------------------------------------------------ @@ -176,7 +164,7 @@ buildDataCon -> RuleM (Synthesized (LHsExpr GhcPs)) buildDataCon jdg dc tyapps = do let args = dataConInstOrigArgTys' dc tyapps - Synthesized tr sc uv sgs + ext <- fmap unzipTrace $ traverse ( \(arg, n) -> newSubgoal @@ -185,7 +173,7 @@ buildDataCon jdg dc tyapps = do . flip withNewGoal jdg $ CType arg ) $ zip args [0..] - pure - $ Synthesized (rose (show dc) $ pure tr) sc uv - $ mkCon dc sgs + pure $ ext + & #syn_trace %~ rose (show dc) . pure + & #syn_val %~ mkCon dc diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs index 25ba3b0832..a61f86dbce 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs @@ -50,6 +50,7 @@ deriveArbitrary = do -- But maybe it's fine for known rules? mempty mempty + mempty $ noLoc $ let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $ appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $ diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs index bdaa0aa77f..a4569af5b9 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs @@ -6,36 +6,35 @@ {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} -module Ide.Plugin.Tactic.Machinery - ( module Ide.Plugin.Tactic.Machinery - ) where +module Ide.Plugin.Tactic.Machinery where -import Class (Class (classTyVars)) -import Control.Arrow +import Class (Class (classTyVars)) +import Control.Lens ((<>~)) import Control.Monad.Error.Class import Control.Monad.Reader -import Control.Monad.State (MonadState (..)) -import Control.Monad.State.Class (gets, modify) -import Control.Monad.State.Strict (StateT (..)) -import Data.Bool (bool) +import Control.Monad.State.Class (gets, modify) +import Control.Monad.State.Strict (StateT (..)) +import Data.Bool (bool) import Data.Coerce import Data.Either import Data.Foldable -import Data.Functor ((<&>)) -import Data.Generics (everything, gcount, mkQ) -import Data.List (sortBy) -import qualified Data.Map as M -import Data.Ord (Down (..), comparing) -import Data.Set (Set) -import qualified Data.Set as S +import Data.Functor ((<&>)) +import Data.Generics (everything, gcount, mkQ) +import Data.Generics.Product (field') +import Data.List (sortBy) +import qualified Data.Map as M +import Data.Monoid (getSum) +import Data.Ord (Down (..), comparing) +import Data.Set (Set) +import qualified Data.Set as S import Development.IDE.GHC.Compat import Ide.Plugin.Tactic.Judgements -import Ide.Plugin.Tactic.Simplify (simplify) +import Ide.Plugin.Tactic.Simplify (simplify) import Ide.Plugin.Tactic.Types -import OccName (HasOccName (occName)) +import OccName (HasOccName (occName)) import Refinery.ProofState import Refinery.Tactic import Refinery.Tactic.Internal @@ -88,8 +87,8 @@ runTactic ctx jdg t = (errs, []) -> Left $ take 50 errs (_, fmap assoc23 -> solns) -> do let sorted = - flip sortBy solns $ comparing $ \(ext, (jdg, holes)) -> - Down $ scoreSolution ext jdg holes + flip sortBy solns $ comparing $ \(ext, (_, holes)) -> + Down $ scoreSolution ext holes case sorted of ((syn, _) : _) -> Right $ @@ -111,39 +110,37 @@ tracePrim :: String -> Trace tracePrim = flip rose [] +------------------------------------------------------------------------------ +-- | Mark that a tactic used the given string in its extract derivation. Mainly +-- used for debugging the search when things go terribly wrong. tracing :: Functor m => String -> TacticT jdg (Synthesized ext) err s m a -> TacticT jdg (Synthesized ext) err s m a -tracing s (TacticT m) - = TacticT $ StateT $ \jdg -> - mapExtract' (mapTrace $ rose s . pure) $ runStateT m jdg +tracing s = mappingExtract (mapTrace $ rose s . pure) ------------------------------------------------------------------------------ --- | Recursion is allowed only when we can prove it is on a structurally --- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller --- pattern val. -guardStructurallySmallerRecursion - :: TacticState - -> Maybe TacticError -guardStructurallySmallerRecursion s = - case head $ ts_recursion_stack s of - Just _ -> Nothing - Nothing -> Just NoProgress +-- | Mark that a tactic performed recursion. Doing so incurs a small penalty in +-- the score. +markRecursion + :: Functor m + => TacticT jdg (Synthesized ext) err s m a + -> TacticT jdg (Synthesized ext) err s m a +markRecursion = mappingExtract (field' @"syn_recursion_count" <>~ 1) ------------------------------------------------------------------------------ --- | Mark that the current recursive call is structurally smaller, due to --- having been matched on a pattern value. --- --- Implemented by setting the top of the 'ts_recursion_stack'. -markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m () -markStructuralySmallerRecursion pv = do - modify $ withRecursionStack $ \case - (_ : bs) -> Just pv : bs - [] -> [] +-- | Map a function over the extract created by a tactic. +mappingExtract + :: Functor m + => (ext -> ext) + -> TacticT jdg ext err s m a + -> TacticT jdg ext err s m a +mappingExtract f (TacticT m) + = TacticT $ StateT $ \jdg -> + mapExtract' f $ runStateT m jdg ------------------------------------------------------------------------------ @@ -154,7 +151,6 @@ markStructuralySmallerRecursion pv = do -- to produce the right test results. scoreSolution :: Synthesized (LHsExpr GhcPs) - -> TacticState -> [Judgement] -> ( Penalize Int -- number of holes , Reward Bool -- all bindings used @@ -164,19 +160,23 @@ scoreSolution , Penalize Int -- number of recursive calls , Penalize Int -- size of extract ) -scoreSolution ext TacticState{..} holes +scoreSolution ext holes = ( Penalize $ length holes , Reward $ S.null $ intro_vals S.\\ used_vals , Penalize $ S.size unused_top_vals , Penalize $ S.size intro_vals , Reward $ S.size used_vals - , Penalize ts_recursion_count + , Penalize $ getSum $ syn_recursion_count ext , Penalize $ solutionSize $ syn_val ext ) where intro_vals = M.keysSet $ hyByName $ syn_scoped ext used_vals = S.intersection intro_vals $ syn_used_vals ext - top_vals = S.fromList . fmap hi_name . filter (isTopLevel . hi_provenance) $ unHypothesis $ syn_scoped ext + top_vals = S.fromList + . fmap hi_name + . filter (isTopLevel . hi_provenance) + . unHypothesis + $ syn_scoped ext unused_top_vals = top_vals S.\\ used_vals @@ -240,6 +240,26 @@ methodHypothesis ty = do ) +------------------------------------------------------------------------------ +-- | Mystical time-traveling combinator for inspecting the extracts produced by +-- a tactic. We can use it to guard that extracts match certain predicates, for +-- example. +-- +-- Note, that this thing is WEIRD. To illustrate: +-- +-- @@ +-- peek f +-- blah +-- @@ +-- +-- Here, @f@ can inspect the extract _produced by @blah@,_ which means the +-- causality appears to go backwards. +-- +-- 'peek' should be exposed directly by @refinery@ in the next release. +peek :: (ext -> TacticT jdg ext err s m ()) -> TacticT jdg ext err s m () +peek k = tactic $ \j -> Subgoal ((), j) $ \e -> proofState (k e) j + + ------------------------------------------------------------------------------ -- | Run the given tactic iff the current hole contains no univars. Skolems and -- already decided univars are OK though. @@ -251,3 +271,4 @@ requireConcreteHole m = do case S.size $ vars S.\\ skolems of 0 -> m _ -> throwError TooPolymorphic + diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs index b388a4cee9..476a6c3232 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Tactics.hs @@ -1,19 +1,24 @@ {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE OverloadedLabels #-} {-# LANGUAGE TupleSections #-} {-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TypeApplications #-} module Ide.Plugin.Tactic.Tactics ( module Ide.Plugin.Tactic.Tactics , runTactic ) where +import Control.Applicative (Alternative(empty)) +import Control.Lens ((&), (%~)) +import Control.Monad (unless) import Control.Monad.Except (throwError) import Control.Monad.Reader.Class (MonadReader (ask)) -import Control.Monad.State.Class import Control.Monad.State.Strict (StateT(..), runStateT) import Data.Foldable +import Data.Generics.Labels () import Data.List import qualified Data.Map as M import Data.Maybe @@ -51,28 +56,33 @@ assume name = rule $ \jdg -> do case M.lookup name $ hyByName $ jHypothesis jdg of Just (hi_type -> ty) -> do unify ty $ jGoal jdg - for_ (M.lookup name $ jPatHypothesis jdg) markStructuralySmallerRecursion - pure $ Synthesized (tracePrim $ "assume " <> occNameString name) - mempty - (S.singleton name) - $ noLoc - $ var' name + pure $ + -- This slightly terrible construct is producing a mostly-empty + -- 'Synthesized'; but there is no monoid instance to do something more + -- reasonable for a default value. + (pure (noLoc $ var' name)) + { syn_trace = tracePrim $ "assume " <> occNameString name + , syn_used_vals = S.singleton name + } Nothing -> throwError $ UndefinedHypothesis name recursion :: TacticsM () recursion = requireConcreteHole $ tracing "recursion" $ do defs <- getCurrentDefinitions - attemptOn (const defs) $ \(name, ty) -> do - -- TODO(sandy): When we can inspect the extract of a TacticsM bind - -- (requires refinery support), this recursion stack stuff is unnecessary. - -- We can just inspect the extract to see i we used any pattern vals, and - -- then be on our merry way. - modify $ pushRecursionStack . countRecursiveCall - ensure guardStructurallySmallerRecursion popRecursionStack $ do - let hy' = recursiveHypothesis defs - localTactic (apply $ HyInfo name RecursivePrv ty) (introduce hy') - <@> fmap (localTactic assumption . filterPosition name) [0..] + attemptOn (const defs) $ \(name, ty) -> markRecursion $ do + -- Peek allows us to look at the extract produced by this block. + peek $ \ext -> do + jdg <- goal + let pat_vals = jPatHypothesis jdg + -- Make sure that the recursive call contains at least one already-bound + -- pattern value. This ensures it is structurally smaller, and thus + -- suggests termination. + unless (any (flip M.member pat_vals) $ syn_used_vals ext) empty + + let hy' = recursiveHypothesis defs + localTactic (apply $ HyInfo name RecursivePrv ty) (introduce hy') + <@> fmap (localTactic assumption . filterPosition name) [0..] ------------------------------------------------------------------------------ @@ -89,15 +99,12 @@ intros = rule $ \jdg -> do hy' = lambdaHypothesis top_hole $ zip vs $ coerce as jdg' = introduce hy' $ withNewGoal (CType b) jdg - Synthesized tr sc uv sg <- newSubgoal jdg' - pure - . Synthesized - (rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") $ pure tr) - (sc <> hy') - uv - . noLoc - . lambda (fmap bvar' vs) - $ unLoc sg + ext <- newSubgoal jdg' + pure $ + ext + & #syn_trace %~ rose ("intros {" <> intercalate ", " (fmap show vs) <> "}") + . pure + & #syn_val %~ noLoc . lambda (fmap bvar' vs) . unLoc ------------------------------------------------------------------------------ @@ -164,16 +171,17 @@ apply hi = requireConcreteHole $ tracing ("apply' " <> show (hi_name hi)) $ do -- see https://github.com/haskell/haskell-language-server/issues/1447 requireNewHoles $ rule $ \jdg -> do unify g (CType ret) - Synthesized tr sc uv sgs + ext <- fmap unzipTrace $ traverse ( newSubgoal . blacklistingDestruct . flip withNewGoal jdg . CType ) args - pure $ Synthesized tr sc (S.insert func uv) - $ noLoc . foldl' (@@) (var' func) - $ fmap unLoc sgs + pure $ + ext + & #syn_used_vals %~ S.insert func + & #syn_val %~ noLoc . foldl' (@@) (var' func) . fmap unLoc ------------------------------------------------------------------------------ diff --git a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs index 0ea9c81c8d..92a1f66c88 100644 --- a/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs +++ b/plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Types.hs @@ -89,21 +89,6 @@ data TacticState = TacticState { ts_skolems :: !(Set TyVar) -- ^ The known skolems. , ts_unifier :: !TCvSubst - -- ^ The current substitution of univars. - , ts_recursion_stack :: ![Maybe PatVal] - -- ^ Stack for tracking whether or not the current recursive call has - -- used at least one smaller pat val. Recursive calls for which this - -- value is 'False' are guaranteed to loop, and must be pruned. - -- - -- TODO(sandy): This thing need not exist; we should just inspect - -- 'syn_used_vals' to see if anything was a pattern val. - , ts_recursion_count :: !Int - -- ^ Number of calls to recursion. We penalize each. - -- - -- TODO(sandy): This thing need not exist; it should just be a field - -- inside of 'Synthesized', but can't implement that without support from - -- refinery directly. Need the ability to get the extract of a TacticT - -- inside of TacticT, first. , ts_unique_gen :: !UniqSupply } deriving stock (Show, Generic) @@ -124,8 +109,6 @@ defaultTacticState = TacticState { ts_skolems = mempty , ts_unifier = emptyTCvSubst - , ts_recursion_stack = mempty - , ts_recursion_count = 0 , ts_unique_gen = unsafeDefaultUniqueSupply } @@ -139,18 +122,6 @@ freshUnique = do pure uniq -withRecursionStack - :: ([Maybe PatVal] -> [Maybe PatVal]) -> TacticState -> TacticState -withRecursionStack f = - field @"ts_recursion_stack" %~ f - -pushRecursionStack :: TacticState -> TacticState -pushRecursionStack = withRecursionStack (Nothing :) - -popRecursionStack :: TacticState -> TacticState -popRecursionStack = withRecursionStack tail - - ------------------------------------------------------------------------------ -- | Describes where hypotheses came from. Used extensively to prune stupid -- solutions from the search space. @@ -261,11 +232,7 @@ newtype ExtractM a = ExtractM { unExtractM :: Reader Context a } ------------------------------------------------------------------------------ -- | Orphan instance for producing holes when attempting to solve tactics. instance MonadExtract (Synthesized (LHsExpr GhcPs)) ExtractM where - hole - = pure - . Synthesized mempty mempty mempty - . noLoc - $ var "_" + hole = pure . pure . noLoc $ var "_" ------------------------------------------------------------------------------ @@ -344,12 +311,14 @@ data Synthesized a = Synthesized -- ^ All of the bindings created to produce the 'syn_val'. , syn_used_vals :: Set OccName -- ^ The values used when synthesizing the 'syn_val'. + , syn_recursion_count :: Sum Int + -- ^ The number of recursive calls , syn_val :: a } - deriving (Eq, Show, Functor, Foldable, Traversable) + deriving (Eq, Show, Functor, Foldable, Traversable, Generic) mapTrace :: (Trace -> Trace) -> Synthesized a -> Synthesized a -mapTrace f (Synthesized tr sc uv a) = Synthesized (f tr) sc uv a +mapTrace f (Synthesized tr sc uv rc a) = Synthesized (f tr) sc uv rc a ------------------------------------------------------------------------------ @@ -357,9 +326,9 @@ mapTrace f (Synthesized tr sc uv a) = Synthesized (f tr) sc uv a -- lawful. But that's only for debug output, so it's not anything I'm concerned -- about. instance Applicative Synthesized where - pure = Synthesized mempty mempty mempty - Synthesized tr1 sc1 uv1 f <*> Synthesized tr2 sc2 uv2 a = - Synthesized (tr1 <> tr2) (sc1 <> sc2) (uv1 <> uv2) $ f a + pure = Synthesized mempty mempty mempty mempty + Synthesized tr1 sc1 uv1 rc1 f <*> Synthesized tr2 sc2 uv2 rc2 a = + Synthesized (tr1 <> tr2) (sc1 <> sc2) (uv1 <> uv2) (rc1 <> rc2) $ f a ------------------------------------------------------------------------------