Skip to content

Commit d1320ae

Browse files
authored
Merge pull request #2838 from tomlau10/fix/func_type_union_overload
Fix incorrect function params' type infer when there is only `@overload`
2 parents c9d8193 + fe84a71 commit d1320ae

File tree

4 files changed

+79
-1
lines changed

4 files changed

+79
-1
lines changed

changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* `FIX` Fix `VM.OnCompileFunctionParam` function in plugins
99
* `FIX` Lua 5.1: fix incorrect warning when using setfenv with an int as first parameter
1010
* `FIX` Improve type narrow by checking exact match on literal type params
11+
* `FIX` Incorrect function params' type infer when there is only `@overload` [#2509](https://github.com/LuaLS/lua-language-server/issues/2509) [#2708](https://github.com/LuaLS/lua-language-server/issues/2708) [#2709](https://github.com/LuaLS/lua-language-server/issues/2709)
1112

1213
## 3.10.5
1314
`2024-8-19`

script/vm/compiler.lua

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,6 +1099,7 @@ local function compileFunctionParam(func, source)
10991099

11001100
-- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number
11011101
local funcNode = vm.compileNode(func)
1102+
local found = false
11021103
for n in funcNode:eachObject() do
11031104
if n.type == 'doc.type.function' and n.args[aindex] then
11041105
local argNode = vm.compileNode(n.args[aindex])
@@ -1107,9 +1108,16 @@ local function compileFunctionParam(func, source)
11071108
vm.setNode(source, an)
11081109
end
11091110
end
1110-
return true
1111+
-- NOTE: keep existing behavior for local call which only set type based on the 1st match
1112+
if func.parent.type == 'callargs' then
1113+
return true
1114+
end
1115+
found = true
11111116
end
11121117
end
1118+
if found then
1119+
return true
1120+
end
11131121

11141122
local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType')
11151123
if derviationParam and func.parent.type == 'local' and func.parent.ref then

script/vm/function.lua

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ end
359359
---@return number
360360
local function calcFunctionMatchScore(uri, args, func)
361361
if vm.isVarargFunctionWithOverloads(func)
362+
or vm.isFunctionWithOnlyOverloads(func)
362363
or not isAllParamMatched(uri, args, func.args)
363364
then
364365
return -1
@@ -490,6 +491,36 @@ function vm.isVarargFunctionWithOverloads(func)
490491
return false
491492
end
492493

494+
---@param func table
495+
---@return boolean
496+
function vm.isFunctionWithOnlyOverloads(func)
497+
if func.type ~= 'function' then
498+
return false
499+
end
500+
if func._onlyOverloadFunction ~= nil then
501+
return func._onlyOverloadFunction
502+
end
503+
504+
if not func.bindDocs then
505+
func._onlyOverloadFunction = false
506+
return false
507+
end
508+
local hasOverload = false
509+
for _, doc in ipairs(func.bindDocs) do
510+
if doc.type == 'doc.overload' then
511+
hasOverload = true
512+
elseif doc.type == 'doc.param'
513+
or doc.type == 'doc.return'
514+
then
515+
-- has specified @param or @return, thus not only @overload
516+
func._onlyOverloadFunction = false
517+
return false
518+
end
519+
end
520+
func._onlyOverloadFunction = hasOverload
521+
return true
522+
end
523+
493524
---@param func parser.object
494525
---@return boolean
495526
function vm.isEmptyFunction(func)

test/type_inference/param_match.lua

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,44 @@ local v = 'y'
172172
local <?r?> = f(v)
173173
]]
174174

175+
TEST 'string|number' [[
176+
---@overload fun(a: string)
177+
---@overload fun(a: number)
178+
local function f(<?a?>) end
179+
]]
180+
181+
TEST '1|2' [[
182+
---@overload fun(a: 1)
183+
---@overload fun(a: 2)
184+
local function f(<?a?>) end
185+
]]
186+
187+
TEST 'string' [[
188+
---@overload fun(a: 1): string
189+
---@overload fun(a: 2): number
190+
local function f(a) end
191+
192+
local <?r?> = f(1)
193+
]]
194+
195+
TEST 'number' [[
196+
---@overload fun(a: 1): string
197+
---@overload fun(a: 2): number
198+
local function f(a) end
199+
200+
local <?r?> = f(2)
201+
]]
202+
203+
TEST 'string|number' [[
204+
---@overload fun(a: 1): string
205+
---@overload fun(a: 2): number
206+
local function f(a) end
207+
208+
---@type number
209+
local v
210+
local <?r?> = f(v)
211+
]]
212+
175213
TEST 'number' [[
176214
---@overload fun(a: 1, c: fun(x: number))
177215
---@overload fun(a: 2, c: fun(x: string))

0 commit comments

Comments
 (0)