6
6
{-# LANGUAGE MultiParamTypeClasses #-}
7
7
{-# LANGUAGE RecordWildCards #-}
8
8
{-# LANGUAGE ScopedTypeVariables #-}
9
+ {-# LANGUAGE TypeApplications #-}
9
10
{-# LANGUAGE ViewPatterns #-}
10
- {-# OPTIONS_GHC -fno-warn-orphans #-}
11
11
12
- module Ide.Plugin.Tactic.Machinery
13
- ( module Ide.Plugin.Tactic.Machinery
14
- ) where
12
+ module Ide.Plugin.Tactic.Machinery where
15
13
16
- import Class (Class (classTyVars ))
17
- import Control.Arrow
14
+ import Class (Class (classTyVars ))
15
+ import Control.Lens ( (<>~) )
18
16
import Control.Monad.Error.Class
19
17
import Control.Monad.Reader
20
- import Control.Monad.State (MonadState (.. ))
21
- import Control.Monad.State.Class (gets , modify )
22
- import Control.Monad.State.Strict (StateT (.. ))
23
- import Data.Bool (bool )
18
+ import Control.Monad.State.Class (gets , modify )
19
+ import Control.Monad.State.Strict (StateT (.. ))
20
+ import Data.Bool (bool )
24
21
import Data.Coerce
25
22
import Data.Either
26
23
import Data.Foldable
27
- import Data.Functor ((<&>) )
28
- import Data.Generics (everything , gcount , mkQ )
29
- import Data.List (sortBy )
30
- import qualified Data.Map as M
31
- import Data.Ord (Down (.. ), comparing )
32
- import Data.Set (Set )
33
- import qualified Data.Set as S
24
+ import Data.Functor ((<&>) )
25
+ import Data.Generics (everything , gcount , mkQ )
26
+ import Data.Generics.Product (field' )
27
+ import Data.List (sortBy )
28
+ import qualified Data.Map as M
29
+ import Data.Monoid (getSum )
30
+ import Data.Ord (Down (.. ), comparing )
31
+ import Data.Set (Set )
32
+ import qualified Data.Set as S
34
33
import Development.IDE.GHC.Compat
35
34
import Ide.Plugin.Tactic.Judgements
36
- import Ide.Plugin.Tactic.Simplify (simplify )
35
+ import Ide.Plugin.Tactic.Simplify (simplify )
37
36
import Ide.Plugin.Tactic.Types
38
- import OccName (HasOccName (occName ))
37
+ import OccName (HasOccName (occName ))
39
38
import Refinery.ProofState
40
39
import Refinery.Tactic
41
40
import Refinery.Tactic.Internal
@@ -88,8 +87,8 @@ runTactic ctx jdg t =
88
87
(errs, [] ) -> Left $ take 50 errs
89
88
(_, fmap assoc23 -> solns) -> do
90
89
let sorted =
91
- flip sortBy solns $ comparing $ \ (ext, (jdg , holes)) ->
92
- Down $ scoreSolution ext jdg holes
90
+ flip sortBy solns $ comparing $ \ (ext, (_ , holes)) ->
91
+ Down $ scoreSolution ext holes
93
92
case sorted of
94
93
((syn, _) : _) ->
95
94
Right $
@@ -111,39 +110,37 @@ tracePrim :: String -> Trace
111
110
tracePrim = flip rose []
112
111
113
112
113
+ ------------------------------------------------------------------------------
114
+ -- | Mark that a tactic used the given string in its extract derivation. Mainly
115
+ -- used for debugging the search when things go terribly wrong.
114
116
tracing
115
117
:: Functor m
116
118
=> String
117
119
-> TacticT jdg (Synthesized ext ) err s m a
118
120
-> TacticT jdg (Synthesized ext ) err s m a
119
- tracing s (TacticT m)
120
- = TacticT $ StateT $ \ jdg ->
121
- mapExtract' (mapTrace $ rose s . pure ) $ runStateT m jdg
121
+ tracing s = mappingExtract (mapTrace $ rose s . pure )
122
122
123
123
124
124
------------------------------------------------------------------------------
125
- -- | Recursion is allowed only when we can prove it is on a structurally
126
- -- smaller argument. The top of the 'ts_recursion_stack' witnesses the smaller
127
- -- pattern val.
128
- guardStructurallySmallerRecursion
129
- :: TacticState
130
- -> Maybe TacticError
131
- guardStructurallySmallerRecursion s =
132
- case head $ ts_recursion_stack s of
133
- Just _ -> Nothing
134
- Nothing -> Just NoProgress
125
+ -- | Mark that a tactic performed recursion. Doing so incurs a small penalty in
126
+ -- the score.
127
+ markRecursion
128
+ :: Functor m
129
+ => TacticT jdg (Synthesized ext ) err s m a
130
+ -> TacticT jdg (Synthesized ext ) err s m a
131
+ markRecursion = mappingExtract (field' @ " syn_recursion_count" <>~ 1 )
135
132
136
133
137
134
------------------------------------------------------------------------------
138
- -- | Mark that the current recursive call is structurally smaller, due to
139
- -- having been matched on a pattern value.
140
- --
141
- -- Implemented by setting the top of the 'ts_recursion_stack'.
142
- markStructuralySmallerRecursion :: MonadState TacticState m => PatVal -> m ()
143
- markStructuralySmallerRecursion pv = do
144
- modify $ withRecursionStack $ \ case
145
- (_ : bs) -> Just pv : bs
146
- [] -> []
135
+ -- | Map a function over the extract created by a tactic.
136
+ mappingExtract
137
+ :: Functor m
138
+ => ( ext -> ext )
139
+ -> TacticT jdg ext err s m a
140
+ -> TacticT jdg ext err s m a
141
+ mappingExtract f ( TacticT m)
142
+ = TacticT $ StateT $ \ jdg ->
143
+ mapExtract' f $ runStateT m jdg
147
144
148
145
149
146
------------------------------------------------------------------------------
@@ -154,7 +151,6 @@ markStructuralySmallerRecursion pv = do
154
151
-- to produce the right test results.
155
152
scoreSolution
156
153
:: Synthesized (LHsExpr GhcPs )
157
- -> TacticState
158
154
-> [Judgement ]
159
155
-> ( Penalize Int -- number of holes
160
156
, Reward Bool -- all bindings used
@@ -164,19 +160,23 @@ scoreSolution
164
160
, Penalize Int -- number of recursive calls
165
161
, Penalize Int -- size of extract
166
162
)
167
- scoreSolution ext TacticState { .. } holes
163
+ scoreSolution ext holes
168
164
= ( Penalize $ length holes
169
165
, Reward $ S. null $ intro_vals S. \\ used_vals
170
166
, Penalize $ S. size unused_top_vals
171
167
, Penalize $ S. size intro_vals
172
168
, Reward $ S. size used_vals
173
- , Penalize ts_recursion_count
169
+ , Penalize $ getSum $ syn_recursion_count ext
174
170
, Penalize $ solutionSize $ syn_val ext
175
171
)
176
172
where
177
173
intro_vals = M. keysSet $ hyByName $ syn_scoped ext
178
174
used_vals = S. intersection intro_vals $ syn_used_vals ext
179
- top_vals = S. fromList . fmap hi_name . filter (isTopLevel . hi_provenance) $ unHypothesis $ syn_scoped ext
175
+ top_vals = S. fromList
176
+ . fmap hi_name
177
+ . filter (isTopLevel . hi_provenance)
178
+ . unHypothesis
179
+ $ syn_scoped ext
180
180
unused_top_vals = top_vals S. \\ used_vals
181
181
182
182
@@ -240,6 +240,26 @@ methodHypothesis ty = do
240
240
)
241
241
242
242
243
+ ------------------------------------------------------------------------------
244
+ -- | Mystical time-traveling combinator for inspecting the extracts produced by
245
+ -- a tactic. We can use it to guard that extracts match certain predicates, for
246
+ -- example.
247
+ --
248
+ -- Note, that this thing is WEIRD. To illustrate:
249
+ --
250
+ -- @@
251
+ -- peek f
252
+ -- blah
253
+ -- @@
254
+ --
255
+ -- Here, @f@ can inspect the extract _produced by @blah@,_ which means the
256
+ -- causality appears to go backwards.
257
+ --
258
+ -- 'peek' should be exposed directly by @refinery@ in the next release.
259
+ peek :: (ext -> TacticT jdg ext err s m () ) -> TacticT jdg ext err s m ()
260
+ peek k = tactic $ \ j -> Subgoal (() , j) $ \ e -> proofState (k e) j
261
+
262
+
243
263
------------------------------------------------------------------------------
244
264
-- | Run the given tactic iff the current hole contains no univars. Skolems and
245
265
-- already decided univars are OK though.
@@ -251,3 +271,4 @@ requireConcreteHole m = do
251
271
case S. size $ vars S. \\ skolems of
252
272
0 -> m
253
273
_ -> throwError TooPolymorphic
274
+
0 commit comments