Skip to content

Commit 76dadaa

Browse files
Remove recursion tracking from TacticState (#1453)
* Remove recursion from TacticState * Use lenses to simplify Synthesize changes * Minor cleanup and haddock * Commentary on 'recursion' Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent eff69a7 commit 76dadaa

File tree

5 files changed

+134
-147
lines changed

5 files changed

+134
-147
lines changed

plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/CodeGen.hs

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
{-# LANGUAGE FlexibleContexts #-}
2+
{-# LANGUAGE OverloadedLabels #-}
23
{-# LANGUAGE TupleSections #-}
34
{-# LANGUAGE TypeApplications #-}
45

@@ -7,9 +8,10 @@ module Ide.Plugin.Tactic.CodeGen
78
, module Ide.Plugin.Tactic.CodeGen.Utils
89
) where
910

10-
import Control.Lens ((+~))
11+
12+
import Control.Lens ((%~), (<>~), (&))
1113
import Control.Monad.Except
12-
import Data.Generics.Product (field)
14+
import Data.Generics.Labels ()
1315
import Data.List
1416
import qualified Data.Set as S
1517
import Data.Traversable
@@ -29,13 +31,6 @@ import Ide.Plugin.Tactic.Types
2931
import Type hiding (Var)
3032

3133

32-
33-
------------------------------------------------------------------------------
34-
-- | Doing recursion incurs a small penalty in the score.
35-
countRecursiveCall :: TacticState -> TacticState
36-
countRecursiveCall = field @"ts_recursion_count" +~ 1
37-
38-
3934
destructMatches
4035
:: (DataCon -> Judgement -> Rule)
4136
-- ^ How to construct each match
@@ -62,16 +57,12 @@ destructMatches f scrut t jdg = do
6257
$ coerce args
6358
j = introduce hy'
6459
$ withNewGoal g jdg
65-
Synthesized tr sc uv sg <- f dc j
66-
pure
67-
$ Synthesized
68-
( rose ("match " <> show dc <> " {" <>
69-
intercalate ", " (fmap show names) <> "}")
70-
$ pure tr)
71-
(sc <> hy')
72-
uv
73-
$ match [mkDestructPat dc names]
74-
$ unLoc sg
60+
ext <- f dc j
61+
pure $ ext
62+
& #syn_trace %~ rose ("match " <> show dc <> " {" <> intercalate ", " (fmap show names) <> "}")
63+
. pure
64+
& #syn_scoped <>~ hy'
65+
& #syn_val %~ match [mkDestructPat dc names] . unLoc
7566

7667

7768
------------------------------------------------------------------------------
@@ -138,19 +129,16 @@ destruct' :: (DataCon -> Judgement -> Rule) -> HyInfo CType -> Judgement -> Rule
138129
destruct' f hi jdg = do
139130
when (isDestructBlacklisted jdg) $ throwError NoApplicableTactic
140131
let term = hi_name hi
141-
Synthesized tr sc uv ms
132+
ext
142133
<- destructMatches
143134
f
144135
(Just term)
145136
(hi_type hi)
146137
$ disallowing AlreadyDestructed [term] jdg
147-
pure
148-
$ Synthesized
149-
(rose ("destruct " <> show term) $ pure tr)
150-
sc
151-
(S.insert term uv)
152-
$ noLoc
153-
$ case' (var' term) ms
138+
pure $ ext
139+
& #syn_trace %~ rose ("destruct " <> show term) . pure
140+
& #syn_used_vals %~ S.insert term
141+
& #syn_val %~ noLoc . case' (var' term)
154142

155143

156144
------------------------------------------------------------------------------
@@ -176,7 +164,7 @@ buildDataCon
176164
-> RuleM (Synthesized (LHsExpr GhcPs))
177165
buildDataCon jdg dc tyapps = do
178166
let args = dataConInstOrigArgTys' dc tyapps
179-
Synthesized tr sc uv sgs
167+
ext
180168
<- fmap unzipTrace
181169
$ traverse ( \(arg, n) ->
182170
newSubgoal
@@ -185,7 +173,7 @@ buildDataCon jdg dc tyapps = do
185173
. flip withNewGoal jdg
186174
$ CType arg
187175
) $ zip args [0..]
188-
pure
189-
$ Synthesized (rose (show dc) $ pure tr) sc uv
190-
$ mkCon dc sgs
176+
pure $ ext
177+
& #syn_trace %~ rose (show dc) . pure
178+
& #syn_val %~ mkCon dc
191179

plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/KnownStrategies/QuickCheck.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ deriveArbitrary = do
5050
-- But maybe it's fine for known rules?
5151
mempty
5252
mempty
53+
mempty
5354
$ noLoc $
5455
let' [valBind (fromString "terminal") $ list $ fmap genExpr terminal] $
5556
appDollar (mkFunc "sized") $ lambda [bvar' (mkVarOcc "n")] $

plugins/hls-tactics-plugin/src/Ide/Plugin/Tactic/Machinery.hs

Lines changed: 68 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,35 @@
66
{-# LANGUAGE MultiParamTypeClasses #-}
77
{-# LANGUAGE RecordWildCards #-}
88
{-# LANGUAGE ScopedTypeVariables #-}
9+
{-# LANGUAGE TypeApplications #-}
910
{-# LANGUAGE ViewPatterns #-}
10-
{-# OPTIONS_GHC -fno-warn-orphans #-}
1111

12-
module Ide.Plugin.Tactic.Machinery
13-
( module Ide.Plugin.Tactic.Machinery
14-
) where
12+
module Ide.Plugin.Tactic.Machinery where
1513

16-
import Class (Class (classTyVars))
17-
import Control.Arrow
14+
import Class (Class (classTyVars))
15+
import Control.Lens ((<>~))
1816
import Control.Monad.Error.Class
1917
import Control.Monad.Reader
20-
import Control.Monad.State (MonadState (..))
21-
import Control.Monad.State.Class (gets, modify)
22-
import Control.Monad.State.Strict (StateT (..))
23-
import Data.Bool (bool)
18+
import Control.Monad.State.Class (gets, modify)
19+
import Control.Monad.State.Strict (StateT (..))
20+
import Data.Bool (bool)
2421
import Data.Coerce
2522
import Data.Either
2623
import Data.Foldable
27-
import Data.Functor ((<&>))
28-
import Data.Generics (everything, gcount, mkQ)
29-
import Data.List (sortBy)
30-
import qualified Data.Map as M
31-
import Data.Ord (Down (..), comparing)
32-
import Data.Set (Set)
33-
import qualified Data.Set as S
24+
import Data.Functor ((<&>))
25+
import Data.Generics (everything, gcount, mkQ)
26+
import Data.Generics.Product (field')
27+
import Data.List (sortBy)
28+
import qualified Data.Map as M
29+
import Data.Monoid (getSum)
30+
import Data.Ord (Down (..), comparing)
31+
import Data.Set (Set)
32+
import qualified Data.Set as S
3433
import Development.IDE.GHC.Compat
3534
import Ide.Plugin.Tactic.Judgements
36-
import Ide.Plugin.Tactic.Simplify (simplify)
35+
import Ide.Plugin.Tactic.Simplify (simplify)
3736
import Ide.Plugin.Tactic.Types
38-
import OccName (HasOccName (occName))
37+
import OccName (HasOccName (occName))
3938
import Refinery.ProofState
4039
import Refinery.Tactic
4140
import Refinery.Tactic.Internal
@@ -88,8 +87,8 @@ runTactic ctx jdg t =
8887
(errs, []) -> Left $ take 50 errs
8988
(_, fmap assoc23 -> solns) -> do
9089
let sorted =
91-
flip sortBy solns $ comparing $ \(ext, (jdg, holes)) ->
92-
Down $ scoreSolution ext jdg holes
90+
flip sortBy solns $ comparing $ \(ext, (_, holes)) ->
91+
Down $ scoreSolution ext holes
9392
case sorted of
9493
((syn, _) : _) ->
9594
Right $
@@ -111,39 +110,37 @@ tracePrim :: String -> Trace
111110
tracePrim = flip rose []
112111

113112

113+
------------------------------------------------------------------------------
114+
-- | Mark that a tactic used the given string in its extract derivation. Mainly
115+
-- used for debugging the search when things go terribly wrong.
114116
tracing
115117
:: Functor m
116118
=> String
117119
-> TacticT jdg (Synthesized ext) err s m a
118120
-> TacticT jdg (Synthesized ext) err s m a
119-
tracing s (TacticT m)
120-
= TacticT $ StateT $ \jdg ->
121-
mapExtract' (mapTrace $ rose s . pure) $ runStateT m jdg
121+
tracing s = mappingExtract (mapTrace $ rose s . pure)
122122

123123

124124
------------------------------------------------------------------------------
125-
-- | Recursion is allowed only when we can prove it is on a structurally
126-
-- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller
127-
-- pattern val.
128-
guardStructurallySmallerRecursion
129-
:: TacticState
130-
-> Maybe TacticError
131-
guardStructurallySmallerRecursion s =
132-
case head $ ts_recursion_stack s of
133-
Just _ -> Nothing
134-
Nothing -> Just NoProgress
125+
-- | Mark that a tactic performed recursion. Doing so incurs a small penalty in
126+
-- the score.
127+
markRecursion
128+
:: Functor m
129+
=> TacticT jdg (Synthesized ext) err s m a
130+
-> TacticT jdg (Synthesized ext) err s m a
131+
markRecursion = mappingExtract (field' @"syn_recursion_count" <>~ 1)
135132

136133

137134
------------------------------------------------------------------------------
138-
-- | Mark that the current recursive call is structurally smaller, due to
139-
-- having been matched on a pattern value.
140-
--
141-
-- Implemented by setting the top of the 'ts_recursion_stack'.
142-
markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m ()
143-
markStructuralySmallerRecursion pv = do
144-
modify $ withRecursionStack $ \case
145-
(_ : bs) -> Just pv : bs
146-
[] -> []
135+
-- | Map a function over the extract created by a tactic.
136+
mappingExtract
137+
:: Functor m
138+
=> (ext -> ext)
139+
-> TacticT jdg ext err s m a
140+
-> TacticT jdg ext err s m a
141+
mappingExtract f (TacticT m)
142+
= TacticT $ StateT $ \jdg ->
143+
mapExtract' f $ runStateT m jdg
147144

148145

149146
------------------------------------------------------------------------------
@@ -154,7 +151,6 @@ markStructuralySmallerRecursion pv = do
154151
-- to produce the right test results.
155152
scoreSolution
156153
:: Synthesized (LHsExpr GhcPs)
157-
-> TacticState
158154
-> [Judgement]
159155
-> ( Penalize Int -- number of holes
160156
, Reward Bool -- all bindings used
@@ -164,19 +160,23 @@ scoreSolution
164160
, Penalize Int -- number of recursive calls
165161
, Penalize Int -- size of extract
166162
)
167-
scoreSolution ext TacticState{..} holes
163+
scoreSolution ext holes
168164
= ( Penalize $ length holes
169165
, Reward $ S.null $ intro_vals S.\\ used_vals
170166
, Penalize $ S.size unused_top_vals
171167
, Penalize $ S.size intro_vals
172168
, Reward $ S.size used_vals
173-
, Penalize ts_recursion_count
169+
, Penalize $ getSum $ syn_recursion_count ext
174170
, Penalize $ solutionSize $ syn_val ext
175171
)
176172
where
177173
intro_vals = M.keysSet $ hyByName $ syn_scoped ext
178174
used_vals = S.intersection intro_vals $ syn_used_vals ext
179-
top_vals = S.fromList . fmap hi_name . filter (isTopLevel . hi_provenance) $ unHypothesis $ syn_scoped ext
175+
top_vals = S.fromList
176+
. fmap hi_name
177+
. filter (isTopLevel . hi_provenance)
178+
. unHypothesis
179+
$ syn_scoped ext
180180
unused_top_vals = top_vals S.\\ used_vals
181181

182182

@@ -240,6 +240,26 @@ methodHypothesis ty = do
240240
)
241241

242242

243+
------------------------------------------------------------------------------
244+
-- | Mystical time-traveling combinator for inspecting the extracts produced by
245+
-- a tactic. We can use it to guard that extracts match certain predicates, for
246+
-- example.
247+
--
248+
-- Note, that this thing is WEIRD. To illustrate:
249+
--
250+
-- @@
251+
-- peek f
252+
-- blah
253+
-- @@
254+
--
255+
-- Here, @f@ can inspect the extract _produced by @blah@,_ which means the
256+
-- causality appears to go backwards.
257+
--
258+
-- 'peek' should be exposed directly by @refinery@ in the next release.
259+
peek :: (ext -> TacticT jdg ext err s m ()) -> TacticT jdg ext err s m ()
260+
peek k = tactic $ \j -> Subgoal ((), j) $ \e -> proofState (k e) j
261+
262+
243263
------------------------------------------------------------------------------
244264
-- | Run the given tactic iff the current hole contains no univars. Skolems and
245265
-- already decided univars are OK though.
@@ -251,3 +271,4 @@ requireConcreteHole m = do
251271
case S.size $ vars S.\\ skolems of
252272
0 -> m
253273
_ -> throwError TooPolymorphic
274+

0 commit comments

Comments
 (0)