Skip to content

Code action: Expand catch all variant #987

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- Emit `%todo` instead of `failwith("TODO")` when we can (ReScript >= v11.1). https://github.com/rescript-lang/rescript-vscode/pull/981
- Complete `%todo`. https://github.com/rescript-lang/rescript-vscode/pull/981
- Add code action for extracting a locally defined module into its own file. https://github.com/rescript-lang/rescript-vscode/pull/983
- Add code action for expanding catch-all patterns. https://github.com/rescript-lang/rescript-vscode/pull/987

## 1.50.0

Expand Down
4 changes: 3 additions & 1 deletion analysis/src/CompletionFrontEnd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,9 @@ let completionWithParser1 ~currentFile ~debug ~offset ~path ~posCursor
typedCompletionExpr expr;
match expr.pexp_desc with
| Pexp_match (expr, cases)
when cases <> [] && locHasCursor expr.pexp_loc = false ->
when cases <> []
&& locHasCursor expr.pexp_loc = false
&& Option.is_none findThisExprLoc ->
if Debug.verbose () then
print_endline "[completionFrontend] Checking each case";
let ctxPath = exprToContextPath expr in
Expand Down
272 changes: 230 additions & 42 deletions analysis/src/Xform.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,42 @@ let rangeOfLoc (loc : Location.t) =
let end_ = loc |> Loc.end_ |> mkPosition in
{Protocol.start; end_}

let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
match
expr.Parsetree.pexp_loc
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path ~currentFile
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
with
| Some (completable, scope) -> (
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
let completions =
completable
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
~forHover:true
in
let rawOpens = Scope.getRawOpens scope in
match completions with
| {env} :: _ -> (
let opens =
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package ~env
in
match
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full ~rawOpens
~opens ~pos completions
with
| Some (typ, _env) ->
let extractedType =
match typ with
| ExtractedType t -> Some t
| TypeExpr t ->
TypeUtils.extractType t ~env ~package:full.package
|> TypeUtils.getExtractedType
in
extractedType
| None -> None)
| _ -> None)
| _ -> None

module IfThenElse = struct
(* Convert if-then-else to switch *)

Expand Down Expand Up @@ -324,6 +360,196 @@ module AddTypeAnnotation = struct
| _ -> ()))
end

module ExpandCatchAllForVariants = struct
let mkIterator ~pos ~result =
let expr (iterator : Ast_iterator.iterator) (e : Parsetree.expression) =
(if e.pexp_loc |> Loc.hasPos ~pos then
match e.pexp_desc with
| Pexp_match (switchExpr, cases) -> (
let catchAllCase =
cases
|> List.find_opt (fun (c : Parsetree.case) ->
match c with
| {pc_lhs = {ppat_desc = Ppat_any}} -> true
| _ -> false)
in
match catchAllCase with
| None -> ()
| Some catchAllCase ->
result := Some (switchExpr, catchAllCase, cases))
| _ -> ());
Ast_iterator.default_iterator.expr iterator e
in
{Ast_iterator.default_iterator with expr}

