Skip to content

Better scoring metric for deriving safeHead #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE ViewPatterns #-}

module Ide.Plugin.Tactic.CodeGen where

import Control.Lens ((+~), (%~), (<>~))
import Control.Monad.Except
import Control.Monad.State (MonadState)
import Control.Monad.State.Class (modify)
import Data.Generics.Product (field)
import Data.List
import qualified Data.Map as M
import qualified Data.Set as S
Expand All @@ -31,10 +35,25 @@ useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
useOccName jdg name =
-- Only score points if this is in the local hypothesis
case M.lookup name $ jLocalHypothesis jdg of
Just{} -> modify $ withUsedVals $ S.insert name
Just{} -> modify
$ (withUsedVals $ S.insert name)
. (field @"ts_unused_top_vals" %~ S.delete name)
Nothing -> pure ()


------------------------------------------------------------------------------
-- | Doing recursion incurs a small penalty in the score.
penalizeRecursion :: MonadState TacticState m => m ()
penalizeRecursion = modify $ field @"ts_recursion_penality" +~ 1


------------------------------------------------------------------------------
-- | Insert some values into the unused top values field. These are
-- subsequently removed via 'useOccName'.
addUnusedTopVals :: MonadState TacticState m => S.Set OccName -> m ()
addUnusedTopVals vals = modify $ field @"ts_unused_top_vals" <>~ vals


destructMatches
:: (DataCon -> Judgement -> Rule)
-- ^ How to construct each match
Expand Down
7 changes: 7 additions & 0 deletions plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ introducing ns =
field @"_jHypothesis" <>~ M.fromList ns


------------------------------------------------------------------------------
-- | Add some terms to the ambient hypothesis
introducingAmbient :: [(OccName, a)] -> Judgement' a -> Judgement' a
introducingAmbient ns =
field @"_jAmbientHypothesis" <>~ M.fromList ns


filterPosition :: OccName -> Int -> Judgement -> Judgement
filterPosition defn pos jdg =
withHypothesis (M.filterWithKey go) jdg
Expand Down
20 changes: 18 additions & 2 deletions plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ import Control.Monad.State.Class (gets, modify)
import Control.Monad.State.Strict (StateT (..))
import Data.Coerce
import Data.Either
import Data.Foldable
import Data.Functor ((<&>))
import Data.Generics (mkQ, everything, gcount)
import Data.List (sortBy)
import Data.List (nub, sortBy)
import Data.Ord (comparing, Down(..))
import qualified Data.Set as S
import Development.IDE.GHC.Compat
Expand Down Expand Up @@ -71,7 +72,12 @@ runTactic
-> Either [TacticError] RunTacticResults
runTactic ctx jdg t =
let skolems = tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg
tacticState = defaultTacticState { ts_skolems = skolems }
unused_topvals = nub $ join $ join $ toList $ _jPositionMaps jdg
tacticState =
defaultTacticState
{ ts_skolems = skolems
, ts_unused_top_vals = S.fromList unused_topvals
}
in case partitionEithers
. flip runReader ctx
. unExtractM
Expand Down Expand Up @@ -126,21 +132,31 @@ setRecursionFrameData b = do
[] -> []


------------------------------------------------------------------------------
-- | Given the results of running a tactic, score the solutions by
-- desirability.
--
-- TODO(sandy): This function is completely unprincipled and was just hacked
-- together to produce the right test results.
scoreSolution
:: LHsExpr GhcPs
-> TacticState
-> [Judgement]
-> ( Penalize Int -- number of holes
, Reward Bool -- all bindings used
, Penalize Int -- unused top-level bindings
, Penalize Int -- number of introduced bindings
, Reward Int -- number used bindings
, Penalize Int -- number of recursive calls
, Penalize Int -- size of extract
)
scoreSolution ext TacticState{..} holes
= ( Penalize $ length holes
, Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals
, Penalize $ S.size ts_unused_top_vals
, Penalize $ S.size ts_intro_vals
, Reward $ S.size ts_used_vals
, Penalize $ ts_recursion_penality
, Penalize $ solutionSize ext
)

Expand Down
4 changes: 3 additions & 1 deletion plugins/tactics/src/Ide/Plugin/Tactic/Tactics.hs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ recursion = requireConcreteHole $ tracing "recursion" $ do
defs <- getCurrentDefinitions
attemptOn (const $ fmap fst defs) $ \name -> do
modify $ withRecursionStack (False :)
penalizeRecursion
ensure recursiveCleanup (withRecursionStack tail) $ do
(localTactic (apply name) $ introducing defs)
(localTactic (apply name) $ introducingAmbient defs)
<@> fmap (localTactic assumption . filterPosition name) [0..]


Expand All @@ -88,6 +89,7 @@ intros = rule $ \jdg -> do
let jdg' = introducing (zip vs $ coerce as)
$ withNewGoal (CType b) jdg
modify $ withIntroducedVals $ mappend $ S.fromList vs
when (isTopHole jdg) $ addUnusedTopVals $ S.fromList vs
(tr, sg)
<- newSubgoal
$ bool
Expand Down
13 changes: 13 additions & 0 deletions plugins/tactics/src/Ide/Plugin/Tactic/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,21 @@ instance Show DataCon where
------------------------------------------------------------------------------
data TacticState = TacticState
{ ts_skolems :: !([TyVar])
-- ^ The known skolems.
, ts_unifier :: !(TCvSubst)
-- ^ The current substitution of univars.
, ts_used_vals :: !(Set OccName)
-- ^ Set of values used by tactics.
, ts_intro_vals :: !(Set OccName)
-- ^ Set of values introduced by tactics.
, ts_unused_top_vals :: !(Set OccName)
-- ^ Set of currently unused arguments to the function being defined.
, ts_recursion_stack :: ![Bool]
-- ^ Stack for tracking whether or not the current recursive call has
-- used at least one smaller pat val. Recursive calls for which this
-- value is 'False' are guaranteed to loop, and must be pruned.
, ts_recursion_penality :: !Int
-- ^ Number of calls to recursion. We penalize each.
, ts_unique_gen :: !UniqSupply
} deriving stock (Show, Generic)

Expand All @@ -100,7 +111,9 @@ defaultTacticState =
, ts_unifier = emptyTCvSubst
, ts_used_vals = mempty
, ts_intro_vals = mempty
, ts_unused_top_vals = mempty
, ts_recursion_stack = mempty
, ts_recursion_penality = 0
, ts_unique_gen = unsafeDefaultUniqueSupply
}

Expand Down
1 change: 1 addition & 0 deletions test/functional/Tactic.hs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ tests = testGroup
, goldenTest "GoldenShowMapChar.hs" 2 8 Auto ""
, goldenTest "GoldenSuperclass.hs" 7 8 Auto ""
, goldenTest "GoldenApplicativeThen.hs" 2 11 Auto ""
, goldenTest "GoldenSafeHead.hs" 2 12 Auto ""
]


Expand Down
2 changes: 2 additions & 0 deletions test/testdata/tactic/GoldenSafeHead.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
safeHead :: [x] -> Maybe x
safeHead = _
5 changes: 5 additions & 0 deletions test/testdata/tactic/GoldenSafeHead.hs.expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
safeHead :: [x] -> Maybe x
safeHead = (\ l_x
-> case l_x of
[] -> Nothing
(x : l_x2) -> Just x)