diff --git a/changelog.md b/changelog.md index 1d4db462f..e41293f60 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,7 @@ * `NEW` Infer function parameter types when overriding the same-named class function in an instance of that class [#2158](https://github.com/LuaLS/lua-language-server/issues/2158) * `FIX` Eliminate floating point error in test benchmark output * `FIX` Remove luamake install from make scripts +* `NEW` Types with literal fields can be narrowed. ## 3.10.6 `2024-9-10` diff --git a/script/vm/tracer.lua b/script/vm/tracer.lua index e47a98245..ef94cc9e8 100644 --- a/script/vm/tracer.lua +++ b/script/vm/tracer.lua @@ -256,6 +256,44 @@ function mt:fastWardCasts(pos, node) return node end +--- Return types of source which have a field with the value of literal. +--- @param uri uri +--- @param source parser.object +--- @param fieldName string +--- @param literal parser.object +--- @return string[]? +local function getNodeTypesWithLiteralField(uri, source, fieldName, literal) + local loc = vm.getVariable(source) + if not loc then + return + end + + local tys + + for _, c in ipairs(vm.compileNode(loc)) do + if c.cate == 'type' then + for _, set in ipairs(c:getSets(uri)) do + if set.type == 'doc.class' then + for _, f in ipairs(set.fields) do + if f.field[1] == fieldName then + for _, t in ipairs(f.extends.types) do + if t[1] == literal[1] then + tys = tys or {} + table.insert(tys, set.class[1]) + break + end + end + break + end + end + end + end + end + end + + return tys +end + local lookIntoChild = util.switch() : case 'getlocal' : case 'getglobal' @@ -637,6 +675,27 @@ local lookIntoChild = util.switch() end end end + elseif handler.type == 'getfield' + and handler.node.type == 'getlocal' then + local tys = getNodeTypesWithLiteralField( + tracer.uri, handler.node, handler.field[1], checker) + + -- TODO: handle more types + if tys and #tys == 1 then + local ty = tys[1] + topNode = topNode:copy() + if action.op.type == '==' then + topNode:narrow(tracer.uri, ty) + if outNode then + outNode:remove(ty) + end + else + topNode:remove(ty) + if outNode then + outNode:narrow(tracer.uri, ty) + end + end + end elseif handler.type == 'call' and checker.type == 'string' and handler.node.special == 'type' diff --git a/test/type_inference/common.lua b/test/type_inference/common.lua index 969f71126..f7eda03ff 100644 --- a/test/type_inference/common.lua +++ b/test/type_inference/common.lua @@ -4453,3 +4453,53 @@ function A:func(x) end local a = {} function a:func() end ]] + +TEST 'A' [[ +---@class A +---@field type 'a' +---@field field1 integer + +---@class B +---@field type 'b' + +local obj --- @type A|B + +if obj.type == 'a' and obj.field1 > 0 then + local = obj +end +]] + +TEST 'B' [[ +---@class A +---@field type 'a' + +---@class B +---@field type 'b' + +local obj --- @type A|B + +if obj.type == 'a' then + --- +else + local = obj +end +]] + +TEST 'A' [[ +---@class A +---@field type 'a' + +---@class B +---@field type 'b' + +---@class C +---@field type 'c' + +---@alias AB A|B + +local obj --- @type C|AB + +if obj.type == 'a' then + local = obj +end +]]