Skip to content

Commit 91d2711

Browse files
authored
Tactics support for using given constraints (#534)
This PR allows tactics to use methods from given constraints, meaning it can solve holes like this: ```haskell showMe :: Show a => a -> String showMe = _ ``` It will not, however, discover instances. So this one *won't* solve: ```haskell showMe :: Int -> String showMe = _ ``` There's quite a lot of finicky details going on in order to support this. The primary challenge is that our types are running after the typechecker has finished, meaning it's already solved the constraints and inlined their evidence. Our solution is to look up the written polymorphic type and unify it with the final, typechecked type. We can use the polymorphic type to find the theta context, and instantiate every class method in the theta. In addition, this PR fixes a subtle bug in our unification code, which could cause skolems to unify in some circumstances. Furthermore, it adds a tie-breaker to the scoring metric to prefer shorter programs.
1 parent d13f670 commit 91d2711

22 files changed

+411
-116
lines changed

plugins/tactics/hls-tactics-plugin.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ test-suite tests
8484
main-is: Main.hs
8585
other-modules:
8686
AutoTupleSpec
87+
UnificationSpec
8788
hs-source-dirs:
8889
test
8990
ghc-options: -Wall -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N

plugins/tactics/src/Ide/Plugin/Tactic.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,9 +265,11 @@ judgementForHole state nfp range = do
265265
$ getDefiningBindings binds rss)
266266
tcg
267267
hyps = hypothesisFromBindings rss binds
268+
ambient = M.fromList $ contextMethodHypothesis ctx
268269
pure ( resulting_range
269270
, mkFirstJudgement
270271
hyps
272+
ambient
271273
(isRhsHole rss tcs)
272274
(maybe
273275
mempty

plugins/tactics/src/Ide/Plugin/Tactic/CodeGen.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ import Type hiding (Var)
2929

3030
useOccName :: MonadState TacticState m => Judgement -> OccName -> m ()
3131
useOccName jdg name =
32-
case M.lookup name $ jHypothesis jdg of
32+
-- Only score points if this is in the local hypothesis
33+
case M.lookup name $ jLocalHypothesis jdg of
3334
Just{} -> modify $ withUsedVals $ S.insert name
3435
Nothing -> pure ()
3536

plugins/tactics/src/Ide/Plugin/Tactic/Context.hs

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,50 @@ import Development.IDE.GHC.Compat
1010
import Ide.Plugin.Tactic.Types
1111
import OccName
1212
import TcRnTypes
13+
import Ide.Plugin.Tactic.GHC (tacticsThetaTy)
14+
import Ide.Plugin.Tactic.Machinery (methodHypothesis)
15+
import Data.Maybe (mapMaybe)
16+
import Data.List
17+
import TcType (substTy, tcSplitSigmaTy)
18+
import Unify (tcUnifyTy)
1319

1420

1521
mkContext :: [(OccName, CType)] -> TcGblEnv -> Context
16-
mkContext locals
17-
= Context locals
18-
. fmap splitId
19-
. (getFunBindId =<<)
20-
. fmap unLoc
21-
. bagToList
22-
. tcg_binds
22+
mkContext locals tcg = Context
23+
{ ctxDefiningFuncs = locals
24+
, ctxModuleFuncs = fmap splitId
25+
. (getFunBindId =<<)
26+
. fmap unLoc
27+
. bagToList
28+
$ tcg_binds tcg
29+
}
30+
31+
32+
------------------------------------------------------------------------------
33+
-- | Find all of the class methods that exist from the givens in the context.
34+
contextMethodHypothesis :: Context -> [(OccName, CType)]
35+
contextMethodHypothesis ctx
36+
= join
37+
. concatMap
38+
( mapMaybe methodHypothesis
39+
. tacticsThetaTy
40+
. unCType
41+
)
42+
. mapMaybe (definedThetaType ctx)
43+
. fmap fst
44+
$ ctxDefiningFuncs ctx
45+
46+
47+
------------------------------------------------------------------------------
48+
-- | Given the name of a function that exists in 'ctxDefiningFuncs', get its
49+
-- theta type.
50+
definedThetaType :: Context -> OccName -> Maybe CType
51+
definedThetaType ctx name = do
52+
(_, CType mono) <- find ((== name) . fst) $ ctxDefiningFuncs ctx
53+
(_, CType poly) <- find ((== name) . fst) $ ctxModuleFuncs ctx
54+
let (_, _, poly') = tcSplitSigmaTy poly
55+
subst <- tcUnifyTy poly' mono
56+
pure $ CType $ substTy subst $ snd $ splitForAllTys poly
2357

2458

2559
splitId :: Id -> (OccName, CType)

plugins/tactics/src/Ide/Plugin/Tactic/GHC.hs

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1-
{-# LANGUAGE CPP #-}
2-
{-# LANGUAGE PatternSynonyms #-}
3-
{-# LANGUAGE ViewPatterns #-}
1+
{-# LANGUAGE CPP #-}
2+
{-# LANGUAGE FlexibleContexts #-}
3+
{-# LANGUAGE PatternSynonyms #-}
4+
{-# LANGUAGE ViewPatterns #-}
45

56
module Ide.Plugin.Tactic.GHC where
67

7-
import Data.Maybe (isJust)
8-
import Development.IDE.GHC.Compat
9-
import OccName
10-
import TcType
11-
import TyCoRep
12-
import Type
13-
import TysWiredIn (intTyCon, floatTyCon, doubleTyCon, charTyCon)
14-
import Unique
15-
import Var
8+
import Control.Monad.State
9+
import qualified Data.Map as M
10+
import Data.Maybe (isJust)
11+
import Data.Traversable
12+
import Development.IDE.GHC.Compat
13+
import Generics.SYB (mkT, everywhere)
14+
import Ide.Plugin.Tactic.Types
15+
import OccName
16+
import TcType
17+
import TyCoRep
18+
import Type
19+
import TysWiredIn (intTyCon, floatTyCon, doubleTyCon, charTyCon)
20+
import Unique
21+
import Var
1622

1723

1824
tcTyVar_maybe :: Type -> Maybe Var
@@ -43,8 +49,44 @@ cloneTyVar t =
4349
------------------------------------------------------------------------------
4450
-- | Is this a function type?
4551
isFunction :: Type -> Bool
46-
isFunction (tcSplitFunTys -> ((_:_), _)) = True
47-
isFunction _ = False
52+
isFunction (tacticsSplitFunTy -> (_, _, [], _)) = False
53+
isFunction _ = True
54+
55+
56+
------------------------------------------------------------------------------
57+
-- | Split a function, also splitting out its quantified variables and theta
58+
-- context.
59+
tacticsSplitFunTy :: Type -> ([TyVar], ThetaType, [Type], Type)
60+
tacticsSplitFunTy t
61+
= let (vars, theta, t') = tcSplitSigmaTy t
62+
(args, res) = tcSplitFunTys t'
63+
in (vars, theta, args, res)
64+
65+
66+
------------------------------------------------------------------------------
67+
-- | Rip the theta context out of a regular type.
68+
tacticsThetaTy :: Type -> ThetaType
69+
tacticsThetaTy (tcSplitSigmaTy -> (_, theta, _)) = theta
70+
71+
72+
------------------------------------------------------------------------------
73+
-- | Instantiate all of the quantified type variables in a type with fresh
74+
-- skolems.
75+
freshTyvars :: MonadState TacticState m => Type -> m Type
76+
freshTyvars t = do
77+
let (tvs, _, _, _) = tacticsSplitFunTy t
78+
reps <- fmap M.fromList
79+
$ for tvs $ \tv -> do
80+
uniq <- freshUnique
81+
pure $ (tv, setTyVarUnique tv uniq)
82+
pure $
83+
everywhere
84+
(mkT $ \tv ->
85+
case M.lookup tv reps of
86+
Just tv' -> tv'
87+
Nothing -> tv
88+
) t
89+
4890

4991
------------------------------------------------------------------------------
5092
-- | Is this an algebraic type?

plugins/tactics/src/Ide/Plugin/Tactic/Judgements.hs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,17 @@ disallowing ns =
162162
field @"_jHypothesis" %~ flip M.withoutKeys (S.fromList ns)
163163

164164

165+
------------------------------------------------------------------------------
166+
-- | The hypothesis, consisting of local terms and the ambient environment
167+
-- (includes and class methods.)
165168
jHypothesis :: Judgement' a -> Map OccName a
166-
jHypothesis = _jHypothesis
169+
jHypothesis = _jHypothesis <> _jAmbientHypothesis
170+
171+
172+
------------------------------------------------------------------------------
173+
-- | Just the local hypothesis.
174+
jLocalHypothesis :: Judgement' a -> Map OccName a
175+
jLocalHypothesis = _jHypothesis
167176

168177

169178
isPatVal :: Judgement' a -> OccName -> Bool
@@ -191,13 +200,15 @@ substJdg :: TCvSubst -> Judgement -> Judgement
191200
substJdg subst = fmap $ coerce . substTy subst . coerce
192201

193202
mkFirstJudgement
194-
:: M.Map OccName CType
203+
:: M.Map OccName CType -- ^ local hypothesis
204+
-> M.Map OccName CType -- ^ ambient hypothesis
195205
-> Bool -- ^ are we in the top level rhs hole?
196206
-> M.Map OccName [[OccName]] -- ^ existing pos vals
197207
-> Type
198208
-> Judgement' CType
199-
mkFirstJudgement hy top posvals goal = Judgement
209+
mkFirstJudgement hy ambient top posvals goal = Judgement
200210
{ _jHypothesis = hy
211+
, _jAmbientHypothesis = ambient
201212
, _jDestructed = mempty
202213
, _jPatternVals = mempty
203214
, _jBlacklistDestruct = False

plugins/tactics/src/Ide/Plugin/Tactic/Machinery.hs

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
{-# LANGUAGE ScopedTypeVariables #-}
12
{-# LANGUAGE DeriveFunctor #-}
23
{-# LANGUAGE DeriveGeneric #-}
34
{-# LANGUAGE DerivingStrategies #-}
@@ -17,6 +18,7 @@ module Ide.Plugin.Tactic.Machinery
1718
( module Ide.Plugin.Tactic.Machinery
1819
) where
1920

21+
import Class (Class(classTyVars))
2022
import Control.Arrow
2123
import Control.Monad.Error.Class
2224
import Control.Monad.Reader
@@ -25,12 +27,15 @@ import Control.Monad.State.Class (gets, modify)
2527
import Control.Monad.State.Strict (StateT (..))
2628
import Data.Coerce
2729
import Data.Either
28-
import Data.List (intercalate, sortBy)
30+
import Data.Functor ((<&>))
31+
import Data.Generics (mkQ, everything, gcount)
32+
import Data.List (sortBy)
2933
import Data.Ord (comparing, Down(..))
3034
import qualified Data.Set as S
3135
import Development.IDE.GHC.Compat
3236
import Ide.Plugin.Tactic.Judgements
3337
import Ide.Plugin.Tactic.Types
38+
import OccName (HasOccName(occName))
3439
import Refinery.ProofState
3540
import Refinery.Tactic
3641
import Refinery.Tactic.Internal
@@ -74,7 +79,8 @@ runTactic ctx jdg t =
7479
(errs, []) -> Left $ take 50 $ errs
7580
(_, fmap assoc23 -> solns) -> do
7681
let sorted =
77-
sortBy (comparing $ Down . uncurry scoreSolution . snd) solns
82+
flip sortBy solns $ comparing $ \((_, ext), (jdg, holes)) ->
83+
Down $ scoreSolution ext jdg holes
7884
case sorted of
7985
(((tr, ext), _) : _) ->
8086
Right
@@ -121,56 +127,97 @@ setRecursionFrameData b = do
121127

122128

123129
scoreSolution
124-
:: TacticState
130+
:: LHsExpr GhcPs
131+
-> TacticState
125132
-> [Judgement]
126133
-> ( Penalize Int -- number of holes
127134
, Reward Bool -- all bindings used
128135
, Penalize Int -- number of introduced bindings
129136
, Reward Int -- number used bindings
137+
, Penalize Int -- size of extract
130138
)
131-
scoreSolution TacticState{..} holes
139+
scoreSolution ext TacticState{..} holes
132140
= ( Penalize $ length holes
133-
, Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals
141+
, Reward $ S.null $ ts_intro_vals S.\\ ts_used_vals
134142
, Penalize $ S.size ts_intro_vals
135-
, Reward $ S.size ts_used_vals
143+
, Reward $ S.size ts_used_vals
144+
, Penalize $ solutionSize ext
136145
)
137146

138147

148+
------------------------------------------------------------------------------
149+
-- | Compute the number of 'LHsExpr' nodes; used as a rough metric for code
150+
-- size.
151+
solutionSize :: LHsExpr GhcPs -> Int
152+
solutionSize = everything (+) $ gcount $ mkQ False $ \case
153+
(_ :: LHsExpr GhcPs) -> True
154+
155+
139156
newtype Penalize a = Penalize a
140157
deriving (Eq, Ord, Show) via (Down a)
141158

142159
newtype Reward a = Reward a
143160
deriving (Eq, Ord, Show) via a
144161

145162

163+
------------------------------------------------------------------------------
164+
-- | Like 'tcUnifyTy', but takes a list of skolems to prevent unification of.
165+
tryUnifyUnivarsButNotSkolems :: [TyVar] -> CType -> CType -> Maybe TCvSubst
166+
tryUnifyUnivarsButNotSkolems skolems goal inst =
167+
case tcUnifyTysFG (skolemsOf skolems) [unCType inst] [unCType goal] of
168+
Unifiable subst -> pure subst
169+
_ -> Nothing
170+
146171

147172
------------------------------------------------------------------------------
148-
-- | We need to make sure that we don't try to unify any skolems.
149-
-- To see why, consider the case:
150-
--
151-
-- uhh :: (Int -> Int) -> a
152-
-- uhh f = _
153-
--
154-
-- If we were to apply 'f', then we would try to unify 'Int' and 'a'.
155-
-- This is fine from the perspective of 'tcUnifyTy', but will cause obvious
156-
-- type errors in our use case. Therefore, we need to ensure that our
157-
-- 'TCvSubst' doesn't try to unify skolems.
158-
checkSkolemUnification :: CType -> CType -> TCvSubst -> RuleM ()
159-
checkSkolemUnification t1 t2 subst = do
160-
skolems <- gets ts_skolems
161-
unless (all (flip notElemTCvSubst subst) skolems) $
162-
throwError (UnificationError t1 t2)
173+
-- | Helper method for 'tryUnifyUnivarsButNotSkolems'
174+
skolemsOf :: [TyVar] -> TyVar -> BindFlag
175+
skolemsOf tvs tv =
176+
case elem tv tvs of
177+
True -> Skolem
178+
False -> BindMe
163179

164180

165181
------------------------------------------------------------------------------
166182
-- | Attempt to unify two types.
167183
unify :: CType -- ^ The goal type
168184
-> CType -- ^ The type we are trying unify the goal type with
169185
-> RuleM ()
170-
unify goal inst =
171-
case tcUnifyTy (unCType inst) (unCType goal) of
172-
Just subst -> do
173-
checkSkolemUnification inst goal subst
174-
modify (\s -> s { ts_unifier = unionTCvSubst subst (ts_unifier s) })
175-
Nothing -> throwError (UnificationError inst goal)
186+
unify goal inst = do
187+
skolems <- gets ts_skolems
188+
case tryUnifyUnivarsButNotSkolems skolems goal inst of
189+
Just subst ->
190+
modify (\s -> s { ts_unifier = unionTCvSubst subst (ts_unifier s) })
191+
Nothing -> throwError (UnificationError inst goal)
192+
193+
194+
------------------------------------------------------------------------------
195+
-- | Get the class methods of a 'PredType', correctly dealing with
196+
-- instantiation of quantified class types.
197+
methodHypothesis :: PredType -> Maybe [(OccName, CType)]
198+
methodHypothesis ty = do
199+
(tc, apps) <- splitTyConApp_maybe ty
200+
cls <- tyConClass_maybe tc
201+
let methods = classMethods cls
202+
tvs = classTyVars cls
203+
subst = zipTvSubst tvs apps
204+
sc_methods <- fmap join
205+
$ traverse (methodHypothesis . substTy subst)
206+
$ classSCTheta cls
207+
pure $ mappend sc_methods $ methods <&> \method ->
208+
let (_, _, ty) = tcSplitSigmaTy $ idType method
209+
in (occName method, CType $ substTy subst ty)
210+
211+
212+
------------------------------------------------------------------------------
213+
-- | Run the given tactic iff the current hole contains no univars. Skolems and
214+
-- already decided univars are OK though.
215+
requireConcreteHole :: TacticsM a -> TacticsM a
216+
requireConcreteHole m = do
217+
jdg <- goal
218+
skolems <- gets $ S.fromList . ts_skolems
219+
let vars = S.fromList $ tyCoVarsOfTypeWellScoped $ unCType $ jGoal jdg
220+
case S.size $ vars S.\\ skolems of
221+
0 -> m
222+
_ -> throwError TooPolymorphic
176223

0 commit comments

Comments
 (0)