Skip to content

Commit 8c1669c

Browse files
committed
fix type infer in overload
1 parent ee590a4 commit 8c1669c

File tree

5 files changed

+96
-37
lines changed

5 files changed

+96
-37
lines changed

script/core/completion/completion.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1659,7 +1659,7 @@ local function tryCallArg(state, position, results)
16591659
return
16601660
end
16611661
---@diagnostic disable-next-line: missing-fields
1662-
local node = vm.compileCallArg({ type = 'dummyarg' }, call, argIndex)
1662+
local node = vm.compileCallArg({ type = 'dummyarg', uri = state.uri }, call, argIndex)
16631663
if not node then
16641664
return
16651665
end

script/vm/compiler.lua

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -882,52 +882,69 @@ local function compileCallArgNode(arg, call, callNode, fixIndex, myIndex)
882882
end
883883
end
884884

885-
for n in callNode:eachObject() do
886-
if n.type == 'function' then
887-
---@cast n parser.object
888-
local sign = vm.getSign(n)
885+
---@param n parser.object
886+
local function dealDocFunc(n)
887+
local myEvent
888+
if n.args[eventIndex] then
889+
local argNode = vm.compileNode(n.args[eventIndex])
890+
myEvent = argNode:get(1)
891+
end
892+
if not myEvent
893+
or not eventMap
894+
or myIndex <= eventIndex
895+
or myEvent.type ~= 'doc.type.string'
896+
or eventMap[myEvent[1]] then
889897
local farg = getFuncArg(n, myIndex)
890898
if farg then
891899
for fn in vm.compileNode(farg):eachObject() do
892900
if isValidCallArgNode(arg, fn) then
893-
if fn.type == 'doc.type.function' then
894-
---@cast fn parser.object
895-
if sign then
896-
local generic = vm.createGeneric(fn, sign)
897-
local args = {}
898-
for i = fixIndex + 1, myIndex - 1 do
899-
args[#args+1] = call.args[i]
900-
end
901-
local resolvedNode = generic:resolve(guide.getUri(call), args)
902-
vm.setNode(arg, resolvedNode)
903-
goto CONTINUE
904-
end
905-
end
906901
vm.setNode(arg, fn)
907-
::CONTINUE::
908902
end
909903
end
910904
end
911905
end
912-
if n.type == 'doc.type.function' then
913-
---@cast n parser.object
914-
local myEvent
915-
if n.args[eventIndex] then
916-
local argNode = vm.compileNode(n.args[eventIndex])
917-
myEvent = argNode:get(1)
918-
end
919-
if not myEvent
920-
or not eventMap
921-
or myIndex <= eventIndex
922-
or myEvent.type ~= 'doc.type.string'
923-
or eventMap[myEvent[1]] then
924-
local farg = getFuncArg(n, myIndex)
925-
if farg then
926-
for fn in vm.compileNode(farg):eachObject() do
927-
if isValidCallArgNode(arg, fn) then
928-
vm.setNode(arg, fn)
906+
end
907+
908+
---@param n parser.object
909+
local function dealFunction(n)
910+
local sign = vm.getSign(n)
911+
local farg = getFuncArg(n, myIndex)
912+
if farg then
913+
for fn in vm.compileNode(farg):eachObject() do
914+
if isValidCallArgNode(arg, fn) then
915+
if fn.type == 'doc.type.function' then
916+
---@cast fn parser.object
917+
if sign then
918+
local generic = vm.createGeneric(fn, sign)
919+
local args = {}
920+
for i = fixIndex + 1, myIndex - 1 do
921+
args[#args+1] = call.args[i]
922+
end
923+
local resolvedNode = generic:resolve(guide.getUri(call), args)
924+
vm.setNode(arg, resolvedNode)
925+
goto CONTINUE
929926
end
930927
end
928+
vm.setNode(arg, fn)
929+
::CONTINUE::
930+
end
931+
end
932+
end
933+
end
934+
935+
for n in callNode:eachObject() do
936+
if n.type == 'function' then
937+
---@cast n parser.object
938+
dealFunction(n)
939+
elseif n.type == 'doc.type.function' then
940+
---@cast n parser.object
941+
dealDocFunc(n)
942+
elseif n.type == 'global' and n.cate == 'type' then
943+
---@cast n vm.global
944+
local overloads = vm.getOverloadsByTypeName(n.name, guide.getUri(arg))
945+
if overloads then
946+
for _, func in ipairs(overloads) do
947+
dealDocFunc(func)
931948
end
932949
end
933950
end

script/vm/sign.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ function mt:resolve(uri, args)
254254
local argNode = vm.compileNode(arg)
255255
local knownTypes, genericNames = getSignInfo(sign)
256256
if not isAllResolved(genericNames) then
257-
local newArgNode = buildArgNode(argNode,sign, knownTypes)
257+
local newArgNode = buildArgNode(argNode, sign, knownTypes)
258258
resolve(sign, newArgNode)
259259
end
260260
end

script/vm/type.lua

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,25 @@ function vm.viewTypeErrorMessage(uri, errs)
767767
return table.concat(lines, '\n')
768768
end
769769
end
770+
771+
---@param name string
772+
---@param uri uri
773+
---@return parser.object[]?
774+
function vm.getOverloadsByTypeName(name, uri)
775+
local global = vm.getGlobal('type', name)
776+
if not global then
777+
return nil
778+
end
779+
local results
780+
for _, set in ipairs(global:getSets(uri)) do
781+
for _, doc in ipairs(set.bindGroup) do
782+
if doc.type == 'doc.overload' then
783+
if not results then
784+
results = {}
785+
end
786+
results[#results+1] = doc.overload
787+
end
788+
end
789+
end
790+
return results
791+
end

test/completion/common.lua

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4393,3 +4393,23 @@ f {
43934393
kind = define.CompletionItemKind.Property,
43944394
},
43954395
}
4396+
4397+
TEST [[
4398+
---@class A
4399+
---@overload fun(x: {id: string})
4400+
4401+
---@generic T
4402+
---@param t `T`
4403+
---@return T
4404+
local function new(t) end
4405+
4406+
new 'A' {
4407+
<??>
4408+
}
4409+
]]
4410+
{
4411+
{
4412+
label = 'id',
4413+
kind = define.CompletionItemKind.Property,
4414+
}
4415+
}

0 commit comments

Comments
 (0)