let xform ~path ~pos ~full ~structure ~currentFile ~codeActions ~debug =
let result = ref None in
let iterator = mkIterator ~pos ~result in
iterator.structure iterator structure;
match !result with
| None -> ()
| Some (switchExpr, catchAllCase, cases) -> (
if Debug.verbose () then
print_endline
"[codeAction - ExpandCatchAllForVariants] Found target switch";
let currentConstructorNames =
cases
|> List.filter_map (fun (c : Parsetree.case) ->
match c with
| {pc_lhs = {ppat_desc = Ppat_construct ({txt}, _)}} ->
Some (Longident.last txt)
| {pc_lhs = {ppat_desc = Ppat_variant (name, _)}} -> Some name
| _ -> None)
in
match
switchExpr
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
~pos:(Pos.ofLexing switchExpr.pexp_loc.loc_end)
with
| Some (Tvariant {constructors}) ->
let missingConstructors =
constructors
|> List.filter (fun (c : SharedTypes.Constructor.t) ->
currentConstructorNames |> List.mem c.cname.txt = false)
in
if List.length missingConstructors > 0 then
let newText =
missingConstructors
|> List.map (fun (c : SharedTypes.Constructor.t) ->
c.cname.txt
^
match c.args with
| Args [] -> ""
| Args _ | InlineRecord _ -> "(_)")
|> String.concat " | "
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| Some (Tpolyvariant {constructors}) ->
let missingConstructors =
constructors
|> List.filter (fun (c : SharedTypes.polyVariantConstructor) ->
currentConstructorNames |> List.mem c.name = false)
in
if List.length missingConstructors > 0 then
let newText =
missingConstructors
|> List.map (fun (c : SharedTypes.polyVariantConstructor) ->
Res_printer.polyVarIdentToString c.name
^
match c.args with
| [] -> ""
| _ -> "(_)")
|> String.concat " | "
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| Some (Toption (env, innerType)) -> (
if Debug.verbose () then
print_endline
"[codeAction - ExpandCatchAllForVariants] Found option type";
let innerType =
match innerType with
| ExtractedType t -> Some t
| TypeExpr t -> (
match TypeUtils.extractType ~env ~package:full.package t with
| None -> None
| Some (t, _) -> Some t)
in
match innerType with
| Some ((Tvariant _ | Tpolyvariant _) as variant) ->
let currentConstructorNames =
cases
|> List.filter_map (fun (c : Parsetree.case) ->
match c with
| {
pc_lhs =
{
ppat_desc =
Ppat_construct
( {txt = Lident "Some"},
Some {ppat_desc = Ppat_construct ({txt}, _)} );
};
} ->
Some (Longident.last txt)
| {
pc_lhs =
{
ppat_desc =
Ppat_construct
( {txt = Lident "Some"},
Some {ppat_desc = Ppat_variant (name, _)} );
};
} ->
Some name
| _ -> None)
in
let hasNoneCase =
cases
|> List.exists (fun (c : Parsetree.case) ->
match c.pc_lhs.ppat_desc with
| Ppat_construct ({txt = Lident "None"}, _) -> true
| _ -> false)
in
let missingConstructors =
match variant with
| Tvariant {constructors} ->
constructors
|> List.filter_map (fun (c : SharedTypes.Constructor.t) ->
if currentConstructorNames |> List.mem c.cname.txt = false
then
Some
( c.cname.txt,
match c.args with
| Args [] -> false
| _ -> true )
else None)
| Tpolyvariant {constructors} ->
constructors
|> List.filter_map
(fun (c : SharedTypes.polyVariantConstructor) ->
if currentConstructorNames |> List.mem c.name = false then
Some
( Res_printer.polyVarIdentToString c.name,
match c.args with
| [] -> false
| _ -> true )
else None)
| _ -> []
in
if List.length missingConstructors > 0 || not hasNoneCase then
let newText =
"Some("
^ (missingConstructors
|> List.map (fun (name, hasArgs) ->
name ^ if hasArgs then "(_)" else "")
|> String.concat " | ")
^ ")"
in
let newText =
if hasNoneCase then newText else newText ^ " | None"
in
let range = rangeOfLoc catchAllCase.pc_lhs.ppat_loc in
let codeAction =
CodeActions.make ~title:"Expand catch-all" ~kind:RefactorRewrite
~uri:path ~newText ~range
in
codeActions := codeAction :: !codeActions
else ()
| _ -> ())
| _ -> ())
end

module ExhaustiveSwitch = struct
(* Expand expression to be an exhaustive switch of the underlying value *)
type posType = Single of Pos.t | Range of Pos.t * Pos.t
Expand All @@ -336,46 +562,6 @@ module ExhaustiveSwitch = struct
}
| Selection of {expr: Parsetree.expression}

