Skip to content

Commit 7c481f5

Browse files
committed
Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
1 parent 1ea4c04 commit 7c481f5

File tree

3 files changed

+67
-17
lines changed

3 files changed

+67
-17
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
<!-- Add all new changes here. They will be moved under a version at release -->
55
* `NEW` Custom documentation exporter
66
* `NEW` Setting: `Lua.docScriptPath`: Path to a script that overrides `cli.doc.export`, allowing user-specified documentation exporting.
7+
* `NEW` Infer the parameter types of a same-named function in the subclass based on the parameter types in the superclass function.
78
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
89
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
910
* `FIX` Improve type narrow by checking exact match on literal type params

script/vm/compiler.lua

Lines changed: 53 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,21 +1088,26 @@ end
10881088
---@param func parser.object
10891089
---@param source parser.object
10901090
local function compileFunctionParam(func, source)
1091+
local aindex
1092+
for index, arg in ipairs(func.args) do
1093+
if arg == source then
1094+
aindex = index
1095+
break
1096+
end
1097+
end
1098+
---@cast aindex integer
1099+
10911100
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
10921101
local funcNode = vm.compileNode(func)
10931102
for n in funcNode:eachObject() do
10941103
if n.type == 'doc.type.function' then
1095-
for index, arg in ipairs(n.args) do
1096-
if func.args[index] == source then
1097-
local argNode = vm.compileNode(arg)
1098-
for an in argNode:eachObject() do
1099-
if an.type ~= 'doc.generic.name' then
1100-
vm.setNode(source, an)
1101-
end
1102-
end
1103-
return true
1104+
local argNode = vm.compileNode(n.args[aindex])
1105+
for an in argNode:eachObject() do
1106+
if an.type ~= 'doc.generic.name' then
1107+
vm.setNode(source, an)
11041108
end
11051109
end
1110+
return true
11061111
end
11071112
end
11081113

@@ -1118,19 +1123,50 @@ local function compileFunctionParam(func, source)
11181123
if not caller.args then
11191124
goto continue
11201125
end
1121-
for index, arg in ipairs(source.parent) do
1122-
if arg == source then
1123-
local callerArg = caller.args[index]
1124-
if callerArg then
1125-
vm.setNode(source, vm.compileNode(callerArg))
1126-
found = true
1127-
end
1128-
end
1126+
local callerArg = caller.args[aindex]
1127+
if callerArg then
1128+
vm.setNode(source, vm.compileNode(callerArg))
1129+
found = true
11291130
end
11301131
::continue::
11311132
end
11321133
return found
11331134
end
1135+
1136+
do
1137+
local parent = func.parent
1138+
local key = vm.getKeyName(parent)
1139+
local classDef = vm.getParentClass(parent)
1140+
local suri = guide.getUri(func)
1141+
if classDef and key then
1142+
local found
1143+
for _, set in ipairs(classDef:getSets(suri)) do
1144+
if set.type == 'doc.class' and set.extends then
1145+
for _, ext in ipairs(set.extends) do
1146+
local extClass = vm.getGlobal('type', ext[1])
1147+
if extClass then
1148+
vm.getClassFields(suri, extClass, key, function (field, isMark)
1149+
for n in vm.compileNode(field):eachObject() do
1150+
if n.type == 'function' then
1151+
local argNode = vm.compileNode(n.args[aindex])
1152+
for an in argNode:eachObject() do
1153+
if an.type ~= 'doc.generic.name' then
1154+
vm.setNode(source, an)
1155+
found = true
1156+
end
1157+
end
1158+
end
1159+
end
1160+
end)
1161+
end
1162+
end
1163+
end
1164+
end
1165+
if found then
1166+
return true
1167+
end
1168+
end
1169+
end
11341170
end
11351171

11361172
---@param source parser.object

test/type_inference/common.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4428,3 +4428,16 @@ TEST 'A' [[
44284428
local x
44294429
local <?y?> = 1 >> x
44304430
]]
4431+
4432+
TEST 'number' [[
4433+
---@class A
4434+
local A = {}
4435+
4436+
---@param x number
4437+
function A:func(x) end
4438+
4439+
---@class B: A
4440+
local B = {}
4441+
4442+
function B:func(<?x?>) end
4443+
]]

0 commit comments

Comments
 (0)