diff --git a/changelog.md b/changelog.md index e82aa7687..1d4db462f 100644 --- a/changelog.md +++ b/changelog.md @@ -3,6 +3,7 @@ ## Unreleased * `NEW` Added support for Japanese locale +* `NEW` Infer function parameter types when overriding the same-named class function in an instance of that class [#2158](https://github.com/LuaLS/lua-language-server/issues/2158) * `FIX` Eliminate floating point error in test benchmark output * `FIX` Remove luamake install from make scripts diff --git a/script/core/diagnostics/duplicate-set-field.lua b/script/core/diagnostics/duplicate-set-field.lua index a4b205dda..2705c2d4b 100644 --- a/script/core/diagnostics/duplicate-set-field.lua +++ b/script/core/diagnostics/duplicate-set-field.lua @@ -68,6 +68,12 @@ return function (uri, callback) if not defValue or defValue.type ~= 'function' then goto CONTINUE end + if vm.getDefinedClass(guide.getUri(def), def.node) + and not vm.getDefinedClass(guide.getUri(src), src.node) + then + -- allow type variable to override function defined in class variable + goto CONTINUE + end callback { start = src.start, finish = src.finish, diff --git a/script/vm/compiler.lua b/script/vm/compiler.lua index a454731d9..8d6b2db76 100644 --- a/script/vm/compiler.lua +++ b/script/vm/compiler.lua @@ -1117,26 +1117,56 @@ local function compileFunctionParam(func, source) end ---@cast aindex integer - -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number local funcNode = vm.compileNode(func) - local found = false - for n in funcNode:eachObject() do - if n.type == 'doc.type.function' and n.args[aindex] then - local argNode = vm.compileNode(n.args[aindex]) - for an in argNode:eachObject() do - if an.type ~= 'doc.generic.name' then - vm.setNode(source, an) + if func.parent.type == 'callargs' then + -- local call ---@type fun(f: fun(x: number));call(function (x) end) --> x -> number + for n in funcNode:eachObject() do + if n.type == 'doc.type.function' and n.args[aindex] then + local argNode = vm.compileNode(n.args[aindex]) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end end - end - -- NOTE: keep existing behavior for local call which only set type based on the 1st match - if func.parent.type == 'callargs' then + -- NOTE: keep existing behavior for function as argument which only set type based on the 1st match return true end - found = true end - end - if found then - return true + else + -- function declaration: use info from all `fun()`, also from the base function when overriding + --[[ + ---@type fun(x: string)|fun(x: number) + local function f1(x) end --> x -> string|number + + ---@overload fun(x: string) + ---@overload fun(x: number) + local function f2(x) end --> x -> string|number + + ---@class A + local A = {} + ---@param x number + function A:f(x) end --> x -> number + ---@type A + local a = {} + function a:f(x) end --> x -> number + ]] + local found = false + for n in funcNode:eachObject() do + if (n.type == 'doc.type.function' or n.type == 'function') + and n.args[aindex] and n.args[aindex] ~= source + then + local argNode = vm.compileNode(n.args[aindex]) + for an in argNode:eachObject() do + if an.type ~= 'doc.generic.name' then + vm.setNode(source, an) + end + end + found = true + end + end + if found then + return true + end end local derviationParam = config.get(guide.getUri(func), 'Lua.type.inferParamType') diff --git a/test/diagnostics/duplicate-set-field.lua b/test/diagnostics/duplicate-set-field.lua index 469bc3eaa..7b92b061c 100644 --- a/test/diagnostics/duplicate-set-field.lua +++ b/test/diagnostics/duplicate-set-field.lua @@ -72,3 +72,29 @@ else function X.f() end end ]] + +TEST [[ +---@class A +X = {} + +function X:f() end + +---@type x +local x + +function x:f() end +]] + +TEST [[ +---@class A +X = {} + +function X:f() end + +---@type x +local x + +function () end + +function () end +]] diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 11fa39b8a..969f71126 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4441,3 +4441,15 @@ local B = {} function B:func() end ]] + +TEST 'number' [[ +---@class A +local A = {} + +---@param x number +function A:func(x) end + +---@type A +local a = {} +function a:func() end +]]