Skip to content

Commit a4b71fb

Browse files
committed
feat: infer function param type when overriding class function or method
1 parent 6614170 commit a4b71fb

File tree

2 files changed

+57
-15
lines changed

2 files changed

+57
-15
lines changed

script/vm/compiler.lua

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,26 +1117,56 @@ local function compileFunctionParam(func, source)
11171117
end
11181118
---@cast aindex integer
11191119

1120-
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
11211120
local funcNode = vm.compileNode(func)
1122-
local found = false
1123-
for n in funcNode:eachObject() do
1124-
if n.type == 'doc.type.function' and n.args[aindex] then
1125-
local argNode = vm.compileNode(n.args[aindex])
1126-
for an in argNode:eachObject() do
1127-
if an.type ~= 'doc.generic.name' then
1128-
vm.setNode(source, an)
1121+
if func.parent.type == 'callargs' then
1122+
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
1123+
for n in funcNode:eachObject() do
1124+
if n.type == 'doc.type.function' and n.args[aindex] then
1125+
local argNode = vm.compileNode(n.args[aindex])
1126+
for an in argNode:eachObject() do
1127+
if an.type ~= 'doc.generic.name' then
1128+
vm.setNode(source, an)
1129+
end
11291130
end
1130-
end
1131-
-- NOTE: keep existing behavior for local call which only set type based on the 1st match
1132-
if func.parent.type == 'callargs' then
1131+
-- NOTE: keep existing behavior for function as argument which only set type based on the 1st match
11331132
return true
11341133
end
1135-
found = true
11361134
end
1137-
end
1138-
if found then
1139-
return true
1135+
else
1136+
-- function declaration: use info from all `fun()`, also from the base function when overriding
1137+
--[[
1138+
---@type fun(x: string)|fun(x: number)
1139+
local function f1(x) end --> x -> string|number
1140+
1141+
---@overload fun(x: string)
1142+
---@overload fun(x: number)
1143+
local function f2(x) end --> x -> string|number
1144+
1145+
---@class A
1146+
local A = {}
1147+
---@param x number
1148+
function A:f(x) end --> x -> number
1149+
---@type A
1150+
local a = {}
1151+
function a:f(x) end --> x -> number
1152+
]]
1153+
local found = false
1154+
for n in funcNode:eachObject() do
1155+
if (n.type == 'doc.type.function' or n.type == 'function')
1156+
and n.args[aindex] and n.args[aindex] ~= source
1157+
then
1158+
local argNode = vm.compileNode(n.args[aindex])
1159+
for an in argNode:eachObject() do
1160+
if an.type ~= 'doc.generic.name' then
1161+
vm.setNode(source, an)
1162+
end
1163+
end
1164+
found = true
1165+
end
1166+
end
1167+
if found then
1168+
return true
1169+
end
11401170
end
11411171

11421172
local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')

test/type_inference/common.lua

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4441,3 +4441,15 @@ local B = {}
44414441
44424442
function B:func(<?x?>) end
44434443
]]
4444+
4445+
TEST 'number' [[
4446+
---@class A
4447+
local A = {}
4448+
4449+
---@param x number
4450+
function A:func(x) end
4451+
4452+
---@type A
4453+
local a = {}
4454+
function a:func(<?x?>) end
4455+
]]

0 commit comments

Comments
 (0)