module C = struct
let extractTypeFromExpr expr ~debug ~path ~currentFile ~full ~pos =
match
expr.Parsetree.pexp_loc
|> CompletionFrontEnd.findTypeOfExpressionAtLoc ~debug ~path
~currentFile
~posCursor:(Pos.ofLexing expr.Parsetree.pexp_loc.loc_start)
with
| Some (completable, scope) -> (
let env = SharedTypes.QueryEnv.fromFile full.SharedTypes.file in
let completions =
completable
|> CompletionBackEnd.processCompletable ~debug ~full ~pos ~scope ~env
~forHover:true
in
let rawOpens = Scope.getRawOpens scope in
match completions with
| {env} :: _ -> (
let opens =
CompletionBackEnd.getOpens ~debug ~rawOpens ~package:full.package
~env
in
match
CompletionBackEnd.completionsGetCompletionType2 ~debug ~full
~rawOpens ~opens ~pos completions
with
| Some (typ, _env) ->
let extractedType =
match typ with
| ExtractedType t -> Some t
| TypeExpr t ->
TypeUtils.extractType t ~env ~package:full.package
|> TypeUtils.getExtractedType
in
extractedType
| None -> None)
| _ -> None)
| _ -> None
end

let mkIteratorSingle ~pos ~result =
let expr (iterator : Ast_iterator.iterator) (exp : Parsetree.expression) =
(match exp.pexp_desc with
Expand Down Expand Up @@ -434,7 +620,7 @@ module ExhaustiveSwitch = struct
| Some (Selection {expr}) -> (
match
expr
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full
|> extractTypeFromExpr ~debug ~path ~currentFile ~full
~pos:(Pos.ofLexing expr.pexp_loc.loc_start)
with
| None -> ()
Expand All @@ -460,7 +646,7 @@ module ExhaustiveSwitch = struct
| Some (Switch {switchExpr; completionExpr; pos}) -> (
match
completionExpr
|> C.extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
|> extractTypeFromExpr ~debug ~path ~currentFile ~full ~pos
with
| None -> ()
| Some extractedType -> (
Expand Down Expand Up @@ -743,6 +929,8 @@ let extractCodeActions ~path ~startPos ~endPos ~currentFile ~debug =
match Cmt.loadFullCmtFromPath ~path with
| Some full ->
AddTypeAnnotation.xform ~path ~pos ~full ~structure ~codeActions ~debug;
ExpandCatchAllForVariants.xform ~path ~pos ~full ~structure ~codeActions
~currentFile ~debug;
ExhaustiveSwitch.xform ~printExpr ~path
~pos:
(if startPos = endPos then Single startPos
Expand Down
43 changes: 38 additions & 5 deletions analysis/tests/src/Xform.res
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
type kind = First | Second | Third
type kind = First | Second | Third | Fourth(int)
type r = {name: string, age: int}

let ret = _ => assert false
let kind = assert false
let ret = _ => assert(false)
let kind = assert(false)

if kind == First {
// ^xfm
Expand Down Expand Up @@ -63,7 +63,7 @@ let bar = () => {
}
//^xfm
}
@res.partial Inner.foo(1)
Inner.foo(1, ...)
}

module ExtractableModule = {
Expand All @@ -72,4 +72,37 @@ module ExtractableModule = {
// A comment here
let doStuff = a => a + 1
// ^xfm
}
}

let variant = First

let _x = switch variant {
| First => "first"
| _ => "other"
// ^xfm
}

let polyvariant: [#first | #second | #"illegal identifier" | #third(int)] = #first

let _y = switch polyvariant {
| #first => "first"
| _ => "other"
// ^xfm
}

let variantOpt = Some(variant)

let _x = switch variantOpt {
| Some(First) => "first"
| _ => "other"
// ^xfm
}

let polyvariantOpt = Some(polyvariant)

let _x = switch polyvariantOpt {
| Some(#first) => "first"
| None => "nothing"
| _ => "other"
// ^xfm
}
Loading
Loading