This commit is contained in:
2025-11-18 15:15:04 -05:00
parent 81a6c323d6
commit 488bbd20d2
1641 changed files with 927512 additions and 14 deletions

View File

@@ -0,0 +1,60 @@
local byte = string.byte
local max = 0x7fffffff
---@class SDBMHash
local mt = {}
mt.__index = mt
mt.cache = nil
---@param str string
---@return integer
function mt:rawHash(str)
local id = 0
for i = 1, #str do
local b = byte(str, i, i)
id = id * 65599 + b
end
return id & max
end
---@param str string
---@return integer
function mt:hash(str)
local id = self:rawHash(str)
local other = self.cache[id]
if other == nil or str == other then
self.cache[id] = str
self.cache[str] = id
return id
else
log.warn(('哈希碰撞:[%s] -> [%s]: [%d]'):format(str, other, id))
for i = 1, max do
local newId = (id + i) % max
if not self.cache[newId] then
self.cache[newId] = str
self.cache[str] = newId
return newId
end
end
error(('哈希碰撞解决失败:[%s] -> [%s]: [%d]'):format(str, other, id))
end
end
function mt:setCache(t)
self.cache = t
end
function mt:getCache()
return self.cache
end
mt.__call = mt.hash
---@return SDBMHash
return function ()
local self = setmetatable({
cache = {}
}, mt)
return self
end

View File

@@ -0,0 +1,267 @@
local timer = require 'timer'
local wkmt = { __mode = 'k' }
---@class await
local m = {}
m.type = 'await'
m.coMap = setmetatable({}, wkmt)
m.idMap = {}
m.delayQueue = {}
m.delayQueueIndex = 1
m.needClose = {}
m._enable = true
local function setID(id, co, callback)
if not coroutine.isyieldable(co) then
return
end
if not m.idMap[id] then
m.idMap[id] = setmetatable({}, wkmt)
end
m.idMap[id][co] = callback or true
end
--- 设置错误处理器
---@param errHandle function # 当有错误发生时,会以错误堆栈为参数调用该函数
function m.setErrorHandle(errHandle)
m.errorHandle = errHandle
end
function m.checkResult(co, ...)
local suc, err = ...
if not suc and m.errorHandle then
m.errorHandle(debug.traceback(co, err))
end
return ...
end
--- 创建一个任务
---@param callback async fun()
function m.call(callback, ...)
local co = coroutine.create(callback)
local closers = {}
m.coMap[co] = {
closers = closers,
priority = false,
}
for i = 1, select('#', ...) do
local id = select(i, ...)
if not id then
break
end
setID(id, co)
end
local currentCo = coroutine.running()
local current = m.coMap[currentCo]
if current then
for closer in pairs(current.closers) do
closers[closer] = true
closer(co)
end
end
return m.checkResult(co, coroutine.resume(co))
end
--- 创建一个任务,并挂起当前线程,当任务完成后再延续当前线程/若任务被关闭则返回nil
---@async
function m.await(callback, ...)
if not coroutine.isyieldable() then
return callback(...)
end
return m.wait(function (resume, ...)
m.call(function ()
local returnNil <close> = resume
resume(callback())
end, ...)
end, ...)
end
--- 设置一个id用于批量关闭任务
function m.setID(id, callback)
local co = coroutine.running()
setID(id, co, callback)
end
--- 根据id批量关闭任务
function m.close(id)
local map = m.idMap[id]
if not map then
return
end
m.idMap[id] = nil
for co, callback in pairs(map) do
if coroutine.status(co) == 'suspended' then
map[co] = nil
if type(callback) == 'function' then
xpcall(callback, log.error)
end
coroutine.close(co)
end
end
end
function m.hasID(id, co)
co = co or coroutine.running()
return m.idMap[id] and m.idMap[id][co] ~= nil
end
function m.unique(id, callback)
m.close(id)
m.setID(id, callback)
end
--- 休眠一段时间
---@param time number
---@async
function m.sleep(time)
if not coroutine.isyieldable() then
if m.errorHandle then
m.errorHandle(debug.traceback('Cannot yield'))
end
return
end
local co = coroutine.running()
timer.wait(time, function ()
if coroutine.status(co) ~= 'suspended' then
return
end
return m.checkResult(co, coroutine.resume(co))
end)
return coroutine.yield()
end
--- 等待直到唤醒
---@param callback function
---@async
function m.wait(callback, ...)
local co = coroutine.running()
local resumed
callback(function (...)
if resumed then
return
end
resumed = true
if coroutine.status(co) ~= 'suspended' then
return
end
return m.checkResult(co, coroutine.resume(co, ...))
end, ...)
return coroutine.yield()
end
--- 延迟
---@async
function m.delay()
if not m._enable then
return
end
if not coroutine.isyieldable() then
return
end
local co = coroutine.running()
local current = m.coMap[co]
-- TODO
if current.priority then
return
end
m.delayQueue[#m.delayQueue+1] = function ()
if coroutine.status(co) ~= 'suspended' then
return
end
return m.checkResult(co, coroutine.resume(co))
end
return coroutine.yield()
end
local throttledDelayer = {}
throttledDelayer.__index = throttledDelayer
---@async
function throttledDelayer:delay()
if not m._enable then
return
end
self.calls = self.calls + 1
if self.calls == self.factor then
self.calls = 0
return m.delay()
end
end
function m.newThrottledDelayer(factor)
return setmetatable({
factor = factor,
calls = 0,
}, throttledDelayer)
end
--- stop then close
---@async
function m.stop()
if not coroutine.isyieldable() then
return
end
m.needClose[#m.needClose+1] = coroutine.running()
coroutine.yield()
end
local function warnStepTime(passed, waker)
if passed < 2 then
log.warn(('Await step takes [%.3f] sec.'):format(passed))
return
end
for i = 1, 100 do
local name, v = debug.getupvalue(waker, i)
if not name then
return
end
if name == 'co' then
log.warn(debug.traceback(v, ('[fire]Await step takes [%.3f] sec.'):format(passed)))
return
end
end
end
--- 步进
function m.step()
for i = #m.needClose, 1, -1 do
coroutine.close(m.needClose[i])
m.needClose[i] = nil
end
local resume = m.delayQueue[m.delayQueueIndex]
if resume then
m.delayQueue[m.delayQueueIndex] = false
m.delayQueueIndex = m.delayQueueIndex + 1
local clock = os.clock()
resume()
local passed = os.clock() - clock
if passed > 0.5 then
warnStepTime(passed, resume)
end
return true
else
for i = 1, #m.delayQueue do
m.delayQueue[i] = nil
end
m.delayQueueIndex = 1
return false
end
end
function m.setPriority(n)
m.coMap[coroutine.running()].priority = true
end
function m.enable()
m._enable = true
end
function m.disable()
m._enable = false
end
return m

View File

@@ -0,0 +1,68 @@
local thread = require 'bee.thread'
local taskPad = thread.channel('taskpad')
local waiter = thread.channel('waiter')
---@class pub_brave
local m = {}
m.type = 'brave'
m.ability = {}
m.queue = {}
--- 注册成为勇者
function m.register(id, privatePad)
m.id = id
if #m.queue > 0 then
for _, info in ipairs(m.queue) do
waiter:push(m.id, info.name, info.params)
end
end
m.queue = nil
m.start(privatePad)
end
--- 注册能力
function m.on(name, callback)
m.ability[name] = callback
end
--- 报告
function m.push(name, params)
if m.id then
waiter:push(m.id, name, params)
else
m.queue[#m.queue+1] = {
name = name,
params = params,
}
end
end
--- 开始找工作
function m.start(privatePad)
local reqPad = privatePad and thread.channel('req:' .. privatePad) or taskPad
local resPad = privatePad and thread.channel('res:' .. privatePad) or waiter
m.push('mem', collectgarbage 'count')
while true do
local name, id, params = reqPad:bpop()
local ability = m.ability[name]
-- TODO
if not ability then
resPad:push(m.id, id)
log.error('Brave can not handle this work: ' .. name)
goto CONTINUE
end
local ok, res = xpcall(ability, log.error, params)
if ok then
resPad:push(m.id, id, res)
else
resPad:push(m.id, id)
end
m.push('mem', collectgarbage 'count')
::CONTINUE::
end
end
return m

View File

@@ -0,0 +1,4 @@
local brave = require 'brave.brave'
require 'brave.work'
return brave

View File

@@ -0,0 +1,55 @@
local brave = require 'brave'
local time = require 'bee.time'
local tablePack = table.pack
local tostring = tostring
local tableConcat = table.concat
local debugTraceBack = debug.traceback
local debugGetInfo = debug.getinfo
local monotonic = time.monotonic
_ENV = nil
local function pushLog(level, ...)
local t = tablePack(...)
for i = 1, t.n do
t[i] = tostring(t[i])
end
local str = tableConcat(t, '\t', 1, t.n)
if level == 'error' then
str = str .. '\n' .. debugTraceBack(nil, 3)
end
local info = debugGetInfo(3, 'Sl')
brave.push('log', {
level = level,
msg = str,
src = info.source,
line = info.currentline,
clock = monotonic(),
})
return str
end
local m = {}
function m.info(...)
pushLog('info', ...)
end
function m.debug(...)
pushLog('debug', ...)
end
function m.trace(...)
pushLog('trace', ...)
end
function m.warn(...)
pushLog('warn', ...)
end
function m.error(...)
pushLog('error', ...)
end
return m

View File

@@ -0,0 +1,125 @@
local brave = require 'brave.brave'
brave.on('loadProtoByStdio', function ()
local jsonrpc = require 'jsonrpc'
while true do
local proto, err = jsonrpc.decode(io.read)
--log.debug('loaded proto', proto.method)
if not proto then
brave.push('protoerror', err)
return
end
brave.push('proto', proto)
end
end)
brave.on('loadProtoBySocket', function (param)
local jsonrpc = require 'jsonrpc'
local net = require 'service.net'
local buf = ''
---@async
local parser = coroutine.create(function ()
while true do
---@async
local proto, err = jsonrpc.decode(function (len)
while true do
if #buf >= len then
local res = buf:sub(1, len)
buf = buf:sub(len + 1)
return res
end
coroutine.yield()
end
end)
--log.debug('loaded proto', proto.method)
if not proto then
brave.push('protoerror', err)
return
end
brave.push('proto', proto)
end
end)
local lsclient = net.connect('tcp', '127.0.0.1', param.port)
local lsmaster = net.connect('unix', param.unixPath)
assert(lsclient)
assert(lsmaster)
function lsclient:on_data(data)
buf = buf .. data
coroutine.resume(parser)
end
function lsclient:on_error(...)
log.error(...)
end
function lsmaster:on_data(data)
lsclient:write(data)
--net.update()
end
function lsmaster:on_error(...)
log.error(...)
end
while true do
net.update(10)
end
end)
brave.on('timer', function (time)
local thread = require 'bee.thread'
while true do
thread.sleep(math.floor(time * 1000))
brave.push('wakeup')
end
end)
brave.on('loadFile', function (path)
local util = require 'utility'
return util.loadFile(path)
end)
brave.on('removeCaches', function (path)
local fs = require 'bee.filesystem'
local fsu = require 'fs-utility'
for dir in fs.pairs(fs.path(path)) do
local lockFile = dir / '.lock'
local f = io.open(lockFile:string(), 'wb')
if f then
f:close()
fsu.fileRemove(dir)
end
end
end)
---@class brave.param.compile
---@field uri uri
---@field text string
---@field mode string
---@field version string
---@field options brave.param.compile.options
---@class brave.param.compile.options
---@field special table<string, string>
---@field unicodeName boolean
---@field nonstandardSymbol table<string, true>
---@param param brave.param.compile
brave.on('compile', function (param)
local parser = require 'parser'
local clock = os.clock()
local state, err = parser.compile(param.text
, param.mode
, param.version
, param.options
)
log.debug('Async compile', param.uri, 'takes:', os.clock() - clock)
return {
state = state,
err = err,
}
end)

View File

@@ -0,0 +1,119 @@
local lang = require 'language'
local platform = require 'bee.platform'
local subprocess = require 'bee.subprocess'
local json = require 'json'
local jsonb = require 'json-beautify'
local util = require 'utility'
local export = {}
local function logFileForThread(threadId)
return LOGPATH .. '/check-partial-' .. threadId .. '.json'
end
local function buildArgs(exe, numThreads, threadId, format, quiet)
local args = {exe}
local skipNext = false
for i = 1, #arg do
local arg = arg[i]
-- --check needs to be transformed into --check_worker
if arg:lower():match('^%-%-check$') or arg:lower():match('^%-%-check=') then
args[#args + 1] = arg:gsub('%-%-%w*', '--check_worker')
-- --check_out_path needs to be removed if we have more than one thread
elseif arg:lower():match('%-%-check_out_path') and numThreads > 1 then
if not arg:match('%-%-[%w_]*=') then
skipNext = true
end
else
if skipNext then
skipNext = false
else
args[#args + 1] = arg
end
end
end
args[#args + 1] = '--thread_id'
args[#args + 1] = tostring(threadId)
if numThreads > 1 then
if quiet then
args[#args + 1] = '--quiet'
end
if format then
args[#args + 1] = '--check_format=' .. format
end
args[#args + 1] = '--check_out_path'
args[#args + 1] = logFileForThread(threadId)
end
return args
end
function export.runCLI()
local numThreads = tonumber(NUM_THREADS or 1)
local exe
local minIndex = -1
while arg[minIndex] do
exe = arg[minIndex]
minIndex = minIndex - 1
end
-- TODO: is this necessary? got it from the shell.lua helper in bee.lua tests
if platform.os == 'windows' and not exe:match('%.[eE][xX][eE]$') then
exe = exe..'.exe'
end
if not QUIET and numThreads > 1 then
print(lang.script('CLI_CHECK_MULTIPLE_WORKERS', numThreads))
end
local procs = {}
for i = 1, numThreads do
local process, err = subprocess.spawn({buildArgs(exe, numThreads, i, CHECK_FORMAT, QUIET)})
if err then
print(err)
end
if process then
procs[#procs + 1] = process
end
end
local checkPassed = true
for _, process in ipairs(procs) do
checkPassed = process:wait() == 0 and checkPassed
end
if numThreads > 1 then
local mergedResults = {}
local count = 0
for i = 1, numThreads do
local result = json.decode(util.loadFile(logFileForThread(i)) or '[]')
for k, v in pairs(result) do
local entries = mergedResults[k] or {}
mergedResults[k] = entries
for _, entry in ipairs(v) do
entries[#entries + 1] = entry
count = count + 1
end
end
end
local outpath = nil
if CHECK_FORMAT == 'json' or CHECK_OUT_PATH then
outpath = CHECK_OUT_PATH or LOGPATH .. '/check.json'
util.saveFile(outpath, jsonb.beautify(mergedResults))
end
if not QUIET then
if count == 0 then
print(lang.script('CLI_CHECK_SUCCESS'))
elseif outpath then
print(lang.script('CLI_CHECK_RESULTS_OUTPATH', count, outpath))
else
print(lang.script('CLI_CHECK_RESULTS_PRETTY', count))
end
end
end
return checkPassed and 0 or 1
end
return export

View File

@@ -0,0 +1,295 @@
local lclient = require 'lclient'()
local furi = require 'file-uri'
local ws = require 'workspace'
local files = require 'files'
local diag = require 'provider.diagnostic'
local util = require 'utility'
local jsonb = require 'json-beautify'
local lang = require 'language'
local define = require 'proto.define'
local protoDiag = require 'proto.diagnostic'
local config = require 'config.config'
local fs = require 'bee.filesystem'
local provider = require 'provider'
local await = require 'await'
require 'plugin'
require 'vm'
local export = {}
local colors
if not os.getenv('NO_COLOR') then
colors = {
red = '\27[31m',
green = '\27[32m',
yellow = '\27[33m',
blue = '\27[34m',
magenta = '\27[35m',
white = '\27[37m',
grey = '\27[90m',
reset = '\27[0m'
}
else
colors = {
red = '',
green = '',
yellow = '',
blue = '',
magenta = '',
white = '',
grey = '',
reset = ''
}
end
--- @type table<DiagnosticSeverity, string>
local severity_colors = {
Error = colors.red,
Warning = colors.yellow,
Information = colors.white,
Hint = colors.white,
}
local severity_str = {} --- @type table<integer,DiagnosticSeverity>
for k, v in pairs(define.DiagnosticSeverity) do
severity_str[v] = k
end
local pwd
---@param path string
---@return string
local function relpath(path)
if not pwd then
pwd = furi.decode(furi.encode(fs.current_path():string()))
end
if pwd and path:sub(1, #pwd) == pwd then
path = path:sub(#pwd + 2)
end
return path
end
local function report_pretty(uri, diags)
local path = relpath(furi.decode(uri))
local lines = {} --- @type string[]
pcall(function()
for line in io.lines(path) do
table.insert(lines, line)
end
end)
for _, d in ipairs(diags) do
local rstart = d.range.start
local rend = d.range['end']
local severity = severity_str[d.severity]
print(
('%s%s:%s:%s%s [%s%s%s] %s %s(%s)%s'):format(
colors.blue,
path,
rstart.line + 1, -- Use 1-based indexing
rstart.character + 1, -- Use 1-based indexing
colors.reset,
severity_colors[severity],
severity,
colors.reset,
d.message,
colors.magenta,
d.code,
colors.reset
)
)
if #lines > 0 then
io.write(' ', lines[rstart.line + 1], '\n')
io.write(' ', colors.grey, (' '):rep(rstart.character), '^')
if rstart.line == rend.line then
io.write(('^'):rep(rend.character - rstart.character - 1))
end
io.write(colors.reset, '\n')
end
end
end
local function clear_line()
-- Write out empty space to ensure that the previous lien is cleared.
io.write('\x0D', (' '):rep(80), '\x0D')
end
--- @param i integer
--- @param max integer
--- @param results table<string, table[]>
local function report_progress(i, max, results)
local filesWithErrors = 0
local errors = 0
for _, diags in pairs(results) do
filesWithErrors = filesWithErrors + 1
errors = errors + #diags
end
clear_line()
io.write(
('>'):rep(math.ceil(i / max * 20)),
('='):rep(20 - math.ceil(i / max * 20)),
' ',
('0'):rep(#tostring(max) - #tostring(i)),
tostring(i),
'/',
tostring(max)
)
if errors > 0 then
io.write(' [', lang.script('CLI_CHECK_PROGRESS', errors, filesWithErrors), ']')
end
io.flush()
end
--- @param uri string
--- @param checkLevel integer
local function apply_check_level(uri, checkLevel)
local config_disables = util.arrayToHash(config.get(uri, 'Lua.diagnostics.disable'))
local config_severities = config.get(uri, 'Lua.diagnostics.severity')
for name, serverity in pairs(define.DiagnosticDefaultSeverity) do
serverity = config_severities[name] or serverity
if serverity:sub(-1) == '!' then
serverity = serverity:sub(1, -2)
end
if define.DiagnosticSeverity[serverity] > checkLevel then
config_disables[name] = true
end
end
config.set(uri, 'Lua.diagnostics.disable', util.getTableKeys(config_disables, true))
end
local function downgrade_checks_to_opened(uri)
local diagStatus = config.get(uri, 'Lua.diagnostics.neededFileStatus')
for d, status in pairs(diagStatus) do
if status == 'Any' or status == 'Any!' then
diagStatus[d] = 'Opened!'
end
end
for d, status in pairs(protoDiag.getDefaultStatus()) do
if status == 'Any' or status == 'Any!' then
diagStatus[d] = 'Opened!'
end
end
config.set(uri, 'Lua.diagnostics.neededFileStatus', diagStatus)
end
function export.runCLI()
lang(LOCALE)
local numThreads = tonumber(NUM_THREADS or 1)
local threadId = tonumber(THREAD_ID or 1)
local quiet = QUIET or numThreads > 1
if type(CHECK_WORKER) ~= 'string' then
print(lang.script('CLI_CHECK_ERROR_TYPE', type(CHECK_WORKER)))
return
end
local rootPath = fs.canonical(fs.path(CHECK_WORKER)):string()
local rootUri = furi.encode(rootPath)
if not rootUri then
print(lang.script('CLI_CHECK_ERROR_URI', rootPath))
return
end
rootUri = rootUri:gsub("/$", "")
if CHECKLEVEL and not define.DiagnosticSeverity[CHECKLEVEL] then
print(lang.script('CLI_CHECK_ERROR_LEVEL', 'Error, Warning, Information, Hint'))
return
end
local checkLevel = define.DiagnosticSeverity[CHECKLEVEL] or define.DiagnosticSeverity.Warning
util.enableCloseFunction()
local lastClock = os.clock()
local results = {} --- @type table<string, table[]>
local function errorhandler(err)
print(err)
print(debug.traceback())
end
---@async
xpcall(lclient.start, errorhandler, lclient, function (client)
await.disable()
client:registerFakers()
client:initialize {
rootUri = rootUri,
}
client:register('textDocument/publishDiagnostics', function (params)
results[params.uri] = params.diagnostics
if not QUIET and (CHECK_FORMAT == nil or CHECK_FORMAT == 'pretty') then
clear_line()
report_pretty(params.uri, params.diagnostics)
end
end)
if not quiet then
io.write(lang.script('CLI_CHECK_INITING'))
end
provider.updateConfig(rootUri)
ws.awaitReady(rootUri)
-- Disable any diagnostics that are above the check level
apply_check_level(rootUri, checkLevel)
-- Downgrade file opened status to Opened for everything to avoid
-- reporting during compilation on files that do not belong to this thread
downgrade_checks_to_opened(rootUri)
local uris = files.getChildFiles(rootUri)
local max = #uris
table.sort(uris) -- sort file list to ensure the work distribution order across multiple threads
for i, uri in ipairs(uris) do
if (i % numThreads + 1) == threadId and not ws.isIgnored(uri) then
files.open(uri)
diag.doDiagnostic(uri, true)
-- Print regularly but always print the last entry to ensure
-- that logs written to files don't look incomplete.
if not quiet and (os.clock() - lastClock > 0.2 or i == #uris) then
lastClock = os.clock()
client:update()
report_progress(i, max, results)
end
end
end
if not quiet then
clear_line()
end
end)
local count = 0
for uri, result in pairs(results) do
count = count + #result
if #result == 0 then
results[uri] = nil
end
end
local outpath = nil
if CHECK_FORMAT == 'json' or CHECK_OUT_PATH then
outpath = CHECK_OUT_PATH or LOGPATH .. '/check.json'
-- Always write result, even if it's empty to make sure no one accidentally looks at an old output after a successful run.
util.saveFile(outpath, jsonb.beautify(results))
end
if not quiet then
if count == 0 then
print(lang.script('CLI_CHECK_SUCCESS'))
elseif outpath then
print(lang.script('CLI_CHECK_RESULTS_OUTPATH', count, outpath))
else
print(lang.script('CLI_CHECK_RESULTS_PRETTY', count))
end
end
return count == 0 and 0 or 1
end
return export

View File

@@ -0,0 +1,362 @@
---@diagnostic disable: await-in-sync, param-type-mismatch
local ws = require 'workspace'
local vm = require 'vm'
local guide = require 'parser.guide'
local getDesc = require 'core.hover.description'
local getLabel = require 'core.hover.label'
local jsonb = require 'json-beautify'
local util = require 'utility'
local markdown = require 'provider.markdown'
local fs = require 'bee.filesystem'
local furi = require 'file-uri'
---@alias doctype
---| 'doc.alias'
---| 'doc.class'
---| 'doc.field'
---| 'doc.field.name'
---| 'doc.type.arg.name'
---| 'doc.type.function'
---| 'doc.type.table'
---| 'funcargs'
---| 'function'
---| 'function.return'
---| 'global.type'
---| 'global.variable'
---| 'local'
---| 'luals.config'
---| 'self'
---| 'setfield'
---| 'setglobal'
---| 'setindex'
---| 'setmethod'
---| 'tableindex'
---| 'type'
---@class docUnion broadest possible collection of exported docs, these are never all together.
---@field [1] string in name when table, always the same as view
---@field args docUnion[] list of argument docs passed to function
---@field async boolean has @async tag
---@field defines docUnion[] list of places where this is doc is defined and how its defined there
---@field deprecated boolean has @deprecated tag
---@field desc string code commentary
---@field extends string | docUnion ? what type this 'is'. string:<Parent_Class> for type: 'type', docUnion for type: 'function', string<primative> for other type 's
---@field fields docUnion[] class's fields
---@field file string path to where this token is defined
---@field finish [integer, integer] 0-indexed [line, column] position of end of token
---@field name string canonical name
---@field rawdesc string same as desc, but may have other things for types doc.retun andr doc.param (unused?)
---@field returns docUnion | docUnion[] list of docs for return values. if singluar, then always {type: 'undefined'}? might be a bug.
---@field start [integer, integer] 0-indexed [line, column] position of start of token
---@field type doctype role that this token plays in documentation. different from the 'type'/'class' this token is
---@field types docUnion[] type union? unclear. seems to be related to alias, maybe
---@field view string full method name, class, basal type, or unknown. in name table same as [1]
---@field visible 'package'|'private'|'protected'|'public' visibilty tag
local export = {}
function export.getLocalPath(uri)
local file_canonical = fs.canonical(furi.decode(uri)):string()
local doc_canonical = fs.canonical(DOC):string()
local relativePath = fs.relative(file_canonical, doc_canonical):string()
if relativePath == "" or relativePath:sub(1, 2) == '..' then
-- not under project directory
return '[FOREIGN] ' .. file_canonical
end
return relativePath
end
function export.positionOf(rowcol)
return type(rowcol) == 'table' and guide.positionOf(rowcol[1], rowcol[2]) or -1
end
function export.sortDoc(a,b)
if a.name ~= b.name then
return a.name < b.name
end
if a.file ~= b.file then
return a.file < b.file
end
return export.positionOf(a.start) < export.positionOf(b.start)
end
--- recursively generate documentation all parser objects downstream of `source`
---@async
---@param source parser.object | vm.global
---@param has_seen table? keeps track of visited nodes in documentation tree
---@return docUnion | [docUnion] | string | number | boolean | nil
function export.documentObject(source, has_seen)
--is this a primative type? then we dont need to process it.
if type(source) ~= 'table' then return source end
--set up/check recursion
if not has_seen then has_seen = {} end
if has_seen[source] then
return nil
end
has_seen[source] = true
--is this an array type? then process each array item and collect it
if (#source > 0 and next(source, #source) == nil) then
local objs = {} --make a pure numerical array
for i, child in ipairs(source) do
objs[i] = export.documentObject(child, has_seen)
end
return objs
end
--if neither, then this is a singular docUnion
local obj = export.makeDocObject['INIT'](source, has_seen)
--check if this source has a type (no type sources are usually autogen'd anon functions's return values that are not explicitly stated)
if not obj.type then return obj end
local res = export.makeDocObject[obj.type](source, obj, has_seen)
if res == false then
return nil
end
return res or obj
end
---Switch statement table. functions can be overriden by user file.
---@table
export.makeDocObject = setmetatable({}, {__index = function(t, k)
return function()
--print('DocError: no type "'..k..'"')
end
end})
export.makeDocObject['INIT'] = function(source, has_seen)
---@as docUnion
local ok, desc = pcall(getDesc, source)
local rawok, rawdesc = pcall(getDesc, source, true)
return {
type = source.cate or source.type,
name = export.documentObject((source.getCodeName and source:getCodeName()) or source.name, has_seen),
start = source.start and {guide.rowColOf(source.start)},
finish = source.finish and {guide.rowColOf(source.finish)},
types = export.documentObject(source.types, has_seen),
view = vm.getInfer(source):view(ws.rootUri),
desc = ok and desc or nil,
rawdesc = rawok and rawdesc or nil,
}
end
export.makeDocObject['doc.alias'] = function(source, obj, has_seen)
end
export.makeDocObject['doc.field'] = function(source, obj, has_seen)
if source.field.type == 'doc.field.name' then
obj.name = source.field[1]
else
obj.name = ('[%s]'):format(vm.getInfer(source.field):view(ws.rootUri))
end
obj.file = export.getLocalPath(guide.getUri(source))
obj.extends = source.extends and export.documentObject(source.extends, has_seen) --check if bug?
obj.async = vm.isAsync(source, true) and true or false --if vm.isAsync(set, true) then result.defines[#result.defines].extends['async'] = true end
obj.deprecated = vm.getDeprecated(source) and true or false -- if (depr and not depr.versions) the result.defines[#result.defines].extends['deprecated'] = true end
obj.visible = vm.getVisibleType(source)
end
export.makeDocObject['doc.class'] = function(source, obj, has_seen)
local extends = source.extends or source.value --doc.class or other
local field = source.field or source.method
obj.name = type(field) == 'table' and field[1] or nil
obj.file = export.getLocalPath(guide.getUri(source))
obj.extends = extends and export.documentObject(extends, has_seen)
obj.async = vm.isAsync(source, true) and true or false
obj.deprecated = vm.getDeprecated(source) and true or false
obj.visible = vm.getVisibleType(source)
end
export.makeDocObject['doc.field.name'] = function(source, obj, has_seen)
obj['[1]'] = export.documentObject(source[1], has_seen)
obj.view = source[1]
end
export.makeDocObject['doc.type.arg.name'] = export.makeDocObject['doc.field.name']
export.makeDocObject['doc.type.function'] = function(source, obj, has_seen)
obj.args = export.documentObject(source.args, has_seen)
obj.returns = export.documentObject(source.returns, has_seen)
end
export.makeDocObject['doc.type.table'] = function(source, obj, has_seen)
obj.fields = export.documentObject(source.fields, has_seen)
end
export.makeDocObject['funcargs'] = function(source, obj, has_seen)
local objs = {} --make a pure numerical array
for i, child in ipairs(source) do
objs[i] = export.documentObject(child, has_seen)
end
return objs
end
export.makeDocObject['function'] = function(source, obj, has_seen)
obj.args = export.documentObject(source.args, has_seen)
obj.view = getLabel(source, source.parent.type == 'setmethod', 1)
local _, _, max = vm.countReturnsOfFunction(source)
if max > 0 then obj.returns = {} end
for i = 1, max do
obj.returns[i] = export.documentObject(vm.getReturnOfFunction(source, i), has_seen) --check if bug?
end
end
export.makeDocObject['function.return'] = function(source, obj, has_seen)
obj.desc = source.comment and getDesc(source.comment)
obj.rawdesc = source.comment and getDesc(source.comment, true)
end
export.makeDocObject['local'] = function(source, obj, has_seen)
obj.name = source[1]
end
export.makeDocObject['self'] = export.makeDocObject['local']
export.makeDocObject['setfield'] = export.makeDocObject['doc.class']
export.makeDocObject['setglobal'] = export.makeDocObject['doc.class']
export.makeDocObject['setindex'] = export.makeDocObject['doc.class']
export.makeDocObject['setmethod'] = export.makeDocObject['doc.class']
export.makeDocObject['tableindex'] = function(source, obj, has_seen)
obj.name = source.index[1]
end
export.makeDocObject['type'] = function(source, obj, has_seen)
if export.makeDocObject['variable'](source, obj, has_seen) == false then
return false
end
obj.fields = {}
vm.getClassFields(ws.rootUri, source, vm.ANY, function (next_source, mark)
if next_source.type == 'doc.field'
or next_source.type == 'setfield'
or next_source.type == 'setmethod'
or next_source.type == 'tableindex'
then
table.insert(obj.fields, export.documentObject(next_source, has_seen))
end
end)
table.sort(obj.fields, export.sortDoc)
end
export.makeDocObject['variable'] = function(source, obj, has_seen)
obj.defines = {}
for _, set in ipairs(source:getSets(ws.rootUri)) do
if set.type == 'setglobal'
or set.type == 'setfield'
or set.type == 'setmethod'
or set.type == 'setindex'
or set.type == 'doc.alias'
or set.type == 'doc.class'
then
table.insert(obj.defines, export.documentObject(set, has_seen))
end
end
if #obj.defines == 0 then return false end
table.sort(obj.defines, export.sortDoc)
end
---gathers the globals that are to be exported in documentation
---@async
---@return table globals
function export.gatherGlobals()
local all_globals = vm.getAllGlobals()
local globals = {}
for _, g in pairs(all_globals) do
table.insert(globals, g)
end
return globals
end
---builds a lua table of based on `globals` and their elements
---@async
---@param globals table
---@param callback fun(i, max)
function export.makeDocs(globals, callback)
local docs = {}
for i, global in ipairs(globals) do
table.insert(docs, export.documentObject(global))
callback(i, #globals)
end
docs[#docs+1] = export.getLualsConfig()
table.sort(docs, export.sortDoc)
return docs
end
function export.getLualsConfig()
return {
name = 'LuaLS',
type = 'luals.config',
DOC = fs.canonical(fs.path(DOC)):string(),
defines = {},
fields = {}
}
end
---takes the table from `makeDocs`, serializes it, and exports it
---@async
---@param docs table
---@param outputDir string
---@return boolean ok, string[] outputPaths, (string|nil)[]? errs
function export.serializeAndExport(docs, outputDir)
local jsonPath = outputDir .. '/doc.json'
local mdPath = outputDir .. '/doc.md'
--export to json
local old_jsonb_supportSparseArray = jsonb.supportSparseArray
jsonb.supportSparseArray = true
local jsonOk, jsonErr = util.saveFile(jsonPath, jsonb.beautify(docs))
jsonb.supportSparseArray = old_jsonb_supportSparseArray
--export to markdown
local md = markdown()
for _, class in ipairs(docs) do
md:add('md', '# ' .. class.name)
md:emptyLine()
md:add('md', class.desc)
md:emptyLine()
if class.defines then
for _, define in ipairs(class.defines) do
if define.extends then
md:add('lua', define.extends.view)
md:emptyLine()
end
end
end
if class.fields then
local mark = {}
for _, field in ipairs(class.fields) do
if not mark[field.name] then
mark[field.name] = true
md:add('md', '## ' .. field.name)
md:emptyLine()
md:add('lua', field.extends.view)
md:emptyLine()
md:add('md', field.desc)
md:emptyLine()
end
end
end
md:splitLine()
end
local mdOk, mdErr = util.saveFile(mdPath, md:string())
--error checking save file
if( not (jsonOk and mdOk) ) then
return false, {jsonPath, mdPath}, {jsonErr, mdErr}
end
return true, {jsonPath, mdPath}
end
return export

View File

@@ -0,0 +1,258 @@
local lclient = require 'lclient'
local furi = require 'file-uri'
local ws = require 'workspace'
local files = require 'files'
local util = require 'utility'
local lang = require 'language'
local config = require 'config.config'
local await = require 'await'
local progress = require 'progress'
local fs = require 'bee.filesystem'
local doc = {}
---Find file 'doc.json'.
---@return fs.path
local function findDocJson()
local doc_json_path
if type(DOC_UPDATE) == 'string' then
doc_json_path = fs.canonical(fs.path(DOC_UPDATE)) .. '/doc.json'
else
doc_json_path = fs.current_path() .. '/doc.json'
end
if fs.exists(doc_json_path) then
return doc_json_path
else
error(string.format('Error: File "%s" not found.', doc_json_path))
end
end
---@return string # path of 'doc.json'
---@return string # path to be documented
local function getPathDocUpdate()
local doc_json_path = findDocJson()
local ok, doc_path = pcall(
function ()
local json = require('json')
local json_file = io.open(doc_json_path:string(), 'r'):read('*all')
local json_data = json.decode(json_file)
for _, section in ipairs(json_data) do
if section.type == 'luals.config' then
return section.DOC
end
end
end)
if ok then
local doc_json_dir = doc_json_path:string():gsub('/doc.json', '')
return doc_json_dir, doc_path
else
error(string.format('Error: Cannot update "%s".', doc_json_path))
end
end
---clones a module and assigns any internal upvalues pointing to the module to the new clone
---useful for sandboxing
---@param tbl any module to be cloned
---@return any module_clone the cloned module
local function reinstantiateModule(tbl, _new_module, _old_module, _has_seen)
_old_module = _old_module or tbl --remember old module only at root
_has_seen = _has_seen or {} --remember visited indecies
if(type(tbl) == 'table') then
if _has_seen[tbl] then return _has_seen[tbl] end
local clone = {}
_has_seen[tbl] = true
for key, value in pairs(tbl) do
clone[key] = reinstantiateModule(value, _new_module or clone, _old_module, _has_seen)
end
setmetatable(clone, getmetatable(tbl))
return clone
elseif(type(tbl) == 'function') then
local func = tbl
if _has_seen[func] then return _has_seen[func] end --copy function pointers instead of building clones
local upvalues = {}
local i = 1
while true do
local label, value = debug.getupvalue(func, i)
if not value then break end
upvalues[i] = value == _old_module and _new_module or value
i = i + 1
end
local new_func = load(string.dump(func))--, 'function@reinstantiateModule()', 'b', _ENV)
assert(new_func, 'could not load dumped function')
for index, upvalue in ipairs(upvalues) do
debug.setupvalue(new_func, index, upvalue)
end
_has_seen[func] = new_func
return new_func
else
return tbl
end
end
--these modules need to be loaded by the time this function is created
--im leaving them here since this is a pretty strange function that might get moved somewhere else later
--so make sure to bring these with you!
require 'workspace'
require 'vm'
require 'parser.guide'
require 'core.hover.description'
require 'core.hover.label'
require 'json-beautify'
require 'utility'
require 'provider.markdown'
---Gets config file's doc gen overrides.
---@return table dirty_module clone of the export module modified by user buildscript
local function injectBuildScript()
local sub_path = config.get(ws.rootUri, 'Lua.docScriptPath')
local module = reinstantiateModule( ( require 'cli.doc.export' ) )
--if default, then no build script modifications
if sub_path == '' then
return module
end
local resolved_path = fs.absolute(fs.path(DOC)):string() .. sub_path
local f <close> = io.open(resolved_path, 'r')
if not f then
error('could not open config file at '..tostring(resolved_path))
end
--include all `require`s in script.cli.doc.export in enviroment
--NOTE: allows access to the global enviroment!
local data, err = loadfile(resolved_path, 't', setmetatable({
export = module,
ws = require 'workspace',
vm = require 'vm',
guide = require 'parser.guide',
getDesc = require 'core.hover.description',
getLabel = require 'core.hover.label',
jsonb = require 'json-beautify',
util = require 'utility',
markdown = require 'provider.markdown'
},
{__index = _G}))
if err or not data then
error(err, 0)
end
data()
return module
end
---runtime call for documentation exporting
---@async
---@param outputPath string
function doc.makeDoc(outputPath)
ws.awaitReady(ws.rootUri)
local expandAlias = config.get(ws.rootUri, 'Lua.hover.expandAlias')
config.set(ws.rootUri, 'Lua.hover.expandAlias', false)
local _ <close> = function ()
config.set(ws.rootUri, 'Lua.hover.expandAlias', expandAlias)
end
await.sleep(0.1)
-- ready --
local prog <close> = progress.create(ws.rootUri, lang.script('CLI_DOC_WORKING'), 0)
local dirty_export = injectBuildScript()
local globals = dirty_export.gatherGlobals()
local docs = dirty_export.makeDocs(globals, function (i, max)
prog:setMessage(('%d/%d'):format(i, max))
prog:setPercentage((i) / max * 100)
end)
local ok, outPaths, err = dirty_export.serializeAndExport(docs, outputPath)
if not ok then
error(err)
end
return table.unpack(outPaths)
end
---CLI call for documentation (parameter '--DOC=...' is passed to server)
function doc.runCLI()
lang(LOCALE)
if DOC_UPDATE then
DOC_OUT_PATH, DOC = getPathDocUpdate()
end
if type(DOC) ~= 'string' then
print(lang.script('CLI_CHECK_ERROR_TYPE', type(DOC)))
return
end
local rootUri = furi.encode(fs.canonical(fs.path(DOC)):string())
if not rootUri then
print(lang.script('CLI_CHECK_ERROR_URI', DOC))
return
end
print('root uri = ' .. rootUri)
--- If '--configpath' is specified, get the folder path of the '.luarc.doc.json' configuration file (without the file name)
--- 如果指定了'--configpath',则获取`.luarc.doc.json` 配置文件的文件夹路径(不包含文件名)
--- This option is passed into the callback function of the initialized method in provide.
--- 该选项会被传入到`provide`中的`initialized`方法的回调函数中
local luarcParentUri
if CONFIGPATH then
luarcParentUri = furi.encode(fs.absolute(fs.path(CONFIGPATH)):parent_path():string())
end
util.enableCloseFunction()
local lastClock = os.clock()
---@async
lclient():start(function (client)
client:registerFakers()
client:initialize {
rootUri = rootUri,
luarcParentUri = luarcParentUri,
}
io.write(lang.script('CLI_DOC_INITING'))
config.set(nil, 'Lua.diagnostics.enable', false)
config.set(nil, 'Lua.hover.expandAlias', false)
ws.awaitReady(rootUri)
await.sleep(0.1)
--ready--
local dirty_export = injectBuildScript()
local globals = dirty_export.gatherGlobals()
local docs = dirty_export.makeDocs(globals, function (i, max)
if os.clock() - lastClock > 0.2 then
lastClock = os.clock()
local output = '\x0D'
.. ('>'):rep(math.ceil(i / max * 20))
.. ('='):rep(20 - math.ceil(i / max * 20))
.. ' '
.. ('0'):rep(#tostring(max) - #tostring(i))
.. tostring(i) .. '/' .. tostring(max)
io.write(output)
end
end)
io.write('\x0D')
if not DOC_OUT_PATH then
DOC_OUT_PATH = fs.current_path():string()
end
local ok, outPaths, err = dirty_export.serializeAndExport(docs, DOC_OUT_PATH)
print(lang.script('CLI_DOC_DONE'))
for i, path in ipairs(outPaths) do
local this_err = (type(err) == 'table') and err[i] or nil
print(this_err or files.normalize(path))
end
end)
end
return doc

View File

@@ -0,0 +1,170 @@
local util = require 'utility'
--- @class cli.arg
--- @field type? string|string[]
--- @field description string Description of the argument in markdown format.
--- @field example? string
--- @field default? any
--- @type table<string, cli.arg>
local args = {
['--help'] = {
description = [[
Print this message.
]],
},
['--check'] = {
type = 'string',
description = [[
Perform a "diagnosis report" where the results of the diagnosis are written to the logpath.
]],
example = [[--check=C:\Users\Me\path\to\workspace]]
},
['--checklevel'] = {
type = 'string',
description = [[
To be used with --check. The minimum level of diagnostic that should be logged.
Items with lower priority than the one listed here will not be written to the file.
Options include, in order of priority:
- Error
- Warning
- Information
- Hint
]],
default = 'Warning',
example = [[--checklevel=Information]]
},
['--check_format'] = {
type = { 'json', 'pretty' },
description = [[
Output format for the check results.
- 'pretty': results are displayed to stdout in a human-readable format.
- 'json': results are written to a file in JSON format. See --check_out_path
]],
default = 'pretty'
},
['--version'] = {
type = 'boolean',
description = [[
Get the version of the Lua language server.
This will print it to the command line and immediately exit.
]],
},
['--doc'] = {
type = 'string',
description = [[
Generate documentation from a workspace.
The files will be written to the documentation output path passed
in --doc_out_path.
]],
example = [[--doc=C:/Users/Me/Documents/myLuaProject/]]
},
['--doc_out_path'] = {
type = 'string',
description = [[
The path to output generated documentation at.
If --doc_out_path is missing, the documentation will be written
to the current directory.
See --doc for more info.
]],
example = [[--doc_out_path=C:/Users/Me/Documents/myLuaProjectDocumentation]]
},
['--doc_update'] = {
type = 'string',
description = [[
Update existing documentation files at the given path.
]]
},
['--logpath'] = {
type = 'string',
description = [[
Where the log should be written to.
]],
default = './log',
example = [[--logpath=D:/luaServer/logs]]
},
['--loglevel'] = {
type = 'string',
description = [[
The minimum level of logging that should appear in the logfile.
Can be used to log more detailed info for debugging and error reporting.
Options:
- error
- warn
- info
- debug
- trace
]],
example = [[--loglevel=trace]]
},
['--metapath'] = {
type = 'string',
description = [[
Where the standard Lua library definition files should be generated to.
]],
default = './meta',
example = [[--metapath=D:/sumnekoLua/metaDefintions]]
},
['--locale'] = {
type = 'string',
description = [[
The language to use. Defaults to en-us.
Options can be found in locale/ .
]],
example = [[--locale=zh-cn]]
},
['--configpath'] = {
type = 'string',
description = [[
The location of the configuration file that will be loaded.
Can be relative to the workspace.
When provided, config files from elsewhere (such as from VS Code) will no longer be loaded.
]],
example = [[--configpath=sumnekoLuaConfig.lua]]
},
['--force-accept-workspace'] = {
type = 'boolean',
description = [[
Allows the use of root/home directory as the workspace.
]]
},
['--socket'] = {
type = 'number',
description = [[
Will communicate to a client over the specified TCP port instead of through stdio.
]],
example = [[--socket=5050]]
},
['--develop'] = {
type = 'boolean',
description = [[
Enables development mode. This allows plugins to write to the logpath.
]]
}
}
for nm, attrs in util.sortPairs(args) do
if attrs.type == 'boolean' then
print(nm)
else
print(nm .. "=<value>")
end
if attrs.description then
local normalized_description = attrs.description:gsub("^%s+", ""):gsub("\n%s+", "\n"):gsub("%s+$", "")
print("\n " .. normalized_description:gsub('\n', '\n '))
end
local attr_type = attrs.type
if type(attr_type) == "table" then
print("\n Values: " .. table.concat(attr_type, ', '))
end
if attrs.default then
print("\n Default: " .. tostring(attrs.default))
end
if attrs.example then
print("\n Example: " .. attrs.example)
end
print()
end

View File

@@ -0,0 +1,34 @@
if _G['HELP'] then
require 'cli.help'
os.exit(0, true)
end
if _G['VERSION'] then
require 'cli.version'
os.exit(0, true)
end
if _G['CHECK'] then
local ret = require 'cli.check'.runCLI()
os.exit(ret, true)
end
if _G['CHECK_WORKER'] then
local ret = require 'cli.check_worker'.runCLI()
os.exit(ret or 0, true)
end
if _G['DOC_UPDATE'] then
require 'cli.doc' .runCLI()
os.exit(0, true)
end
if _G['DOC'] then
require 'cli.doc' .runCLI()
os.exit(0, true)
end
if _G['VISUALIZE'] then
local ret = require 'cli.visualize' .runCLI()
os.exit(ret or 0, true)
end

View File

@@ -0,0 +1,2 @@
local version = require 'version'
print(version.getVersion())

View File

@@ -0,0 +1,103 @@
local lang = require 'language'
local parser = require 'parser'
local guide = require 'parser.guide'
local function nodeId(node)
return node.type .. ':' .. node.start .. ':' .. node.finish
end
local function shorten(str)
if type(str) ~= 'string' then
return str
end
str = str:gsub('\n', '\\\\n')
if #str <= 20 then
return str
else
return str:sub(1, 17) .. '...'
end
end
local function getTooltipLine(k, v)
if type(v) == 'table' then
if v.type then
v = '<node ' .. v.type .. '>'
else
v = '<table>'
end
end
v = tostring(v)
v = v:gsub('"', '\\"')
return k .. ': ' .. shorten(v) .. '\\n'
end
local function getTooltip(node)
local str = ''
local skipNodes = {parent = true, start = true, finish = true, type = true}
str = str .. getTooltipLine('start', node.start)
str = str .. getTooltipLine('finish', node.finish)
for k, v in pairs(node) do
if type(k) ~= 'number' and not skipNodes[k] then
str = str .. getTooltipLine(k, v)
end
end
for i = 1, math.min(#node, 15) do
str = str .. getTooltipLine(i, node[i])
end
if #node > 15 then
str = str .. getTooltipLine('15..' .. #node, '(...)')
end
return str
end
local nodeEntry = '\t"%s" [\n\t\tlabel="%s\\l%s\\l"\n\t\ttooltip="%s"\n\t]'
local function getNodeLabel(node)
local keyName = guide.getKeyName(node)
if node.type == 'binary' or node.type == 'unary' then
keyName = node.op.type
elseif node.type == 'label' or node.type == 'goto' then
keyName = node[1]
end
return nodeEntry:format(nodeId(node), node.type, shorten(keyName) or '', getTooltip(node))
end
local function getVisualizeVisitor(writer)
local function visitNode(node, parent)
if node == nil then return end
writer:write(getNodeLabel(node))
writer:write('\n')
if parent then
writer:write(('\t"%s" -> "%s"'):format(nodeId(parent), nodeId(node)))
writer:write('\n')
end
guide.eachChild(node, function(child)
visitNode(child, node)
end)
end
return visitNode
end
local export = {}
function export.visualizeAst(code, writer)
local state = parser.compile(code, 'Lua', _G['LUA_VER'] or 'Lua 5.4')
writer:write('digraph AST {\n')
writer:write('\tnode [shape = rect]\n')
getVisualizeVisitor(writer)(state.ast)
writer:write('}\n')
end
function export.runCLI()
lang(LOCALE)
local file = _G['VISUALIZE']
local code, err = io.open(file)
if not code then
io.stderr:write('failed to open ' .. file .. ': ' .. err)
return 1
end
code = code:read('a')
return export.visualizeAst(code, io.stdout)
end
return export

View File

@@ -0,0 +1,668 @@
local fs = require 'bee.filesystem'
local nonil = require 'without-check-nil'
local util = require 'utility'
local lang = require 'language'
local proto = require 'proto'
local define = require 'proto.define'
local config = require 'config'
local converter = require 'proto.converter'
local await = require 'await'
local scope = require 'workspace.scope'
local inspect = require 'inspect'
local jsone = require 'json-edit'
local jsonc = require 'jsonc'
local m = {}
m._eventList = {}
function m.client(newClient)
if newClient then
m._client = newClient
else
return m._client
end
end
function m.isVSCode()
if not m._client then
return false
end
if m._isvscode == nil then
local lname = m._client:lower()
if lname:find 'vscode'
or lname:find 'visual studio code' then
m._isvscode = true
else
m._isvscode = false
end
end
return m._isvscode
end
function m.getOption(name)
nonil.enable()
local option = m.info.initializationOptions[name]
nonil.disable()
return option
end
function m.getAbility(name)
if not m.info
or not m.info.capabilities then
return nil
end
local current = m.info.capabilities
while true do
local parent, nextPos = name:match '^([^%.]+)()'
if not parent then
break
end
current = current[parent]
if not current then
return current
end
if nextPos > #name then
break
else
name = name:sub(nextPos + 1)
end
end
return current
end
function m.getOffsetEncoding()
if m._offsetEncoding then
return m._offsetEncoding
end
local clientEncodings = m.getAbility 'offsetEncoding'
if type(clientEncodings) == 'table' then
for _, encoding in ipairs(clientEncodings) do
if encoding == 'utf-8' then
m._offsetEncoding = 'utf-8'
return m._offsetEncoding
end
end
end
m._offsetEncoding = 'utf-16'
return m._offsetEncoding
end
local function packMessage(...)
local strs = table.pack(...)
for i = 1, strs.n do
strs[i] = tostring(strs[i])
end
return table.concat(strs, '\t')
end
---@alias message.type '"Error"'|'"Warning"'|'"Info"'|'"Log"'
---show message to client
---@param type message.type
function m.showMessage(type, ...)
local message = packMessage(...)
proto.notify('window/showMessage', {
type = define.MessageType[type] or 3,
message = message,
})
proto.notify('window/logMessage', {
type = define.MessageType[type] or 3,
message = message,
})
log.info('ShowMessage', type, message)
end
---@param type message.type
---@param message string
---@param titles string[]
---@param callback fun(action?: string, index?: integer)
function m.requestMessage(type, message, titles, callback)
proto.notify('window/logMessage', {
type = define.MessageType[type] or 3,
message = message,
})
local map = {}
local actions = {}
for i, title in ipairs(titles) do
actions[i] = {
title = title,
}
map[title] = i
end
log.info('requestMessage', type, message)
proto.request('window/showMessageRequest', {
type = define.MessageType[type] or 3,
message = message,
actions = actions,
}, function (item)
log.info('responseMessage', message, item and item.title or nil)
if item then
callback(item.title, map[item.title])
else
callback(nil, nil)
end
end)
end
---@param type message.type
---@param message string
---@param titles string[]
---@return string action
---@return integer index
---@async
function m.awaitRequestMessage(type, message, titles)
return await.wait(function (waker)
m.requestMessage(type, message, titles, waker)
end)
end
---@param type message.type
function m.logMessage(type, ...)
local message = packMessage(...)
proto.notify('window/logMessage', {
type = define.MessageType[type] or 4,
message = message,
})
end
function m.watchFiles(path)
path = path:gsub('\\', '/')
:gsub('[%[%]%{%}%*%?]', '\\%1')
local registration = {
id = path,
method = 'workspace/didChangeWatchedFiles',
registerOptions = {
watchers = {
{
globPattern = path .. '/**',
kind = 1 | 2 | 4,
},
},
},
}
proto.request('client/registerCapability', {
registrations = {
registration,
}
})
return function ()
local unregisteration = {
id = path,
method = 'workspace/didChangeWatchedFiles',
}
proto.request('client/registerCapability', {
unregisterations = {
unregisteration,
}
})
end
end
---@class config.change
---@field key string
---@field prop? string
---@field value any
---@field action '"add"'|'"set"'|'"prop"'
---@field global? boolean
---@field uri? uri
---@param uri uri?
---@param changes config.change[]
---@return config.change[]
local function getValidChanges(uri, changes)
local newChanges = {}
if not uri then
return changes
end
local scp = scope.getScope(uri)
for _, change in ipairs(changes) do
if scp:isChildUri(change.uri)
or scp:isLinkedUri(change.uri) then
newChanges[#newChanges+1] = change
end
end
return newChanges
end
---@class json.patch
---@field op 'add' | 'remove' | 'replace'
---@field path string
---@field value any
---@class json.patchInfo
---@field key string
---@field value any
---@param cfg table
---@param rawKey string
---@return json.patchInfo
local function searchPatchInfo(cfg, rawKey)
---@param key string
---@param parentKey string
---@param parentValue table
---@return json.patchInfo?
local function searchOnce(key, parentKey, parentValue)
if parentValue == nil then
return nil
end
if type(parentValue) ~= 'table' then
return {
key = parentKey,
value = parentValue,
}
end
if parentValue[key] then
return {
key = parentKey .. '/' .. key,
value = parentValue[key],
}
end
for pos in key:gmatch '()%.' do
local k = key:sub(1, pos - 1)
local v = parentValue[k]
local info = searchOnce(key:sub(pos + 1), parentKey .. '/' .. k, v)
if info then
return info
end
end
return nil
end
return searchOnce(rawKey, '', cfg)
or searchOnce(rawKey:gsub('^Lua%.', ''), '', cfg)
or {
key = '/' .. rawKey:gsub('^Lua%.', ''),
value = nil,
}
end
---@param uri? uri
---@param cfg table
---@param change config.change
---@return json.patch?
local function makeConfigPatch(uri, cfg, change)
local info = searchPatchInfo(cfg, change.key)
if change.action == 'add' then
if type(info.value) == 'table' and #info.value > 0 then
return {
op = 'add',
path = info.key .. '/-',
value = change.value,
}
else
return makeConfigPatch(uri, cfg, {
action = 'set',
key = change.key,
value = config.get(uri, change.key),
})
end
elseif change.action == 'set' then
if info.value ~= nil then
return {
op = 'replace',
path = info.key,
value = change.value,
}
else
return {
op = 'add',
path = info.key,
value = change.value,
}
end
elseif change.action == 'prop' then
if type(info.value) == 'table' and next(info.value) then
return {
op = 'add',
path = info.key .. '/' .. change.prop,
value = change.value,
}
else
return makeConfigPatch(uri, cfg, {
action = 'set',
key = change.key,
value = config.get(uri, change.key),
})
end
end
return nil
end
---@param uri? uri
---@param path string
---@param changes config.change[]
---@return string?
local function editConfigJson(uri, path, changes)
local text = util.loadFile(path)
if not text then
m.showMessage('Error', lang.script('CONFIG_LOAD_FAILED', path))
return nil
end
local suc, res = pcall(jsonc.decode_jsonc, text)
if not suc then
m.showMessage('Error', lang.script('CONFIG_MODIFY_FAIL_SYNTAX_ERROR', path .. res:match 'ERROR(.+)$'))
return nil
end
if type(res) ~= 'table' then
res = {}
end
---@cast res table
for _, change in ipairs(changes) do
local patch = makeConfigPatch(uri, res, change)
if patch then
text = jsone.edit(text, patch, { indent = ' ' })
end
end
return text
end
---@param changes config.change[]
---@param applied config.change[]
local function removeAppliedChanges(changes, applied)
local appliedMap = {}
for _, change in ipairs(applied) do
appliedMap[change] = true
end
for i = #changes, 1, -1 do
if appliedMap[changes[i]] then
table.remove(changes, i)
end
end
end
local function tryModifySpecifiedConfig(uri, finalChanges)
if #finalChanges == 0 then
return false
end
log.info('tryModifySpecifiedConfig', uri, inspect(finalChanges))
local workspace = require 'workspace'
local scp = scope.getScope(uri)
if scp:get('lastLocalType') ~= 'json' then
log.info('lastLocalType ~= json')
return false
end
local validChanges = getValidChanges(uri, finalChanges)
if #validChanges == 0 then
log.info('No valid changes')
return false
end
local path = workspace.getAbsolutePath(uri, CONFIGPATH)
if not path then
log.info('Can not get absolute path')
return false
end
local newJson = editConfigJson(uri, path, validChanges)
if not newJson then
log.info('Can not edit config json')
return false
end
util.saveFile(path, newJson)
log.info('Apply changes to config file', inspect(validChanges))
removeAppliedChanges(finalChanges, validChanges)
return true
end
local function tryModifyRC(uri, finalChanges, create)
if #finalChanges == 0 then
return false
end
log.info('tryModifyRC', uri, inspect(finalChanges))
local workspace = require 'workspace'
local path = workspace.getAbsolutePath(uri, '.luarc.jsonc')
if not path then
log.info('Can not get absolute path of .luarc.jsonc')
return false
end
path = fs.exists(fs.path(path)) and path or workspace.getAbsolutePath(uri, '.luarc.json')
if not path then
log.info('Can not get absolute path of .luarc.json')
return false
end
local buf = util.loadFile(path)
if not buf and not create then
log.info('Can not load .luarc.json and not create')
return false
end
local validChanges = getValidChanges(uri, finalChanges)
if #validChanges == 0 then
log.info('No valid changes')
return false
end
if not buf then
util.saveFile(path, '')
end
local newJson = editConfigJson(uri, path, validChanges)
if not newJson then
log.info('Can not edit config json')
return false
end
util.saveFile(path, newJson)
log.info('Apply changes to .luarc.json', inspect(validChanges))
removeAppliedChanges(finalChanges, validChanges)
return true
end
local function tryModifyClient(uri, finalChanges)
if #finalChanges == 0 then
return false
end
log.info('tryModifyClient', uri, inspect(finalChanges))
if not m.getOption 'changeConfiguration' then
return false
end
local scp = scope.getScope(uri)
local scpChanges = {}
for _, change in ipairs(finalChanges) do
if change.uri
and (scp:isChildUri(change.uri) or scp:isLinkedUri(change.uri)) then
scpChanges[#scpChanges+1] = change
end
end
if #scpChanges == 0 then
log.info('No changes in client scope')
return false
end
proto.notify('$/command', {
command = 'lua.config',
data = scpChanges,
})
log.info('Apply client changes', uri, inspect(scpChanges))
removeAppliedChanges(finalChanges, scpChanges)
return true
end
---@param finalChanges config.change[]
local function tryModifyClientGlobal(finalChanges)
if #finalChanges == 0 then
return
end
log.info('tryModifyClientGlobal', inspect(finalChanges))
if not m.getOption 'changeConfiguration' then
log.info('Client dose not support modifying config')
return
end
local changes = {}
for _, change in ipairs(finalChanges) do
if change.global then
changes[#changes+1] = change
end
end
if #changes == 0 then
log.info('No global changes')
return
end
proto.notify('$/command', {
command = 'lua.config',
data = changes,
})
log.info('Apply client global changes', inspect(changes))
removeAppliedChanges(finalChanges, changes)
end
---@param changes config.change[]
---@return string
local function buildMaunuallyMessage(changes)
local message = {}
for _, change in ipairs(changes) do
if change.action == 'add' then
message[#message+1] = '* ' .. lang.script('WINDOW_MANUAL_CONFIG_ADD', change.key, change.value)
elseif change.action == 'set' then
message[#message+1] = '* ' .. lang.script('WINDOW_MANUAL_CONFIG_SET', change.key, change.value)
elseif change.action == 'prop' then
message[#message+1] = '* ' .. lang.script('WINDOW_MANUAL_CONFIG_PROP', change.key, change.prop, change.value)
end
end
return table.concat(message, '\n')
end
---@param changes config.change[]
---@param onlyMemory? boolean
function m.setConfig(changes, onlyMemory)
local finalChanges = {}
for _, change in ipairs(changes) do
if change.action == 'add' then
local suc = config.add(change.uri, change.key, change.value)
if suc then
finalChanges[#finalChanges+1] = change
end
elseif change.action == 'set' then
local suc = config.set(change.uri, change.key, change.value)
if suc then
finalChanges[#finalChanges+1] = change
end
elseif change.action == 'prop' then
local suc = config.prop(change.uri, change.key, change.prop, change.value)
if suc then
finalChanges[#finalChanges+1] = change
end
end
end
if onlyMemory then
return
end
if #finalChanges == 0 then
return
end
log.info('Modify config', inspect(finalChanges))
xpcall(function ()
local ws = require 'workspace'
tryModifyClientGlobal(finalChanges)
if #ws.folders == 0 then
tryModifySpecifiedConfig(nil, finalChanges)
tryModifyClient(nil, finalChanges)
if #finalChanges > 0 then
local manuallyModifyConfig = buildMaunuallyMessage(finalChanges)
m.showMessage('Warning', lang.script('CONFIG_MODIFY_FAIL_NO_WORKSPACE', manuallyModifyConfig))
end
else
for _, scp in ipairs(ws.folders) do
tryModifySpecifiedConfig(scp.uri, finalChanges)
tryModifyRC(scp.uri, finalChanges, false)
tryModifyClient(scp.uri, finalChanges)
tryModifyRC(scp.uri, finalChanges, true)
end
if #finalChanges > 0 then
m.showMessage('Warning', lang.script('CONFIG_MODIFY_FAIL', buildMaunuallyMessage(finalChanges)))
log.warn('Config modify fail', inspect(finalChanges))
end
end
end, log.error)
end
---@alias textEditor {start: integer, finish: integer, text: string}
---@param uri uri
---@param edits textEditor[]
function m.editText(uri, edits)
local files = require 'files'
local state = files.getState(uri)
if not state then
return
end
local textEdits = {}
for i, edit in ipairs(edits) do
textEdits[i] = converter.textEdit(converter.packRange(state, edit.start, edit.finish), edit.text)
end
local params = {
edit = {
changes = {
[uri] = textEdits,
}
}
}
proto.request('workspace/applyEdit', params)
log.info('workspace/applyEdit', inspect(params))
end
---@alias textMultiEditor {uri: uri, start: integer, finish: integer, text: string}
---@param editors textMultiEditor[]
function m.editMultiText(editors)
local files = require 'files'
local changes = {}
for _, editor in ipairs(editors) do
local uri = editor.uri
local state = files.getState(uri)
if state then
if not changes[uri] then
changes[uri] = {}
end
local edit = converter.textEdit(converter.packRange(state, editor.start, editor.finish), editor.text)
table.insert(changes[uri], edit)
end
end
local params = {
edit = {
changes = changes,
}
}
proto.request('workspace/applyEdit', params)
log.info('workspace/applyEdit', inspect(params))
end
---@param callback async fun(ev: string)
function m.event(callback)
m._eventList[#m._eventList+1] = callback
end
function m._callEvent(ev)
for _, callback in ipairs(m._eventList) do
await.call(function ()
callback(ev)
end)
end
end
function m.setReady()
m._ready = true
m._callEvent('ready')
end
function m.isReady()
return m._ready == true
end
local function hookPrint()
if TEST or CLI then
return
end
print = function (...)
m.logMessage('Log', ...)
end
end
function m.init(t)
log.info('Client init', inspect(t))
m.info = t
nonil.enable()
m.client(t.clientInfo.name)
nonil.disable()
lang(LOCALE or t.locale)
converter.setOffsetEncoding(m.getOffsetEncoding())
hookPrint()
m._callEvent('init')
end
return m

View File

@@ -0,0 +1,276 @@
local util = require 'utility'
local timer = require 'timer'
local scope = require 'workspace.scope'
local template = require 'config.template'
---@alias config.source '"client"'|'"path"'|'"local"'
---@class config.api
local m = {}
m.watchList = {}
m.NULL = {}
m.nullSymbols = {
[m.NULL] = true,
}
---@param scp scope
---@param key string
---@param nowValue any
---@param rawValue any
local function update(scp, key, nowValue, rawValue)
local now = m.getNowTable(scp)
local raw = m.getRawTable(scp)
now[key] = nowValue
raw[key] = rawValue
end
---@param uri? uri
---@param key? string
---@return scope
local function getScope(uri, key)
local raw = m.getRawTable(scope.override)
if raw then
if not key or raw[key] ~= nil then
return scope.override
end
end
if uri then
---@type scope?
local scp = scope.getFolder(uri) or scope.getLinkedScope(uri)
if scp then
if not key
or m.getRawTable(scp)[key] ~= nil then
return scp
end
end
end
return scope.fallback
end
---@param scp scope
---@param key string
---@param value any
function m.setByScope(scp, key, value)
local unit = template[key]
if not unit then
return false
end
local raw = m.getRawTable(scp)
if util.equal(raw[key], value) then
return false
end
if unit:checker(value) then
update(scp, key, unit:loader(value), value)
else
update(scp, key, unit.default, unit.default)
end
return true
end
---@param uri? uri
---@param key string
---@param value any
function m.set(uri, key, value)
local unit = template[key]
assert(unit, 'unknown key: ' .. key)
local scp = getScope(uri, key)
local oldValue = m.get(uri, key)
m.setByScope(scp, key, value)
local newValue = m.get(uri, key)
if not util.equal(oldValue, newValue) then
m.event(uri, key, newValue, oldValue)
return true
end
return false
end
function m.add(uri, key, value)
local unit = template[key]
assert(unit, 'unknown key: ' .. key)
local list = m.getRaw(uri, key)
assert(type(list) == 'table', 'not a list: ' .. key)
local copyed = {}
for i, v in ipairs(list) do
if util.equal(v, value) then
return false
end
copyed[i] = v
end
copyed[#copyed+1] = value
local oldValue = m.get(uri, key)
m.set(uri, key, copyed)
local newValue = m.get(uri, key)
if not util.equal(oldValue, newValue) then
m.event(uri, key, newValue, oldValue)
return true
end
return false
end
function m.remove(uri, key, value)
local unit = template[key]
assert(unit, 'unknown key: ' .. key)
local list = m.getRaw(uri, key)
assert(type(list) == 'table', 'not a list: ' .. key)
local copyed = {}
for i, v in ipairs(list) do
if not util.equal(v, value) then
copyed[i] = v
end
end
local oldValue = m.get(uri, key)
m.set(uri, key, copyed)
local newValue = m.get(uri, key)
if not util.equal(oldValue, newValue) then
m.event(uri, key, newValue, oldValue)
return true
end
return false
end
function m.prop(uri, key, prop, value)
local unit = template[key]
assert(unit, 'unknown key: ' .. key)
local map = m.getRaw(uri, key)
assert(type(map) == 'table', 'not a map: ' .. key)
if util.equal(map[prop], value) then
return false
end
local copyed = {}
for k, v in pairs(map) do
copyed[k] = v
end
copyed[prop] = value
local oldValue = m.get(uri, key)
m.set(uri, key, copyed)
local newValue = m.get(uri, key)
if not util.equal(oldValue, newValue) then
m.event(uri, key, newValue, oldValue)
return true
end
return false
end
---@param uri? uri
---@param key string
---@return any
function m.get(uri, key)
local scp = getScope(uri, key)
local value = m.getNowTable(scp)[key]
if value == nil then
value = template[key].default
end
if value == m.NULL then
value = nil
end
return value
end
---@param uri uri
---@param key string
---@return any
function m.getRaw(uri, key)
local scp = getScope(uri, key)
local value = m.getRawTable(scp)[key]
if value == nil then
value = template[key].default
end
if value == m.NULL then
value = nil
end
return value
end
---@param scp scope
function m.getNowTable(scp)
return scp:get 'config.now'
or scp:set('config.now', {})
end
---@param scp scope
function m.getRawTable(scp)
return scp:get 'config.raw'
or scp:set('config.raw', {})
end
---@param scp scope
---@param ... table
function m.update(scp, ...)
local oldConfig = m.getNowTable(scp)
local newConfig = {}
scp:set('config.now', newConfig)
scp:set('config.raw', {})
local function expand(t, left)
for key, value in pairs(t) do
local fullKey = key
if left then
fullKey = left .. '.' .. key
end
if m.nullSymbols[value] then
value = m.NULL
end
if template[fullKey] then
m.setByScope(scp, fullKey, value)
elseif template['Lua.' .. fullKey] then
m.setByScope(scp, 'Lua.' .. fullKey, value)
elseif type(value) == 'table' then
expand(value, fullKey)
end
end
end
local news = table.pack(...)
for i = 1, news.n do
if type(news[i]) == 'table' then
expand(news[i])
end
end
-- compare then fire event
if oldConfig then
for key, oldValue in pairs(oldConfig) do
local newValue = newConfig[key]
if not util.equal(oldValue, newValue) then
m.event(scp.uri, key, newValue, oldValue)
end
end
end
m.event(scp.uri, '')
end
---@param callback fun(uri: uri, key: string, value: any, oldValue: any)
function m.watch(callback)
m.watchList[#m.watchList+1] = callback
end
function m.event(uri, key, value, oldValue)
if not m.changes then
m.changes = {}
timer.wait(0, function ()
local delay = m.changes
m.changes = nil
for _, info in ipairs(delay) do
for _, callback in ipairs(m.watchList) do
callback(info.uri, info.key, info.value, info.oldValue)
end
end
end)
end
m.changes[#m.changes+1] = {
uri = uri,
key = key,
value = value,
oldValue = oldValue,
}
end
function m.addNullSymbol(null)
m.nullSymbols[null] = true
end
return m

View File

@@ -0,0 +1,67 @@
-- Handles loading environment arguments
---Convert a string to boolean
---@param v string
local function strToBool(v)
return v == "true"
end
---ENV args are defined here.
---- `name` is the ENV arg name
---- `key` is the value used to index `_G` for setting the argument
---- `converter` if present, will be used to convert the string value into another type
---@type { name: string, key: string, converter: fun(value: string): any }[]
local vars = {
{
name = "LLS_CHECK_LEVEL",
key = "CHECKLEVEL",
},
{
name = "LLS_CHECK_PATH",
key = "CHECK",
},
{
name = "LLS_CONFIG_PATH",
key = "CONFIGPATH",
},
{
name = "LLS_DOC_OUT_PATH",
key = "DOC_OUT_PATH",
},
{
name = "LLS_DOC_PATH",
key = "DOC",
},
{
name = "LLS_FORCE_ACCEPT_WORKSPACE",
key = "FORCE_ACCEPT_WORKSPACE",
converter = strToBool,
},
{
name = "LLS_LOCALE",
key = "LOCALE",
},
{
name = "LLS_LOG_LEVEL",
key = "LOGLEVEL",
},
{
name = "LLS_LOG_PATH",
key = "LOGPATH",
},
{
name = "LLS_META_PATH",
key = "METAPATH",
},
}
for _, var in ipairs(vars) do
local value = os.getenv(var.name)
if value then
if var.converter then
value = var.converter(value)
end
_G[var.key] = value
end
end

View File

@@ -0,0 +1,3 @@
local config = require 'config.config'
return config

View File

@@ -0,0 +1,131 @@
local proto = require 'proto'
local lang = require 'language'
local util = require 'utility'
local workspace = require 'workspace'
local scope = require 'workspace.scope'
local inspect = require 'inspect'
local jsonc = require 'jsonc'
local function errorMessage(msg)
proto.notify('window/showMessage', {
type = 3,
message = msg,
})
log.error(msg)
end
---@class config.loader
local m = {}
---@return table?
function m.loadRCConfig(uri, filename)
local scp = scope.getScope(uri)
local path = workspace.getAbsolutePath(uri, filename)
if not path then
scp:set('lastRCConfig', nil)
return nil
end
local buf = util.loadFile(path)
if not buf then
scp:set('lastRCConfig', nil)
return nil
end
local suc, res = pcall(jsonc.decode_jsonc, buf)
if not suc then
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastRCConfig')
end
---@cast res table
scp:set('lastRCConfig', res)
return res
end
---@return table?
function m.loadLocalConfig(uri, filename)
if not filename then
return nil
end
local scp = scope.getScope(uri)
local path = workspace.getAbsolutePath(uri, filename)
if not path then
scp:set('lastLocalConfig', nil)
scp:set('lastLocalType', nil)
return nil
end
local buf = util.loadFile(path)
if not buf then
--errorMessage(lang.script('CONFIG_LOAD_FAILED', path))
scp:set('lastLocalConfig', nil)
scp:set('lastLocalType', nil)
return nil
end
local firstChar = buf:match '%S'
if firstChar == '{' then
local suc, res = pcall(jsonc.decode_jsonc, buf)
if not suc then
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
return scp:get('lastLocalConfig')
end
---@cast res table
scp:set('lastLocalConfig', res)
scp:set('lastLocalType', 'json')
return res
else
local suc, res = pcall(function ()
return assert(load(buf, '@' .. path, 't'))()
end)
if not suc then
errorMessage(lang.script('CONFIG_LOAD_ERROR', res))
scp:set('lastLocalConfig', res)
end
scp:set('lastLocalConfig', res)
scp:set('lastLocalType', 'lua')
return res
end
end
---@async
---@param uri? uri
---@return table?
function m.loadClientConfig(uri)
local configs = proto.awaitRequest('workspace/configuration', {
items = {
{
scopeUri = uri,
section = 'Lua',
},
{
scopeUri = uri,
section = 'files.associations',
},
{
scopeUri = uri,
section = 'files.exclude',
},
{
scopeUri = uri,
section = 'editor.semanticHighlighting.enabled',
},
{
scopeUri = uri,
section = 'editor.acceptSuggestionOnEnter',
},
},
})
if not configs or not configs[1] then
log.warn('No config?', inspect(configs))
return nil
end
local newConfig = {
['Lua'] = configs[1],
['files.associations'] = configs[2],
['files.exclude'] = configs[3],
['editor.semanticHighlighting.enabled'] = configs[4],
['editor.acceptSuggestionOnEnter'] = configs[5],
}
return newConfig
end
return m

View File

@@ -0,0 +1,437 @@
local util = require 'utility'
local define = require 'proto.define'
local diag = require 'proto.diagnostic'
---@class config.unit
---@field caller function
---@field loader function
---@field _checker fun(self: config.unit, value: any): boolean
---@field name string
---@operator shl: config.unit
---@operator shr: config.unit
---@operator call: config.unit
local mt = {}
mt.__index = mt
function mt:__call(...)
self:caller(...)
return self
end
function mt:__shr(default)
self.default = default
self.hasDefault = true
return self
end
function mt:__shl(enums)
self.enums = enums
return self
end
function mt:checker(v)
if self.enums then
local ok
for _, enum in ipairs(self.enums) do
if util.equal(enum, v) then
ok = true
break
end
end
if not ok then
return false
end
end
return self:_checker(v)
end
local units = {}
local function register(name, default, checker, loader, caller)
units[name] = {
name = name,
default = default,
_checker = checker,
loader = loader,
caller = caller,
}
end
---@class config.master
---@field [string] config.unit
local Type = setmetatable({}, { __index = function (_, name)
local unit = {}
for k, v in pairs(units[name]) do
unit[k] = v
end
return setmetatable(unit, mt)
end })
register('Boolean', false, function (self, v)
return type(v) == 'boolean'
end, function (self, v)
return v
end)
register('Integer', 0, function (self, v)
return type(v) == 'number'
end, function (self, v)
return math.floor(v)
end)
register('String', '', function (self, v)
return type(v) == 'string'
end, function (self, v)
return tostring(v)
end)
register('Nil', nil, function (self, v)
return type(v) == 'nil'
end, function (self, v)
return nil
end)
register('Array', {}, function (self, value)
return type(value) == 'table'
end, function (self, value)
local t = {}
if #value == 0 then
for k in pairs(value) do
if self.sub:checker(k) then
t[#t+1] = self.sub:loader(k)
end
end
else
for _, v in ipairs(value) do
if self.sub:checker(v) then
t[#t+1] = self.sub:loader(v)
end
end
end
return t
end, function (self, sub)
self.sub = sub
end)
register('Hash', {}, function (self, value)
if type(value) == 'table' then
if #value == 0 then
for k, v in pairs(value) do
if not self.subkey:checker(k)
or not self.subvalue:checker(v) then
return false
end
end
else
if not self.subvalue:checker(true) then
return false
end
for _, v in ipairs(value) do
if not self.subkey:checker(v) then
return false
end
end
end
return true
end
if type(value) == 'string' then
return self.subkey:checker('')
and self.subvalue:checker(true)
end
end, function (self, value)
if type(value) == 'table' then
local t = {}
if #value == 0 then
for k, v in pairs(value) do
t[k] = v
end
else
for _, k in pairs(value) do
t[k] = true
end
end
return t
end
if type(value) == 'string' then
local t = {}
for s in value:gmatch('[^' .. self.sep .. ']+') do
t[s] = true
end
return t
end
end, function (self, subkey, subvalue, sep)
self.subkey = subkey
self.subvalue = subvalue
self.sep = sep
end)
register('Or', nil, function (self, value)
for _, sub in ipairs(self.subs) do
if sub:checker(value) then
return true
end
end
return false
end, function (self, value)
for _, sub in ipairs(self.subs) do
if sub:checker(value) then
return sub:loader(value)
end
end
end, function (self, ...)
self.subs = { ... }
end)
---@format disable-next
local template = {
['Lua.runtime.version'] = Type.String >> 'Lua 5.4' << {
'Lua 5.1',
'Lua 5.2',
'Lua 5.3',
'Lua 5.4',
'LuaJIT',
},
['Lua.runtime.path'] = Type.Array(Type.String) >> {
"?.lua",
"?/init.lua",
},
['Lua.runtime.pathStrict'] = Type.Boolean >> false,
['Lua.runtime.special'] = Type.Hash(
Type.String,
Type.String >> 'require' << {
'_G',
'rawset',
'rawget',
'setmetatable',
'require',
'dofile',
'loadfile',
'pcall',
'xpcall',
'assert',
'error',
'type',
'os.exit',
}
),
['Lua.runtime.meta'] = Type.String >> '${version} ${language} ${encoding}',
['Lua.runtime.unicodeName'] = Type.Boolean,
['Lua.runtime.nonstandardSymbol'] = Type.Array(Type.String << {
'//', '/**/',
'`',
'+=', '-=', '*=', '/=', '%=', '^=', '//=',
'|=', '&=', '<<=', '>>=',
'||', '&&', '!', '!=',
'continue',
}),
['Lua.runtime.plugin'] = Type.Or(Type.String, Type.Array(Type.String)) ,
['Lua.runtime.pluginArgs'] = Type.Or(Type.Array(Type.String), Type.Hash(Type.String, Type.String)),
['Lua.runtime.fileEncoding'] = Type.String >> 'utf8' << {
'utf8',
'ansi',
'utf16le',
'utf16be',
},
['Lua.runtime.builtin'] = Type.Hash(
Type.String << util.getTableKeys(define.BuiltIn, true),
Type.String >> 'default' << {
'default',
'enable',
'disable',
}
)
>> util.deepCopy(define.BuiltIn),
['Lua.diagnostics.enable'] = Type.Boolean >> true,
['Lua.diagnostics.globals'] = Type.Array(Type.String),
['Lua.diagnostics.globalsRegex'] = Type.Array(Type.String),
['Lua.diagnostics.disable'] = Type.Array(Type.String << util.getTableKeys(diag.getDiagAndErrNameMap(), true)),
['Lua.diagnostics.severity'] = Type.Hash(
Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true),
Type.String << {
'Error',
'Warning',
'Information',
'Hint',
'Error!',
'Warning!',
'Information!',
'Hint!',
}
)
>> util.deepCopy(define.DiagnosticDefaultSeverity),
['Lua.diagnostics.neededFileStatus'] = Type.Hash(
Type.String << util.getTableKeys(define.DiagnosticDefaultNeededFileStatus, true),
Type.String << {
'Any',
'Opened',
'None',
'Any!',
'Opened!',
'None!',
}
)
>> util.deepCopy(define.DiagnosticDefaultNeededFileStatus),
['Lua.diagnostics.groupSeverity'] = Type.Hash(
Type.String << util.getTableKeys(define.DiagnosticDefaultGroupSeverity, true),
Type.String << {
'Error',
'Warning',
'Information',
'Hint',
'Fallback',
}
)
>> util.deepCopy(define.DiagnosticDefaultGroupSeverity),
['Lua.diagnostics.groupFileStatus'] = Type.Hash(
Type.String << util.getTableKeys(define.DiagnosticDefaultGroupFileStatus, true),
Type.String << {
'Any',
'Opened',
'None',
'Fallback',
}
)
>> util.deepCopy(define.DiagnosticDefaultGroupFileStatus),
['Lua.diagnostics.disableScheme'] = Type.Array(Type.String) >> { 'git' },
['Lua.diagnostics.workspaceEvent'] = Type.String >> 'OnSave' << {
'OnChange',
'OnSave',
'None',
},
['Lua.diagnostics.workspaceDelay'] = Type.Integer >> 3000,
['Lua.diagnostics.workspaceRate'] = Type.Integer >> 100,
['Lua.diagnostics.libraryFiles'] = Type.String >> 'Opened' << {
'Enable',
'Opened',
'Disable',
},
['Lua.diagnostics.ignoredFiles'] = Type.String >> 'Opened' << {
'Enable',
'Opened',
'Disable',
},
['Lua.diagnostics.unusedLocalExclude'] = Type.Array(Type.String),
['Lua.workspace.ignoreDir'] = Type.Array(Type.String) >> {
'.vscode',
},
['Lua.workspace.ignoreSubmodules'] = Type.Boolean >> true,
['Lua.workspace.useGitIgnore'] = Type.Boolean >> true,
['Lua.workspace.maxPreload'] = Type.Integer >> 5000,
['Lua.workspace.preloadFileSize'] = Type.Integer >> 500,
['Lua.workspace.library'] = Type.Array(Type.String),
['Lua.workspace.checkThirdParty'] = Type.Or(Type.String >> 'Ask' << {
'Ask',
'Apply',
'ApplyInMemory',
'Disable',
}, Type.Boolean),
['Lua.workspace.userThirdParty'] = Type.Array(Type.String),
['Lua.completion.enable'] = Type.Boolean >> true,
['Lua.completion.callSnippet'] = Type.String >> 'Disable' << {
'Disable',
'Both',
'Replace',
},
['Lua.completion.keywordSnippet'] = Type.String >> 'Replace' << {
'Disable',
'Both',
'Replace',
},
['Lua.completion.displayContext'] = Type.Integer >> 0,
['Lua.completion.workspaceWord'] = Type.Boolean >> true,
['Lua.completion.showWord'] = Type.String >> 'Fallback' << {
'Enable',
'Fallback',
'Disable',
},
['Lua.completion.autoRequire'] = Type.Boolean >> true,
['Lua.completion.showParams'] = Type.Boolean >> true,
['Lua.completion.requireSeparator'] = Type.String >> '.',
['Lua.completion.postfix'] = Type.String >> '@',
['Lua.signatureHelp.enable'] = Type.Boolean >> true,
['Lua.hover.enable'] = Type.Boolean >> true,
['Lua.hover.viewString'] = Type.Boolean >> true,
['Lua.hover.viewStringMax'] = Type.Integer >> 1000,
['Lua.hover.viewNumber'] = Type.Boolean >> true,
['Lua.hover.previewFields'] = Type.Integer >> 10,
['Lua.hover.enumsLimit'] = Type.Integer >> 5,
['Lua.hover.expandAlias'] = Type.Boolean >> true,
['Lua.semantic.enable'] = Type.Boolean >> true,
['Lua.semantic.variable'] = Type.Boolean >> true,
['Lua.semantic.annotation'] = Type.Boolean >> true,
['Lua.semantic.keyword'] = Type.Boolean >> false,
['Lua.hint.enable'] = Type.Boolean >> false,
['Lua.hint.paramType'] = Type.Boolean >> true,
['Lua.hint.setType'] = Type.Boolean >> false,
['Lua.hint.paramName'] = Type.String >> 'All' << {
'All',
'Literal',
'Disable',
},
['Lua.hint.await'] = Type.Boolean >> true,
['Lua.hint.awaitPropagate'] = Type.Boolean >> false,
['Lua.hint.arrayIndex'] = Type.String >> 'Auto' << {
'Enable',
'Auto',
'Disable',
},
['Lua.hint.semicolon'] = Type.String >> 'SameLine' << {
'All',
'SameLine',
'Disable',
},
['Lua.window.statusBar'] = Type.Boolean >> true,
['Lua.window.progressBar'] = Type.Boolean >> true,
['Lua.codeLens.enable'] = Type.Boolean >> false,
['Lua.format.enable'] = Type.Boolean >> true,
['Lua.format.defaultConfig'] = Type.Hash(Type.String, Type.String)
>> {},
['Lua.typeFormat.config'] = Type.Hash(Type.String, Type.String)
>> {
format_line = "true",
auto_complete_end = "true",
auto_complete_table_sep = "true"
},
['Lua.spell.dict'] = Type.Array(Type.String),
['Lua.nameStyle.config'] = Type.Hash(Type.String, Type.Or(Type.String, Type.Array(Type.Hash(Type.String, Type.String))))
>> {},
['Lua.misc.parameters'] = Type.Array(Type.String),
['Lua.misc.executablePath'] = Type.String,
['Lua.language.fixIndent'] = Type.Boolean >> true,
['Lua.language.completeAnnotation'] = Type.Boolean >> true,
['Lua.type.castNumberToInteger'] = Type.Boolean >> true,
['Lua.type.weakUnionCheck'] = Type.Boolean >> false,
['Lua.type.maxUnionVariants'] = Type.Integer >> 0,
['Lua.type.weakNilCheck'] = Type.Boolean >> false,
['Lua.type.inferParamType'] = Type.Boolean >> false,
['Lua.type.checkTableShape'] = Type.Boolean >> false,
['Lua.type.inferTableSize'] = Type.Integer >> 10,
['Lua.doc.privateName'] = Type.Array(Type.String),
['Lua.doc.protectedName'] = Type.Array(Type.String),
['Lua.doc.packageName'] = Type.Array(Type.String),
['Lua.doc.regengine'] = Type.String >> 'glob' << {
'glob',
'lua',
},
--testma
["Lua.docScriptPath"] = Type.String,
["Lua.addonRepositoryPath"] = Type.String,
-- VSCode
["Lua.addonManager.enable"] = Type.Boolean >> true,
["Lua.addonManager.repositoryPath"] = Type.String,
["Lua.addonManager.repositoryBranch"] = Type.String,
['files.associations'] = Type.Hash(Type.String, Type.String),
-- copy from VSCode default
['files.exclude'] = Type.Hash(Type.String, Type.Boolean) >> {
["**/.DS_Store"] = true,
["**/.git"] = true,
["**/.hg"] = true,
["**/.svn"] = true,
["**/CVS"] = true,
["**/Thumbs.db"] = true,
},
['editor.semanticHighlighting.enabled'] = Type.Or(Type.Boolean, Type.String),
['editor.acceptSuggestionOnEnter'] = Type.String >> 'on',
}
return template

View File

@@ -0,0 +1,749 @@
local files = require 'files'
local lang = require 'language'
local util = require 'utility'
local guide = require "parser.guide"
local converter = require 'proto.converter'
local autoreq = require 'core.completion.auto-require'
local rpath = require 'workspace.require-path'
local furi = require 'file-uri'
local vm = require 'vm'
---@param uri uri
---@param row integer
---@param mode string
---@param code string
local function checkDisableByLuaDocExits(uri, row, mode, code)
if row < 0 then
return nil
end
local state = files.getState(uri)
if not state then
return nil
end
local lines = state.lines
if state.ast.docs and lines then
return guide.eachSourceBetween(
state.ast.docs,
guide.positionOf(row, 0),
guide.positionOf(row + 1, 0),
function (doc)
if doc.type == 'doc.diagnostic'
and doc.mode == mode then
if doc.names then
return {
start = doc.finish,
finish = doc.finish,
newText = ', ' .. code,
}
else
return {
start = doc.finish,
finish = doc.finish,
newText = ': ' .. code,
}
end
end
end
)
end
return nil
end
local function checkDisableByLuaDocInsert(_uri, row, mode, code)
return {
start = guide.positionOf(row, 0),
finish = guide.positionOf(row, 0),
newText = '---@diagnostic ' .. mode .. ': ' .. code .. '\n',
}
end
local function disableDiagnostic(uri, code, start, results)
local row = guide.rowColOf(start)
results[#results+1] = {
title = lang.script('ACTION_DISABLE_DIAG', code),
kind = 'quickfix',
command = {
title = lang.script.COMMAND_DISABLE_DIAG,
command = 'lua.setConfig',
arguments = {
{
key = 'Lua.diagnostics.disable',
action = 'add',
value = code,
uri = uri,
}
}
}
}
local function pushEdit(title, edit)
results[#results+1] = {
title = title,
kind = 'quickfix',
edit = {
changes = {
[uri] = { edit }
}
}
}
end
pushEdit(lang.script('ACTION_DISABLE_DIAG_LINE', code),
checkDisableByLuaDocExits (uri, row - 1, 'disable-next-line', code)
or checkDisableByLuaDocInsert(uri, row, 'disable-next-line', code))
pushEdit(lang.script('ACTION_DISABLE_DIAG_FILE', code),
checkDisableByLuaDocExits (uri, 0, 'disable', code)
or checkDisableByLuaDocInsert(uri, 0, 'disable', code))
end
local function markGlobal(uri, name, results)
results[#results+1] = {
title = lang.script('ACTION_MARK_GLOBAL', name),
kind = 'quickfix',
command = {
title = lang.script.COMMAND_MARK_GLOBAL,
command = 'lua.setConfig',
arguments = {
{
key = 'Lua.diagnostics.globals',
action = 'add',
value = name,
uri = uri,
}
}
}
}
end
local function changeVersion(uri, version, results)
results[#results+1] = {
title = lang.script('ACTION_RUNTIME_VERSION', version),
kind = 'quickfix',
command = {
title = lang.script.COMMAND_RUNTIME_VERSION,
command = 'lua.setConfig',
arguments = {
{
key = 'Lua.runtime.version',
action = 'set',
value = version,
uri = uri,
}
}
},
}
end
local function solveUndefinedGlobal(uri, diag, results)
local state = files.getState(uri)
if not state then
return
end
local start = converter.unpackRange(state, diag.range)
guide.eachSourceContain(state.ast, start, function (source)
if source.type ~= 'getglobal' then
return
end
local name = guide.getKeyName(source)
markGlobal(uri, name, results)
end)
if diag.data and diag.data.versions then
for _, version in ipairs(diag.data.versions) do
changeVersion(uri, version, results)
end
end
end
local function solveLowercaseGlobal(uri, diag, results)
local state = files.getState(uri)
if not state then
return
end
local start = converter.unpackRange(state, diag.range)
guide.eachSourceContain(state.ast, start, function (source)
if source.type ~= 'setglobal' then
return
end
local name = guide.getKeyName(source)
markGlobal(uri, name, results)
end)
end
local function findSyntax(uri, diag)
local state = files.getState(uri)
if not state then
return
end
for _, err in ipairs(state.errs) do
if err.type:lower():gsub('_', '-') == diag.code then
local range = converter.packRange(state, err.start, err.finish)
if util.equal(range, diag.range) then
return err
end
end
end
return nil
end
local function solveSyntaxByChangeVersion(uri, err, results)
if type(err.version) == 'table' then
for _, version in ipairs(err.version) do
changeVersion(uri, version, results)
end
else
changeVersion(uri, err.version, results)
end
end
local function solveSyntaxByAddDoEnd(uri, err, results)
results[#results+1] = {
title = lang.script.ACTION_ADD_DO_END,
kind = 'quickfix',
edit = {
changes = {
[uri] = {
{
start = err.start,
finish = err.start,
newText = 'do ',
},
{
start = err.finish,
finish = err.finish,
newText = ' end',
},
}
}
}
}
end
local function solveSyntaxByFix(uri, err, results)
local changes = {}
for _, fix in ipairs(err.fix) do
changes[#changes+1] = {
start = fix.start,
finish = fix.finish,
newText = fix.text,
}
end
results[#results+1] = {
title = lang.script('ACTION_' .. err.fix.title, err.fix),
kind = 'quickfix',
edit = {
changes = {
[uri] = changes,
}
}
}
end
local function solveSyntaxUnicodeName(uri, _err, results)
results[#results+1] = {
title = lang.script('ACTION_RUNTIME_UNICODE_NAME'),
kind = 'quickfix',
command = {
title = lang.script.COMMAND_UNICODE_NAME,
command = 'lua.setConfig',
arguments = {
{
key = 'Lua.runtime.unicodeName',
action = 'set',
value = true,
uri = uri,
}
}
},
}
end
local function solveSyntax(uri, diag, results)
local err = findSyntax(uri, diag)
if not err then
return
end
if err.version then
solveSyntaxByChangeVersion(uri, err, results)
end
if err.type == 'ACTION_AFTER_BREAK' or err.type == 'ACTION_AFTER_RETURN' then
solveSyntaxByAddDoEnd(uri, err, results)
end
if err.type == 'UNICODE_NAME' then
solveSyntaxUnicodeName(uri, err, results)
end
if err.fix then
solveSyntaxByFix(uri, err, results)
end
end
local function solveNewlineCall(uri, diag, results)
local state = files.getState(uri)
if not state then
return
end
local start = converter.unpackRange(state, diag.range)
results[#results+1] = {
title = lang.script.ACTION_ADD_SEMICOLON,
kind = 'quickfix',
edit = {
changes = {
[uri] = {
{
start = start,
finish = start,
newText = ';',
}
}
}
}
}
end
local function solveAmbiguity1(uri, diag, results)
results[#results+1] = {
title = lang.script.ACTION_ADD_BRACKETS,
kind = 'quickfix',
command = {
title = lang.script.COMMAND_ADD_BRACKETS,
command = 'lua.solve',
arguments = {
{
name = 'ambiguity-1',
uri = uri,
range = diag.range,
}
}
},
}
end
local function solveTrailingSpace(uri, _diag, results)
results[#results+1] = {
title = lang.script.ACTION_REMOVE_SPACE,
kind = 'quickfix',
command = {
title = lang.script.COMMAND_REMOVE_SPACE,
command = 'lua.removeSpace',
arguments = {
{
uri = uri,
}
}
},
}
end
local function solveAwaitInSync(uri, diag, results)
local state = files.getState(uri)
if not state then
return
end
local start, finish = converter.unpackRange(state, diag.range)
local parentFunction
guide.eachSourceType(state.ast, 'function', function (source)
if source.start > finish
or source.finish < start then
return
end
if not parentFunction or parentFunction.start < source.start then
parentFunction = source
end
end)
if not parentFunction then
return
end
local row = guide.rowColOf(parentFunction.start)
local pos = guide.positionOf(row, 0)
local offset = guide.positionToOffset(state, pos + 1)
local space = state.lua:match('[ \t]*', offset)
results[#results+1] = {
title = lang.script.ACTION_MARK_ASYNC,
kind = 'quickfix',
edit = {
changes = {
[uri] = {
{
start = pos,
finish = pos,
newText = space .. '---@async\n',
}
}
}
},
}
end
local function solveSpell(uri, diag, results)
local state = files.getState(uri)
if not state then
return
end
local spell = require 'provider.spell'
local word = diag.data
if word == nil then
return
end
results[#results+1] = {
title = lang.script('ACTION_ADD_DICT', word),
kind = 'quickfix',
command = {
title = lang.script.COMMAND_ADD_DICT,
command = 'lua.setConfig',
arguments = {
{
key = 'Lua.spell.dict',
action = 'add',
value = word,
uri = uri,
}
}
}
}
local suggests = spell.getSpellSuggest(word)
for _, suggest in ipairs(suggests) do
results[#results+1] = {
title = suggest,
kind = 'quickfix',
edit = {
changes = {
[uri] = {
{
start = converter.unpackPosition(state, diag.range.start),
finish = converter.unpackPosition(state, diag.range["end"]),
newText = suggest
}
}
}
}
}
end
end
local function solveDiagnostic(uri, diag, start, results)
if diag.source == lang.script.DIAG_SYNTAX_CHECK then
solveSyntax(uri, diag, results)
return
end
if not diag.code then
return
end
if diag.code == 'undefined-global' then
solveUndefinedGlobal(uri, diag, results)
elseif diag.code == 'lowercase-global' then
solveLowercaseGlobal(uri, diag, results)
elseif diag.code == 'newline-call' then
solveNewlineCall(uri, diag, results)
elseif diag.code == 'ambiguity-1' then
solveAmbiguity1(uri, diag, results)
elseif diag.code == 'trailing-space' then
solveTrailingSpace(uri, diag, results)
elseif diag.code == 'await-in-sync' then
solveAwaitInSync(uri, diag, results)
elseif diag.code == 'spell-check' then
solveSpell(uri, diag, results)
end
disableDiagnostic(uri, diag.code, start, results)
end
local function checkQuickFix(results, uri, start, diagnostics)
if not diagnostics then
return
end
for _, diag in ipairs(diagnostics) do
solveDiagnostic(uri, diag, start, results)
end
end
local function checkSwapParams(results, uri, start, finish)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
local args = {}
guide.eachSourceBetween(state.ast, start, finish, function (source)
if source.type == 'callargs'
or source.type == 'funcargs' then
local targetIndex
for index, arg in ipairs(source) do
if arg.start <= finish and arg.finish >= start then
-- should select only one param
if targetIndex then
return
end
targetIndex = index
end
end
if not targetIndex then
return
end
local node
if source.type == 'callargs' then
node = text:sub(
guide.positionToOffset(state, source.parent.node.start) + 1,
guide.positionToOffset(state, source.parent.node.finish)
)
elseif source.type == 'funcargs' then
local var = source.parent.parent
if guide.isAssign(var) then
if var.type == 'tablefield' then
var = var.field
end
if var.type == 'tableindex' then
var = var.index
end
node = text:sub(
guide.positionToOffset(state, var.start) + 1,
guide.positionToOffset(state, var.finish)
)
else
node = lang.script.SYMBOL_ANONYMOUS
end
end
args[#args+1] = {
source = source,
index = targetIndex,
node = node,
}
end
end)
if #args == 0 then
return
end
table.sort(args, function (a, b)
return a.source.start > b.source.start
end)
local target = args[1]
local myArg = target.source[target.index]
for i, targetArg in ipairs(target.source) do
if i ~= target.index then
results[#results+1] = {
title = lang.script('ACTION_SWAP_PARAMS', {
node = target.node,
index = i,
}),
kind = 'refactor.rewrite',
edit = {
changes = {
[uri] = {
{
start = myArg.start,
finish = myArg.finish,
newText = text:sub(
guide.positionToOffset(state, targetArg.start) + 1,
guide.positionToOffset(state, targetArg.finish)
),
},
{
start = targetArg.start,
finish = targetArg.finish,
newText = text:sub(
guide.positionToOffset(state, myArg.start) + 1,
guide.positionToOffset(state, myArg.finish)
),
},
}
}
}
}
end
end
end
--local function checkExtractAsFunction(results, uri, start, finish)
-- local ast = files.getAst(uri)
-- local text = files.getText(uri)
-- local funcs = {}
-- guide.eachSourceContain(ast.ast, start, function (source)
-- if source.type == 'function'
-- or source.type == 'main' then
-- funcs[#funcs+1] = source
-- end
-- end)
-- table.sort(funcs, function (a, b)
-- return a.start > b.start
-- end)
-- local func = funcs[1]
-- if not func then
-- return
-- end
-- if #func == 0 then
-- return
-- end
-- if func.finish < finish then
-- return
-- end
-- local actions = {}
-- for i = 1, #func do
-- local action = func[i]
-- if action.start < start
-- and action.finish > start then
-- return
-- end
-- if action.start < finish
-- and action.finish > finish then
-- return
-- end
-- if action.finish >= start
-- and action.start <= finish then
-- actions[#actions+1] = action
-- end
-- end
-- if text:sub(start, actions[1].start - 1):find '[%C%S]' then
-- return
-- end
-- if text:sub(actions[1].finish + 1, finish):find '[%C%S]' then
-- return
-- end
-- while func do
-- local funcName = getExtractFuncName(uri, actions[1].start)
-- local funcParams = getExtractFuncParams(uri, actions[1].start)
-- results[#results+1] = {
-- title = lang.script('ACTION_EXTRACT'),
-- kind = 'refactor.extract',
-- edit = {
-- changes = {
-- [uri] = {
-- {
-- start = actions[1].start,
-- finish = actions[1].start - 1,
-- newText = text:sub(targetArg.start, targetArg.finish),
-- },
-- {
-- start = targetArg.start,
-- finish = targetArg.finish,
-- newText = text:sub(myArg.start, myArg.finish),
-- },
-- }
-- }
-- }
-- }
-- func = guide.getParentFunction(func)
-- end
--end
local function checkJsonToLua(results, uri, start, finish)
local text = files.getText(uri)
local state = files.getState(uri)
if not state or not text then
return
end
local startOffset = guide.positionToOffset(state, start)
local finishOffset = guide.positionToOffset(state, finish)
local jsonStart = text:match('()["%{%[]', startOffset + 1)
if not jsonStart then
return
end
local jsonFinish, finishChar
for i = math.min(finishOffset, #text), jsonStart + 1, -1 do
local char = text:sub(i, i)
if char == ']'
or char == '}' then
jsonFinish = i
finishChar = char
break
end
end
if not jsonFinish then
return
end
if finishChar == '}' then
if not text:sub(jsonStart, jsonFinish):find '"%s*%:' then
return
end
end
if finishChar == ']' then
if not text:sub(jsonStart, jsonFinish):find ',' then
return
end
end
results[#results+1] = {
title = lang.script.ACTION_JSON_TO_LUA,
kind = 'refactor.rewrite',
command = {
title = lang.script.COMMAND_JSON_TO_LUA,
command = 'lua.jsonToLua',
arguments = {
{
uri = uri,
start = guide.offsetToPosition(state, jsonStart) - 1,
finish = guide.offsetToPosition(state, jsonFinish),
}
}
},
}
end
local function findRequireTargets(visiblePaths)
local targets = {}
for _, visible in ipairs(visiblePaths) do
targets[#targets+1] = visible.name
end
return targets
end
local function checkMissingRequire(results, uri, start, finish)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
local function addRequires(global, endpos)
if not global then
return
end
autoreq.check(state, global, endpos, function (moduleFile, _stemname, _targetSource, fullKeyPath)
local visiblePaths = rpath.getVisiblePath(uri, furi.decode(moduleFile))
if not visiblePaths or #visiblePaths == 0 then return end
for _, target in ipairs(findRequireTargets(visiblePaths)) do
results[#results+1] = {
title = lang.script('ACTION_AUTOREQUIRE', target .. (fullKeyPath or ''), global),
kind = 'refactor.rewrite',
command = {
title = 'autoRequire',
command = 'lua.autoRequire',
arguments = {
{
uri = guide.getUri(state.ast),
target = moduleFile,
name = global,
requireName = target,
fullKeyPath = fullKeyPath,
},
},
}
}
end
end)
end
guide.eachSourceBetween(state.ast, start, finish, function (source)
if vm.isUndefinedGlobal(source) then
addRequires(source[1], source.finish)
end
end)
end
return function (uri, start, finish, diagnostics)
local ast = files.getState(uri)
if not ast then
return nil
end
local results = {}
checkQuickFix(results, uri, start, diagnostics)
checkSwapParams(results, uri, start, finish)
--checkExtractAsFunction(results, uri, start, finish)
checkJsonToLua(results, uri, start, finish)
checkMissingRequire(results, uri, start, finish)
return results
end

View File

@@ -0,0 +1,166 @@
local files = require 'files'
local guide = require 'parser.guide'
local await = require 'await'
local conv = require 'proto.converter'
local getRef = require 'core.reference'
local lang = require 'language'
local client = require 'client'
---@class parser.state
---@field package _codeLens? codeLens
---@class codeLens.resolving
---@field mode 'reference'
---@field source? parser.object
---@class codeLens.result
---@field position integer
---@field id integer
---@class codeLens
local mt = {}
mt.__index = mt
mt.type = 'codeLens'
mt.id = 0
---@param uri uri
---@return boolean
function mt:init(uri)
self.state = files.getState(uri)
if not self.state then
return false
end
---@type uri
self.uri = uri
---@type codeLens.result[]
self.results = {}
---@type table<integer, codeLens.resolving>
self.resolving = {}
return true
end
---@param pos integer
---@param resolving codeLens.resolving
function mt:addResult(pos, resolving)
self.id = self.id + 1
self.results[#self.results+1] = {
position = pos,
id = self.id,
}
self.resolving[self.id] = resolving
end
---@async
---@param id integer
---@return proto.command?
function mt:resolve(id)
local resolving = self.resolving[id]
if not resolving then
return nil
end
if resolving.mode == 'reference' then
return self:resolveReference(resolving.source)
end
end
---@async
function mt:collectReferences()
await.delay()
---@async
guide.eachSourceType(self.state.ast, 'function', function (src)
local parent = src.parent
if guide.isAssign(parent) then
src = parent
elseif parent.type == 'return' then
else
return
end
await.delay()
self:addResult(src.start, {
mode = 'reference',
source = src,
})
end)
end
---@async
---@param source parser.object
---@return proto.command?
function mt:resolveReference(source)
local refs = getRef(self.uri, source.finish, false)
local count = refs and #refs or 0
if client.getOption('codeLensViewReferences') then
local locations = {}
for _, ref in ipairs(refs or {}) do
local state = files.getState(ref.uri)
if state then
locations[#locations+1] = conv.location(
ref.uri,
conv.packRange(state, ref.target.start, ref.target.finish)
)
end
end
local command = conv.command(
lang.script('COMMAND_REFERENCE_COUNT', count),
'lua.showReferences',
{
self.uri,
conv.packPosition(self.state, source.start),
locations,
}
)
return command
else
local command = conv.command(
lang.script('COMMAND_REFERENCE_COUNT', count),
'',
{}
)
return command
end
end
---@async
---@param uri uri
---@return codeLens.result[]?
local function getCodeLens(uri)
local state = files.getState(uri)
if not state then
return nil
end
local codeLens = setmetatable({}, mt)
local suc = codeLens:init(uri)
if not suc then
return nil
end
state._codeLens = codeLens
codeLens:collectReferences()
if #codeLens.results == 0 then
return nil
end
return codeLens.results
end
---@async
---@param id integer
---@return proto.command?
local function resolve(uri, id)
local state = files.getState(uri)
if not state then
return nil
end
local codeLens = state._codeLens
if not codeLens then
return nil
end
local command = codeLens:resolve(id)
return command
end
return {
codeLens = getCodeLens,
resolve = resolve,
}

View File

@@ -0,0 +1,97 @@
local files = require "files"
local guide = require "parser.guide"
local colorPattern = string.rep('%x', 8)
local hex6Pattern = string.format("^#%s", string.rep('%x', 6))
---@param source parser.object
---@return boolean
local function isColor(source)
---@type string
local text = source[1]
if text:len() == 8 then
return text:match(colorPattern)
end
if text:len() == 7 then
return text:match(hex6Pattern)
end
return false
end
---@param colorText string
---@return Color
local function textToColor(colorText)
return {
alpha = tonumber(colorText:sub(1, 2), 16) / 255,
red = tonumber(colorText:sub(3, 4), 16) / 255,
green = tonumber(colorText:sub(5, 6), 16) / 255,
blue = tonumber(colorText:sub(7, 8), 16) / 255,
}
end
---@param colorText string
---@return Color
local function hexTextToColor(colorText)
return {
alpha = 255,
red = tonumber(colorText:sub(2, 3), 16) / 255,
green = tonumber(colorText:sub(4, 5), 16) / 255,
blue = tonumber(colorText:sub(6, 7), 16) / 255,
}
end
---@param color Color
---@return string
local function colorToText(color)
return string.format('%02X%02X%02X%02X'
, math.floor(color.alpha * 255)
, math.floor(color.red * 255)
, math.floor(color.green * 255)
, math.floor(color.blue * 255)
)
end
---@class Color
---@field red number
---@field green number
---@field blue number
---@field alpha number
---@class ColorValue
---@field color Color
---@field start integer
---@field finish integer
---@async
local function colors(uri)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return nil
end
---@type ColorValue[]
local colorValues = {}
guide.eachSource(state.ast, function (source) ---@async
if source.type == 'string' and isColor(source) then
---@type string
local colorText = source[1]
local color = colorText:match(colorPattern) and textToColor(colorText) or hexTextToColor(colorText)
colorValues[#colorValues+1] = {
start = source.start + 1,
finish = source.finish - 1,
color = color
}
end
end)
return colorValues
end
return {
colors = colors,
colorToText = colorToText
}

View File

@@ -0,0 +1,164 @@
local files = require 'files'
local furi = require 'file-uri'
local rpath = require 'workspace.require-path'
local client = require 'client'
local lang = require 'language'
local guide = require 'parser.guide'
local function inComment(state, pos)
for _, comm in ipairs(state.comms) do
if comm.start <= pos and comm.finish >= pos then
return true
end
if comm.start > pos then
break
end
end
return false
end
local function findInsertRow(uri)
local text = files.getText(uri)
local state = files.getState(uri)
if not state or not text then
return
end
local lines = state.lines
local fmt = {
pair = false,
quot = '"',
col = nil,
}
local row
for i = 0, #lines do
if inComment(state, guide.positionOf(i, 0)) then
goto CONTINUE
end
local ln = lines[i]
local lnText = text:match('[^\r\n]*', ln)
if not lnText:find('require', 1, true) then
if row then
break
end
if not lnText:match '^local%s'
and not lnText:match '^%s*$'
and not lnText:match '^%-%-' then
break
end
else
row = i + 1
local lpPos = lnText:find '%('
if lpPos then
fmt.pair = true
else
fmt.pair = false
end
local quot = lnText:match [=[(['"])]=]
fmt.quot = quot or fmt.quot
local eqPos = lnText:find '='
if eqPos then
fmt.col = eqPos
end
end
::CONTINUE::
end
return row or 0, fmt
end
---@async
local function askAutoRequire(uri, visiblePaths)
local selects = {}
local nameMap = {}
for _, visible in ipairs(visiblePaths) do
local expect = visible.name
local select = lang.script(expect)
if not nameMap[select] then
nameMap[select] = expect
selects[#selects+1] = select
end
end
local disable = lang.script.COMPLETION_DISABLE_AUTO_REQUIRE
selects[#selects+1] = disable
local result = client.awaitRequestMessage('Info'
, lang.script.COMPLETION_ASK_AUTO_REQUIRE
, selects
)
if not result then
return
end
if result == disable then
client.setConfig {
{
key = 'Lua.completion.autoRequire',
action = 'set',
value = false,
uri = uri,
}
}
return
end
return nameMap[result]
end
local function applyAutoRequire(uri, row, name, result, fmt, fullKeyPath)
local quotedResult = ('%q'):format(result)
if fmt.quot == "'" then
quotedResult = ([['%s']]):format(quotedResult:sub(2, -2)
:gsub([[']], [[\']])
:gsub([[\"]], [["]])
)
end
if fmt.pair then
quotedResult = ('(%s)'):format(quotedResult)
else
quotedResult = (' %s'):format(quotedResult)
end
local sp = ' '
local text = ('local %s'):format(name)
if fmt.col and fmt.col > #text then
sp = (' '):rep(fmt.col - #text - 1)
end
text = ('local %s%s= require%s%s\n'):format(name, sp, quotedResult, fullKeyPath)
client.editText(uri, {
{
start = guide.positionOf(row, 0),
finish = guide.positionOf(row, 0),
text = text,
}
})
end
---@async
return function (data)
---@type uri
local uri = data.uri
local target = data.target
local name = data.name
local requireName = data.requireName
local state = files.getState(uri)
if not state then
return
end
local path = furi.decode(target)
local visiblePaths = rpath.getVisiblePath(uri, path)
if not visiblePaths or #visiblePaths == 0 then
return
end
table.sort(visiblePaths, function (a, b)
return #a.name < #b.name
end)
if not requireName then
requireName = askAutoRequire(uri, visiblePaths)
if not requireName then
return
end
end
local offset, fmt = findInsertRow(uri)
if offset and fmt then
applyAutoRequire(uri, offset, name, requireName, fmt, data.fullKeyPath or '')
end
end

View File

@@ -0,0 +1,15 @@
local doc = require 'cli.doc'
local client = require 'client'
local furi = require 'file-uri'
local lang = require 'language'
local files = require 'files'
---@async
return function (args)
local outputPath = args[1] and furi.decode(args[1]) or LOGPATH
local docPath, mdPath = doc.makeDoc(outputPath)
client.showMessage('Info', lang.script('CLI_DOC_DONE'
, ('[%s](%s)'):format(files.normalize(docPath), furi.encode(docPath))
, ('[%s](%s)'):format(files.normalize(mdPath), furi.encode(mdPath))
))
end

View File

@@ -0,0 +1,13 @@
local config = require 'config'
local client = require 'client'
local await = require 'await'
---@async
return function (data)
local uri = data[1].uri
local key = data[1].key
while not client:isReady() do
await.sleep(0.1)
end
return config.get(uri, key)
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local util = require 'utility'
local proto = require 'proto'
local define = require 'proto.define'
local lang = require 'language'
local converter = require 'proto.converter'
local guide = require 'parser.guide'
local json = require 'json'
local jsonc = require 'jsonc'
---@async
return function (data)
local state = files.getState(data.uri)
local text = files.getText(data.uri)
if not text or not state then
return
end
local start = guide.positionToOffset(state, data.start)
local finish = guide.positionToOffset(state, data.finish)
local jsonStr = text:sub(start + 1, finish)
local suc, res = pcall(jsonc.decode_jsonc, jsonStr:match '[%{%[].+')
if not suc or res == json.null then
proto.notify('window/showMessage', {
type = define.MessageType.Warning,
message = lang.script('COMMAND_JSON_TO_LUA_FAILED', res:match '%:%d+%:(.+)'),
})
return
end
---@cast res table
local luaStr = util.dump(res)
if jsonStr:sub(1, 1) == '"' then
local key = jsonStr:match '^"([^\r\n]+)"'
if key then
if key:match '^[%a_]%w*$' then
luaStr = ('%s = %s'):format(key, luaStr)
else
luaStr = ('[%q] = %s'):format(key, luaStr)
end
end
end
proto.awaitRequest('workspace/applyEdit', {
label = 'json to lua',
edit = {
changes = {
[data.uri] = {
{
range = converter.packRange(state, data.start, data.finish),
newText = luaStr,
}
}
}
}
})
end

View File

@@ -0,0 +1,56 @@
local config = require 'config'
local ws = require 'workspace'
local fs = require 'bee.filesystem'
local scope = require 'workspace.scope'
local SDBMHash = require 'SDBMHash'
local searchCode = require 'plugins.ffi.searchCode'
local cdefRerence = require 'plugins.ffi.cdefRerence'
local ffi = require 'plugins.ffi'
local function createDir(uri)
local dir = scope.getScope(uri).uri or 'default'
local fileDir = fs.path(METAPATH) / ('%08x'):format(SDBMHash():hash(dir))
if fs.exists(fileDir) then
return fileDir, true
end
fs.create_directories(fileDir)
return fileDir
end
---@async
return function (uri)
if config.get(uri, 'Lua.runtime.version') ~= 'LuaJIT' then
return
end
ws.awaitReady(uri)
local fileDir, exists = createDir(uri)
local refs = cdefRerence()
if not refs or #refs == 0 then
return
end
for _, v in ipairs(refs) do
local target_uri = v.uri
local codes = searchCode(refs, target_uri)
if not codes then
return
end
ffi.build_single(codes, fileDir, target_uri)
end
if not exists then
local client = require 'client'
client.setConfig {
{
key = 'Lua.workspace.library',
action = 'add',
value = tostring(fileDir),
uri = uri,
}
}
end
end

View File

@@ -0,0 +1,60 @@
local files = require 'files'
local guide = require 'parser.guide'
local proto = require 'proto'
local lang = require 'language'
local converter = require 'proto.converter'
---@async
return function (data)
local uri = data.uri
local text = files.getText(uri)
local state = files.getState(uri)
if not state or not text then
return
end
local lines = state.lines
local textEdit = {}
for i = 0, #lines do
local startOffset = lines[i]
local finishOffset = text:find('[\r\n]', startOffset) or (#text + 1)
local lastOffset = finishOffset - 1
local lastChar = text:sub(lastOffset, lastOffset)
if lastChar ~= ' ' and lastChar ~= '\t' then
goto NEXT_LINE
end
local lastPos = guide.offsetToPosition(state, lastOffset)
if guide.isInString(state.ast, lastPos)
or guide.isInComment(state.ast, lastPos) then
goto NEXT_LINE
end
local firstOffset = startOffset
for n = lastOffset - 1, startOffset, -1 do
local char = text:sub(n, n)
if char ~= ' ' and char ~= '\t' then
firstOffset = n + 1
break
end
end
local firstPos = guide.offsetToPosition(state, firstOffset) - 1
textEdit[#textEdit+1] = {
range = converter.packRange(state, firstPos, lastPos),
newText = '',
}
::NEXT_LINE::
end
if #textEdit == 0 then
return
end
proto.awaitRequest('workspace/applyEdit', {
label = lang.script.COMMAND_REMOVE_SPACE,
edit = {
changes = {
[uri] = textEdit,
}
},
})
end

View File

@@ -0,0 +1,11 @@
local client = require 'client'
local await = require 'await'
---@async
---@param changes config.change[]
return function (changes)
while not client:isReady() do
await.sleep(0.1)
end
client.setConfig(changes)
end

View File

@@ -0,0 +1,99 @@
local files = require 'files'
local guide = require 'parser.guide'
local proto = require 'proto'
local lang = require 'language'
local converter = require 'proto.converter'
local opMap = {
['+'] = true,
['-'] = true,
['*'] = true,
['/'] = true,
['//'] = true,
['^'] = true,
['<<'] = true,
['>>'] = true,
['&'] = true,
['|'] = true,
['~'] = true,
['..'] = true,
}
local literalMap = {
['number'] = true,
['integer'] = true,
['boolean'] = true,
['string'] = true,
['table'] = true,
}
---@async
return function (data)
local uri = data.uri
local text = files.getText(uri)
local state = files.getState(uri)
if not state or not text then
return
end
local start, finish = converter.unpackRange(state, data.range)
local result = guide.eachSourceContain(state.ast, start, function (source)
if source.start ~= start
or source.finish ~= finish then
return
end
if not source.op or source.op.type ~= 'or' then
return
end
local first = source[1]
local second = source[2]
-- a + b or 0 --> a + (b or 0)
do
if first.op
and opMap[first.op.type]
and first.type ~= 'unary'
and not second.op
and literalMap[second.type] then
return {
start = source[1][2].start,
finish = source[2].finish,
}
end
end
-- a or b + c --> (a or b) + c
do
if second.op
and opMap[second.op.type]
and second.type ~= 'unary'
and not first.op
and literalMap[second[1].type] then
return {
start = source[1].start,
finish = source[2][1].finish,
}
end
end
end)
if not result then
return
end
proto.awaitRequest('workspace/applyEdit', {
label = lang.script.COMMAND_REMOVE_SPACE,
edit = {
changes = {
[uri] = {
{
range = converter.packRange(state, result.start, result.finish),
newText = ('(%s)'):format(text:sub(
guide.positionToOffset(state, result.start + 1),
guide.positionToOffset(state, result.finish)
)),
}
},
}
},
})
end

View File

@@ -0,0 +1,168 @@
local config = require 'config'
local util = require 'utility'
local guide = require 'parser.guide'
local workspace = require 'workspace'
local files = require 'files'
local furi = require 'file-uri'
local rpath = require 'workspace.require-path'
local vm = require 'vm'
local matchKey = require 'core.matchkey'
local ipairs = ipairs
---@class auto-require
local m = {}
---@type table<uri, true>
m.validUris = {}
---@param state parser.state
---@return parser.object?
function m.getTargetSource(state)
local targetReturns = state.ast.returns
if not targetReturns then
return nil
end
local targetSource = targetReturns[1] and targetReturns[1][1]
if not targetSource then
return nil
end
if targetSource.type ~= 'getlocal'
and targetSource.type ~= 'table'
and targetSource.type ~= 'function' then
return nil
end
return targetSource
end
function m.check(state, word, position, callback)
local globals = util.arrayToHash(config.get(state.uri, 'Lua.diagnostics.globals'))
local locals = guide.getVisibleLocals(state.ast, position)
local hit = false
for uri in files.eachFile(state.uri) do
if uri == guide.getUri(state.ast) then
goto CONTINUE
end
if not m.validUris[uri] then
goto CONTINUE
end
local path = furi.decode(uri)
local relativePath = workspace.getRelativePath(path)
local infos = rpath.getVisiblePath(uri, path)
local testedStem = { }
for _, sr in ipairs(infos) do
local stemName
if sr.searcher == '[[meta]]' then
stemName = sr.name
else
local pattern = sr.searcher
: gsub("(%p)", "%%%1")
: gsub("%%%?", "(.-)")
local stemPath = relativePath:match(pattern)
if not stemPath then
goto INNER_CONTINUE
end
stemName = stemPath:match("[%a_][%w_]*$")
if not stemName or testedStem[stemName] then
goto INNER_CONTINUE
end
end
testedStem[stemName] = true
if not locals[stemName]
and not vm.hasGlobalSets(state.uri, 'variable', stemName)
and not globals[stemName]
and matchKey(word, stemName) then
local targetState = files.getState(uri)
if not targetState then
goto INNER_CONTINUE
end
local targetSource = m.getTargetSource(targetState)
if not targetSource then
goto INNER_CONTINUE
end
if targetSource.type == 'getlocal'
and vm.getDeprecated(targetSource.node) then
goto INNER_CONTINUE
end
hit = true
callback(uri, stemName, targetSource)
end
::INNER_CONTINUE::
end
::CONTINUE::
end
-- 如果没命中, 则检查枚举
if not hit then
local docs = vm.getDocSets(state.uri)
for _, doc in ipairs(docs) do
if doc.type ~= 'doc.enum' or vm.getDeprecated(doc) then
goto CONTINUE
end
-- 检查枚举名是否匹配
if not (doc.enum[1] == word or doc.enum[1]:match(".*%.([^%.]*)$") == word) then
goto CONTINUE
end
local uri = guide.getUri(doc)
local targetState = files.getState(uri)
if not targetState then
goto CONTINUE
end
local targetSource = m.getTargetSource(targetState)
if not targetSource or (targetSource.type ~= 'getlocal' and targetSource.type ~= 'table') or vm.getDeprecated(targetSource.node) then
goto CONTINUE
end
-- 枚举的完整路径
local fullKeyPath = ""
local node = doc.bindSource.parent
while node do
-- 检查是否可见
if not vm.isVisible(state.ast, node) then
goto CONTINUE
end
if node.type == 'setfield' or node.type == 'getfield' then
fullKeyPath = "." .. node.field[1] .. fullKeyPath
end
if node.type == 'getlocal' then
node = node.node
break
end
node = node.node
end
-- 匹配导出的值, 确定最终路径
if targetSource.node == node then
hit = true
elseif targetSource.type == 'table' then
for _, value in ipairs(targetSource) do
if value.value.node == node then
fullKeyPath = "." .. value.value[1] .. fullKeyPath
hit = true
break
end
end
end
if hit then
callback(guide.getUri(doc), nil, nil, fullKeyPath)
end
::CONTINUE::
end
end
end
files.watch(function (ev, uri)
if ev == 'update'
or ev == 'remove' then
m.validUris[uri] = nil
end
if ev == 'compile' then
local state = files.getLastState(uri)
if state and m.getTargetSource(state) then
m.validUris[uri] = true
end
end
end)
return m

View File

@@ -0,0 +1 @@
return require 'core.completion.completion'

View File

@@ -0,0 +1,409 @@
local define = require 'proto.define'
local guide = require 'parser.guide'
local config = require 'config'
local util = require 'utility'
local lookback = require 'core.look-backward'
local keyWordMap = {
{ 'do', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'do .. end',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = [[$0 end]],
}
else
results[#results+1] = {
label = 'do .. end',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
do\
\t$0\
end",
}
end
return true
end, function(info)
return guide.eachSourceContain(info.state.ast, info.start, function(source)
if source.type == 'while'
or source.type == 'in'
or source.type == 'loop' then
if source.finish - info.start <= 2 then
return true
end
end
end)
end },
{ 'and' },
{ 'break' },
{ 'else' },
{ 'elseif', function(info, results)
local offset = guide.positionToOffset(info.state, info.position)
if info.text:find('^%s*then', offset + 1)
or info.text:find('^%s*do', offset + 1) then
return false
end
if info.hasSpace then
results[#results+1] = {
label = 'elseif .. then',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = [[$1 then]],
}
else
results[#results+1] = {
label = 'elseif .. then',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = [[elseif $1 then]],
}
end
return true
end },
{ 'end' },
{ 'false' },
{ 'for', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'for .. ipairs',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
${1:index}, ${2:value} in ipairs(${3:t}) do\
\t$0\
end"
}
results[#results+1] = {
label = 'for .. pairs',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
${1:key}, ${2:value} in pairs(${3:t}) do\
\t$0\
end"
}
results[#results+1] = {
label = 'for i = ..',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
${1:i} = ${2:1}, ${3:10, 1} do\
\t$0\
end"
}
else
results[#results+1] = {
label = 'for .. ipairs',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
for ${1:index}, ${2:value} in ipairs(${3:t}) do\
\t$0\
end"
}
results[#results+1] = {
label = 'for .. pairs',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
for ${1:key}, ${2:value} in pairs(${3:t}) do\
\t$0\
end"
}
results[#results+1] = {
label = 'for i = ..',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
for ${1:i} = ${2:1}, ${3:10, 1} do\
\t$0\
end"
}
end
return true
end },
{ 'function', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'function ()',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = info.isExp and "\z
($1)\
\t$0\
end" or "\z
$1($2)\
\t$0\
end"
}
else
results[#results+1] = {
label = 'function ()',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = info.isExp and "\z
function ($1)\
\t$0\
end" or "\z
function $1($2)\
\t$0\
end"
}
end
return true
end },
{ 'goto' },
{ 'if', function(info, results)
local offset = guide.positionToOffset(info.state, info.position)
if info.text:find('^%s*then', offset + 1)
or info.text:find('^%s*do', offset + 1) then
return false
end
if info.hasSpace then
results[#results+1] = {
label = 'if .. then',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
$1 then\
\t$0\
end"
}
else
results[#results+1] = {
label = 'if .. then',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
if $1 then\
\t$0\
end"
}
end
return true
end },
{ 'in', function(info, results)
local offset = guide.positionToOffset(info.state, info.position)
if info.text:find('^%s*then', offset + 1)
or info.text:find('^%s*do', offset + 1) then
return false
end
if info.hasSpace then
results[#results+1] = {
label = 'in ..',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
${1:pairs(${2:t})} do\
\t$0\
end"
}
else
results[#results+1] = {
label = 'in ..',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
in ${1:pairs(${2:t})} do\
\t$0\
end"
}
end
return true
end },
{ 'local', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'local function',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
function $1($2)\
\t$0\
end"
}
else
results[#results+1] = {
label = 'local function',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
local function $1($2)\
\t$0\
end"
}
end
return false
end },
{ 'nil' },
{ 'not' },
{ 'or' },
{ 'repeat', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'repeat .. until',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = [[$0 until $1]]
}
else
results[#results+1] = {
label = 'repeat .. until',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
repeat\
\t$0\
until $1"
}
end
return true
end },
{ 'return', function(info, results)
if not info.hasSpace then
results[#results+1] = {
label = 'do return end',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = [[do return $1end]]
}
end
return false
end },
{ 'then', function(info, results)
local startOffset = guide.positionToOffset(info.state, info.start)
local pos, first = info.text:match('%S+%s+()(%S+)', startOffset + 1)
if first == 'end'
or first == 'else'
or first == 'elseif' then
local startRow = guide.rowColOf(info.start)
local finishPosition = guide.offsetToPosition(info.state, pos)
local finishRow = guide.rowColOf(finishPosition)
local startSp = info.text:match('^%s*', info.state.lines[startRow])
local finishSp = info.text:match('^%s*', info.state.lines[finishRow])
if startSp == finishSp then
return false
end
end
if not info.hasSpace then
results[#results+1] = {
label = 'then .. end',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = '\z
then\
\t$0\
end'
}
end
return true
end },
{ 'true' },
{ 'until' },
{ 'while', function(info, results)
if info.hasSpace then
results[#results+1] = {
label = 'while .. do',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
${1:true} do\
\t$0\
end"
}
else
results[#results+1] = {
label = 'while .. do',
kind = define.CompletionItemKind.Snippet,
insertTextFormat = 2,
insertText = "\z
while ${1:true} do\
\t$0\
end"
}
end
return true
end },
{ 'continue', function (info, results)
local nonstandardSymbol = config.get(info.uri, 'Lua.runtime.nonstandardSymbol')
if util.arrayHas(nonstandardSymbol, 'continue') then
return
end
local version = config.get(info.uri, 'Lua.runtime.version')
if version == 'Lua 5.1' then
return
end
local mostInsideBlock
guide.eachSourceContain(info.state.ast, info.start, function (src)
if src.type == 'while'
or src.type == 'in'
or src.type == 'loop'
or src.type == 'repeat' then
mostInsideBlock = src
end
end)
if not mostInsideBlock then
return
end
-- 找一下 end 的位置
local endPos
if mostInsideBlock.type == 'while' then
endPos = mostInsideBlock.keyword[5]
elseif mostInsideBlock.type == 'in' then
endPos = mostInsideBlock.keyword[7]
elseif mostInsideBlock.type == 'loop' then
endPos = mostInsideBlock.keyword[5]
elseif mostInsideBlock.type == 'repeat' then
endPos = mostInsideBlock.keyword[3]
end
if not endPos then
return
end
local endLine = guide.rowColOf(endPos)
local tabStr = info.state.lua:sub(
info.state.lines[endLine],
guide.positionToOffset(info.state, endPos)
)
local newText
if tabStr:match '^[\t ]*$' then
newText = ' ::continue::\n' .. tabStr
else
newText = '::continue::'
end
local additional = {}
local word = lookback.findWord(info.state.lua, guide.positionToOffset(info.state, info.start) - 1)
if word ~= 'goto' then
additional[#additional+1] = {
start = info.start,
finish = info.start,
newText = 'goto ',
}
end
local hasContinue = guide.eachSourceType(mostInsideBlock, 'label', function (src)
if src[1] == 'continue' then
return true
end
end)
if not hasContinue then
additional[#additional+1] = {
start = endPos,
finish = endPos,
newText = newText,
}
end
results[#results+1] = {
label = 'goto continue ..',
kind = define.CompletionItemKind.Snippet,
insertText = "continue",
additionalTextEdits = additional,
}
return true
end }
}
return keyWordMap

View File

@@ -0,0 +1,413 @@
local guide = require 'parser.guide'
local lookback = require 'core.look-backward'
local matchKey = require 'core.matchkey'
local subString = require 'core.substring'
local define = require 'proto.define'
local markdown = require 'provider.markdown'
local config = require 'config'
local actions = {}
local function register(key)
return function (data)
actions[#actions+1] = {
key = key,
data = data
}
end
end
local function hasNonFieldInNode(source)
local block = guide.getParentBlock(source)
while source ~= block do
if source.type == 'call'
or source.type == 'getindex'
or source.type == 'getmethod' then
return true
end
source = source.parent
end
return false
end
register 'function' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getlocal'
and source.type ~= 'local' then
return
end
if hasNonFieldInNode(source) then
return
end
local subber = subString(state)
callback(string.format('function %s($1)\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'method' {
function (state, source, callback)
if source.type == 'getfield' then
if hasNonFieldInNode(source) then
return
end
local subber = subString(state)
callback(string.format('function %s:%s($1)\n\t$0\nend'
, subber(source.start + 1, source.dot.start)
, subber(source.dot.finish + 1, source.finish)
))
end
if source.type == 'getmethod' then
if hasNonFieldInNode(source.parent) then
return
end
local subber = subString(state)
callback(string.format('function %s:%s($1)\n\t$0\nend'
, subber(source.start + 1, source.colon.start)
, subber(source.colon.finish + 1, source.finish)
))
end
end
}
register 'pcall' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call' then
return
end
local subber = subString(state)
if source.type == 'call' then
if source.args and #source.args > 0 then
callback(string.format('pcall(%s, %s)'
, subber(source.node.start + 1, source.node.finish)
, subber(source.args[1].start + 1, source.args[#source.args].finish)
))
else
callback(string.format('pcall(%s)'
, subber(source.node.start + 1, source.node.finish)
))
end
else
callback(string.format('pcall(%s$1)$0'
, subber(source.start + 1, source.finish)
))
end
end
}
register 'xpcall' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call' then
return
end
local subber = subString(state)
if source.type == 'call' then
if source.args and #source.args > 0 then
callback(string.format('xpcall(%s, ${1:debug.traceback}, %s)$0'
, subber(source.node.start + 1, source.node.finish)
, subber(source.args[1].start + 1, source.args[#source.args].finish)
))
else
callback(string.format('xpcall(%s, ${1:debug.traceback})$0'
, subber(source.node.start + 1, source.node.finish)
))
end
else
callback(string.format('xpcall(%s, ${1:debug.traceback}$2)$0'
, subber(source.start + 1, source.finish)
))
end
end
}
register 'ifcall' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call' then
return
end
local subber = subString(state)
if source.type == 'call' then
if source.args and #source.args > 0 then
callback(string.format('if %s then %s(%s) end$0'
, subber(source.node.start + 1, source.node.finish)
, subber(source.node.start + 1, source.node.finish)
, subber(source.args[1].start + 1, source.args[#source.args].finish)
))
else
callback(string.format('if %s then %s() end$0'
, subber(source.node.start + 1, source.node.finish)
, subber(source.node.start + 1, source.node.finish)
))
end
else
callback(string.format('if %s then %s($1) end$0'
, subber(source.node.start + 1, source.node.finish)
, subber(source.node.start + 1, source.node.finish)
))
end
end
}
register 'local' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('local $1 = %s$0'
, subber(source.start + 1, source.finish)
))
end
}
register 'ipairs' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('for ${1:i}, ${2:v} in ipairs(%s) do\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'pairs' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('for ${1:k}, ${2:v} in pairs(%s) do\n\t$0\nend'
, subber(source.start + 1, source.finish)
))
end
}
register 'unpack' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('unpack(%s)'
, subber(source.start + 1, source.finish)
))
end
}
register 'insert' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.insert(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register 'remove' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.remove(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register 'concat' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getmethod'
and source.type ~= 'getindex'
and source.type ~= 'getlocal'
and source.type ~= 'call'
and source.type ~= 'table' then
return
end
local subber = subString(state)
callback(string.format('table.concat(%s, $0)'
, subber(source.start + 1, source.finish)
))
end
}
register '++' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getindex'
and source.type ~= 'getlocal' then
return
end
local subber = subString(state)
callback(string.format('%s = %s + 1'
, subber(source.start + 1, source.finish)
, subber(source.start + 1, source.finish)
))
end
}
register '++?' {
function (state, source, callback)
if source.type ~= 'getglobal'
and source.type ~= 'getfield'
and source.type ~= 'getindex'
and source.type ~= 'getlocal' then
return
end
local subber = subString(state)
callback(string.format('%s = (%s or 0) + 1'
, subber(source.start + 1, source.finish)
, subber(source.start + 1, source.finish)
))
end
}
local accepts = {
['local'] = true,
['getlocal'] = true,
['getglobal'] = true,
['getfield'] = true,
['getindex'] = true,
['getmethod'] = true,
['call'] = true,
['table'] = true,
}
local function checkPostFix(state, word, wordPosition, position, symbol, results)
local source = guide.eachSourceContain(state.ast, wordPosition, function (source)
if accepts[source.type]
and source.finish == wordPosition then
return source
end
end)
if not source then
return
end
for i, action in ipairs(actions) do
if matchKey(word, action.key) then
action.data[1](state, source, function (newText)
local descText = newText:gsub('%$%{%d+:([^}]+)%}', function (val)
return val
end):gsub('%$%{?%d+%}?', '')
results[#results+1] = {
label = action.key,
kind = define.CompletionItemKind.Snippet,
description = markdown()
: add('lua', descText)
: string(),
textEdit = {
start = wordPosition + #symbol,
finish = position,
newText = newText,
},
sortText = ('postfix-%04d'):format(i),
insertTextFormat = 2,
additionalTextEdits = {
{
start = source.start,
finish = wordPosition + #symbol,
newText = '',
},
},
}
end)
end
end
end
return function (state, position, results)
if guide.isInString(state.ast, position) then
return false
end
local text = state.lua
local offset = guide.positionToOffset(state, position)
local word, newOffset = lookback.findWord(text, offset)
if newOffset then
offset = newOffset - 1
end
local symbol = text:sub(offset, offset)
if symbol == config.get(state.uri, 'Lua.completion.postfix') then
local wordPosition = guide.offsetToPosition(state, offset - 1)
checkPostFix(state, word or '', wordPosition, position, symbol, results)
return symbol ~= '.' and symbol ~= ':'
end
if not word then
if symbol == '+' then
word = text:sub(offset - 1, offset)
offset = offset - 2
end
if symbol == '?' then
word = text:sub(offset - 2, offset)
offset = offset - 3
end
if word then
local wordPosition = guide.offsetToPosition(state, offset)
checkPostFix(state, word or '', wordPosition, position, '', results)
return true
end
end
return false
end

View File

@@ -0,0 +1,237 @@
local workspace = require 'workspace'
local files = require 'files'
local vm = require 'vm'
local findSource = require 'core.find-source'
local guide = require 'parser.guide'
local rpath = require 'workspace.require-path'
local jumpSource = require 'core.jump-source'
local wssymbol = require 'core.workspace-symbol'
local function sortResults(results)
-- 先按照顺序排序
table.sort(results, function (a, b)
local u1 = guide.getUri(a.target)
local u2 = guide.getUri(b.target)
if u1 == u2 then
return a.target.start < b.target.start
else
return u1 < u2
end
end)
-- 如果2个结果处于嵌套状态则取范围小的那个
local lf, lu
for i = #results, 1, -1 do
local res = results[i].target
local f = res.finish
local uri = guide.getUri(res)
if lf and f > lf and uri == lu then
table.remove(results, i)
else
lu = uri
lf = f
end
end
end
local accept = {
['local'] = true,
['setlocal'] = true,
['getlocal'] = true,
['label'] = true,
['goto'] = true,
['field'] = true,
['method'] = true,
['setglobal'] = true,
['getglobal'] = true,
['string'] = true,
['boolean'] = true,
['number'] = true,
['integer'] = true,
['...'] = true,
['doc.type.name'] = true,
['doc.class.name'] = true,
['doc.extends.name'] = true,
['doc.alias.name'] = true,
['doc.see.name'] = true,
['doc.cast.name'] = true,
['doc.enum.name'] = true,
['doc.field.name'] = true,
}
local function checkRequire(source)
if source.type ~= 'string' then
return nil
end
local callargs = source.parent
if callargs.type ~= 'callargs' then
return
end
if callargs[1] ~= source then
return
end
local call = callargs.parent
local func = call.node
local literal = guide.getLiteral(source)
local libName = vm.getLibraryName(func)
if not libName then
return nil
end
if libName == 'require' then
return rpath.findUrisByRequireName(guide.getUri(source), literal)
elseif libName == 'dofile'
or libName == 'loadfile' then
return workspace.findUrisByFilePath(literal)
end
return nil
end
local function convertIndex(source)
if not source then
return
end
if source.type == 'string'
or source.type == 'boolean'
or source.type == 'number'
or source.type == 'integer' then
local parent = source.parent
if not parent then
return
end
if parent.type == 'setindex'
or parent.type == 'getindex'
or parent.type == 'tableindex' then
return parent
end
end
return source
end
---@async
---@param source parser.object
---@param results table
local function checkSee(source, results)
if source.type ~= 'doc.see.name' then
return
end
local symbols = wssymbol(source[1], guide.getUri(source))
for _, symbol in ipairs(symbols) do
if symbol.name == source[1] then
results[#results+1] = {
target = symbol.source,
source = source,
uri = guide.getUri(symbol.source),
}
end
end
end
---@async
return function (uri, offset)
local ast = files.getState(uri)
if not ast then
return nil
end
local source = convertIndex(findSource(ast, offset, accept))
if not source then
return nil
end
local results = {}
local uris = checkRequire(source)
if uris then
for _, uri0 in ipairs(uris) do
results[#results+1] = {
uri = uri0,
source = source,
target = {
start = 0,
finish = 0,
uri = uri0,
}
}
end
end
checkSee(source, results)
local defs = vm.getDefs(source)
for _, src in ipairs(defs) do
if src.type == 'global' then
goto CONTINUE
end
local root = guide.getRoot(src)
if not root then
goto CONTINUE
end
if src.type == 'self' then
goto CONTINUE
end
src = src.field or src.method or src
if src.type == 'getindex'
or src.type == 'setindex'
or src.type == 'tableindex' then
src = src.index
if not src then
goto CONTINUE
end
if not guide.isLiteral(src) then
goto CONTINUE
end
else
if guide.isLiteral(src)
and src.type ~= 'function'
and src.type ~= 'doc.type.function'
and src.type ~= 'doc.type.table' then
goto CONTINUE
end
end
if src.type == 'doc.class' then
src = src.class
end
if src.type == 'doc.alias' then
src = src.alias
end
if src.type == 'doc.enum' then
src = src.enum
end
if src.type == 'doc.type.field' then
src = src.name
end
if src.type == 'doc.class.name'
or src.type == 'doc.alias.name'
or src.type == 'doc.enum.name' then
if source.type ~= 'doc.type.name'
and source.type ~= 'doc.extends.name'
and source.type ~= 'doc.see.name'
and source.type ~= 'doc.class.name'
and source.type ~= 'doc.alias.name' then
goto CONTINUE
end
end
if src.type == 'doc.generic.name' then
goto CONTINUE
end
if src.type == 'doc.param' then
goto CONTINUE
end
results[#results+1] = {
target = src,
uri = root.uri,
source = source,
}
::CONTINUE::
end
if #results == 0 then
return nil
end
sortResults(results)
jumpSource(results)
return results
end

View File

@@ -0,0 +1,79 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local opMap = {
['+'] = true,
['-'] = true,
['*'] = true,
['/'] = true,
['//'] = true,
['^'] = true,
['<<'] = true,
['>>'] = true,
['&'] = true,
['|'] = true,
['~'] = true,
['..'] = true,
}
local literalMap = {
['number'] = true,
['integer'] = true,
['boolean'] = true,
['string'] = true,
['table'] = true,
}
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
guide.eachSourceType(state.ast, 'binary', function (source)
if source.op.type ~= 'or' then
return
end
local first = source[1]
local second = source[2]
if not first or not second then
return
end
-- a + (b or 0) --> (a + b) or 0
do
if opMap[first.op and first.op.type]
and first.type ~= 'unary'
and not second.op
and literalMap[second.type]
and not literalMap[first[2].type]
then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_AMBIGUITY_1', text:sub(
guide.positionToOffset(state, first.start + 1),
guide.positionToOffset(state, first.finish)
))
}
end
end
-- (a or 0) + c --> a or (0 + c)
do
if opMap[second.op and second.op.type]
and second.type ~= 'unary'
and not first.op
and literalMap[second[1].type]
then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_AMBIGUITY_1', text:sub(
guide.positionToOffset(state, second.start + 1),
guide.positionToOffset(state, second.finish)
))
}
end
end
end)
end

View File

@@ -0,0 +1,121 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
local checkTypes = {
'local',
'setlocal',
'setglobal',
'setfield',
'setindex',
'setmethod',
'tablefield',
'tableindex',
'tableexp',
}
---@param source parser.object
---@return boolean
local function hasMarkType(source)
if not source.bindDocs then
return false
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.type'
or doc.type == 'doc.class' then
return true
end
end
return false
end
---@param source parser.object
---@return boolean
local function hasMarkClass(source)
if not source.bindDocs then
return false
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.class' then
return true
end
end
return false
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local delayer = await.newThrottledDelayer(15)
---@async
guide.eachSourceTypes(state.ast, checkTypes, function (source)
local value = source.value
if not value then
return
end
delayer:delay()
if source.type == 'setlocal' then
local locNode = vm.compileNode(source.node)
if not locNode.hasDefined then
return
end
end
if value.type == 'nil' then
--[[
---@class A
local mt
---@type X
mt._x = nil -- don't warn this
]]
if hasMarkType(source) then
return
end
if source.type == 'setfield'
or source.type == 'setindex' then
return
end
end
local valueNode = vm.compileNode(value)
if source.type == 'setindex'
or source.type == 'tableexp' then
-- boolean[1] = nil
valueNode = valueNode:copy():removeOptional()
end
if value.type == 'getfield'
or value.type == 'getindex' then
-- 由于无法对字段进行类型收窄,
-- 因此将假值移除再进行检查
valueNode = valueNode:copy():setTruthy()
end
local varNode = vm.compileNode(source)
local errs = {}
if vm.canCastType(uri, varNode, valueNode, errs) then
return
end
-- local Cat = setmetatable({}, {__index = Animal}) 允许逆变
if hasMarkClass(source) then
if vm.canCastType(uri, valueNode:copy():remove 'table', varNode) then
return
end
end
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_ASSIGN_TYPE_MISMATCH', {
def = vm.getInfer(varNode):view(uri),
ref = vm.getInfer(valueNode):view(uri),
}) .. '\n' .. vm.viewTypeErrorMessage(uri, errs),
}
end)
end

View File

@@ -0,0 +1,30 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
local currentFunc = guide.getParentFunction(source)
if currentFunc and vm.isAsync(currentFunc, false) then
return
end
await.delay()
if vm.isAsyncCall(source) then
callback {
start = source.node.start,
finish = source.node.finish,
message = lang.script('DIAG_AWAIT_IN_SYNC'),
}
return
end
end)
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'local', function (loc)
if not loc.ref then
return
end
if loc[1] == '_' then
return
end
await.delay()
local locNode = vm.compileNode(loc)
if not locNode.hasDefined then
return
end
for _, ref in ipairs(loc.ref) do
if ref.type == 'setlocal' and ref.value then
await.delay()
local refNode = vm.compileNode(ref)
local value = ref.value
if value.type == 'getfield'
or value.type == 'getindex' then
-- 由于无法对字段进行类型收窄,
-- 因此将假值移除再进行检查
refNode = refNode:copy():setTruthy()
end
local errs = {}
if not vm.canCastType(uri, locNode, refNode, errs) then
callback {
start = ref.start,
finish = ref.finish,
message = lang.script('DIAG_CAST_LOCAL_TYPE', {
def = vm.getInfer(locNode):view(uri),
ref = vm.getInfer(refNode):view(uri),
}) .. '\n' .. vm.viewTypeErrorMessage(uri, errs),
}
end
end
end
end)
end

View File

@@ -0,0 +1,46 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.cast' and doc.name then
await.delay()
local defs = vm.getDefs(doc.name)
local loc = defs[1]
if loc then
local defNode = vm.compileNode(loc)
if defNode.hasDefined then
for _, cast in ipairs(doc.casts) do
if not cast.mode and cast.extends then
local refNode = vm.compileNode(cast.extends)
local errs = {}
if not vm.canCastType(uri, defNode, refNode, errs) then
assert(errs)
callback {
start = cast.extends.start,
finish = cast.extends.finish,
message = lang.script('DIAG_CAST_TYPE_MISMATCH', {
def = vm.getInfer(defNode):view(uri),
ref = vm.getInfer(refNode):view(uri),
}) .. '\n' .. vm.viewTypeErrorMessage(uri, errs),
}
end
end
end
end
end
end
end
end

View File

@@ -0,0 +1,58 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local guide = require 'parser.guide'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.class' then
if not doc.extends then
goto CONTINUE
end
await.delay()
local myName = guide.getKeyName(doc)
local list = { doc }
local mark = {}
for i = 1, 999 do
local current = list[i]
if not current then
goto CONTINUE
end
if current.extends then
for _, extend in ipairs(current.extends) do
local newName = extend[1]
if newName == myName then
callback {
start = doc.start,
finish = doc.finish,
message = lang.script('DIAG_CIRCLE_DOC_CLASS', myName)
}
goto CONTINUE
end
if newName and not mark[newName] then
mark[newName] = true
local docs = vm.getDocSets(uri, newName)
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class' then
list[#list+1] = otherDoc
end
end
end
end
end
end
::CONTINUE::
end
end
end

View File

@@ -0,0 +1,40 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
guide.eachSourceType(state.ast, 'local', function (source)
if not source.attrs then
return
end
if source.attrs[1][1] ~= 'close' then
return
end
if not source.value then
callback {
start = source.start,
finish = source.finish,
message = lang.script.DIAG_COSE_NON_OBJECT,
}
return
end
local infer = vm.getInfer(source.value)
if not infer:hasClass(uri)
and not infer:hasType(uri, 'nil')
and not infer:hasType(uri, 'table')
and not infer:hasUnknown(uri)
and not infer:hasAny(uri) then
callback {
start = source.value.start,
finish = source.value.finish,
message = lang.script.DIAG_COSE_NON_OBJECT,
}
end
end)
end

View File

@@ -0,0 +1,38 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local mark = {}
---@async
guide.eachSourceType(state.ast, 'break', function (source)
local list = source.parent
if mark[list] then
return
end
mark[list] = true
await.delay()
for i = #list, 1, -1 do
local src = list[i]
if src == source then
if i == #list then
return
end
callback {
start = list[i+1].start,
finish = list[#list].range or list[#list].finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_CODE_AFTER_BREAK,
}
end
end
end)
end

View File

@@ -0,0 +1,41 @@
local files = require 'files'
local converter = require 'proto.converter'
local log = require 'log'
local pformatting = require 'provider.formatting'
---@async
return function(uri, callback)
local state = files.getState(uri)
if not state then
return
end
local text = state.originText
local suc, codeFormat = pcall(require, 'code_format')
if not suc then
return
end
pformatting.updateConfig(uri)
local status, diagnosticInfos = codeFormat.diagnose_file(uri, text)
if not status then
if diagnosticInfos ~= nil then
log.error(diagnosticInfos)
end
return
end
if diagnosticInfos then
for _, diagnosticInfo in ipairs(diagnosticInfos) do
callback {
start = converter.unpackPosition(state, diagnosticInfo.range.start),
finish = converter.unpackPosition(state, diagnosticInfo.range["end"]),
message = diagnosticInfo.message
}
end
end
end

View File

@@ -0,0 +1,51 @@
local files = require "files"
local guide = require "parser.guide"
local lang = require 'language'
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
guide.eachSourceType(state.ast, 'loop', function (source)
local maxNumber = source.max and tonumber(source.max[1])
if not maxNumber then
return
end
local minNumber = source.init and tonumber(source.init[1])
if minNumber and maxNumber and minNumber <= maxNumber then
return
end
if not minNumber and maxNumber ~= 1 then
return
end
if not source.step then
callback {
start = source.init.start,
finish = source.max.finish,
message = lang.script('DIAG_COUNT_DOWN_LOOP'
, ('%s, %s'):format(text:sub(
guide.positionToOffset(state, source.init.start + 1),
guide.positionToOffset(state, source.max.finish)
), '-1')
)
}
else
local stepNumber = tonumber(source.step[1])
if stepNumber and stepNumber > 0 then
callback {
start = source.init.start,
finish = source.step.finish,
message = lang.script('DIAG_COUNT_DOWN_LOOP'
, ('%s, -%s'):format(text:sub(
guide.positionToOffset(state, source.init.start + 1),
guide.positionToOffset(state, source.max.finish)
), source.step[1])
)
}
end
end
end)
end

View File

@@ -0,0 +1,82 @@
local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local guide = require 'parser.guide'
local config = require 'config'
local define = require 'proto.define'
local await = require 'await'
local util = require 'utility'
local types = {'getglobal', 'getfield', 'getindex', 'getmethod'}
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
local dglobals = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
local rspecial = config.get(uri, 'Lua.runtime.special')
guide.eachSourceTypes(ast.ast, types, function (src) ---@async
if src.type == 'getglobal' then
local key = src[1]
if not key then
return
end
if dglobals[key] then
return
end
if rspecial[key] then
return
end
end
await.delay()
local deprecated = vm.getDeprecated(src, true)
if not deprecated then
return
end
await.delay()
local message = lang.script.DIAG_DEPRECATED
local versions
if deprecated.type == 'doc.version' then
local validVersions = vm.getValidVersions(deprecated)
if not validVersions then
return
end
versions = {}
for version, valid in pairs(validVersions) do
if valid then
versions[#versions+1] = version
end
end
table.sort(versions)
if #versions > 0 then
message = ('%s(%s)'):format(message
, lang.script('DIAG_DEFINED_VERSION'
, table.concat(versions, '/')
, config.get(uri, 'Lua.runtime.version'))
)
end
end
if deprecated.type == 'doc.deprecated' then
if deprecated.comment then
message = ('%s(%s)'):format(message, util.trim(deprecated.comment.text))
end
end
callback {
start = src.start,
finish = src.finish,
tags = { define.DiagnosticTag.Deprecated },
message = message,
data = {
versions = versions,
}
}
end)
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm'
local rpath = require 'workspace.require-path'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local cache = vm.getCache 'different-requires'
guide.eachSpecialOf(state.ast, 'require', function (source)
local call = source.parent
if not call or call.type ~= 'call' then
return
end
local arg1 = call.args and call.args[1]
if not arg1 or arg1.type ~= 'string' then
return
end
local literal = arg1[1]
local results = rpath.findUrisByRequireName(uri, literal)
if not results or #results ~= 1 then
return
end
local result = results[1]
if not files.isLua(result) then
return
end
local other = cache[result]
if not other then
cache[result] = {
source = arg1,
require = literal,
}
return
end
if other.require ~= literal then
callback {
start = arg1.start,
finish = arg1.finish,
related = {
{
start = other.source.start,
finish = other.source.finish,
uri = guide.getUri(other.source),
}
},
message = lang.script('DIAG_DIFFERENT_REQUIRES'),
}
end
end)
end

View File

@@ -0,0 +1,30 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
local lang = require 'language'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
if not guide.isBlockType(source.parent) then
return
end
if source.parent.filter == source then
return
end
await.delay()
if vm.isNoDiscard(source.node, true) then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_DISCARD_RETURNS'),
}
end
end)
end

View File

@@ -0,0 +1,41 @@
local files = require 'files'
local lang = require 'language'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type ~= 'doc.field' then
goto CONTINUE
end
local bindGroup = doc.bindGroup
if not bindGroup then
goto CONTINUE
end
local ok
for _, other in ipairs(bindGroup) do
if other.type == 'doc.class' then
ok = true
break
end
if other == doc then
break
end
end
if not ok then
callback {
start = doc.start,
finish = doc.finish,
message = lang.script('DIAG_DOC_FIELD_NO_CLASS'),
}
end
::CONTINUE::
end
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local guide = require 'parser.guide'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
local merged = {}
local cache = {}
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.alias'
or doc.type == 'doc.enum' then
local name = guide.getKeyName(doc)
if not name then
return
end
await.delay()
if not cache[name] then
local docs = vm.getDocSets(uri, name)
cache[name] = {}
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.alias'
or otherDoc.type == 'doc.class'
or otherDoc.type == 'doc.enum' then
cache[name][#cache[name]+1] = {
start = otherDoc.start,
finish = otherDoc.finish,
uri = guide.getUri(otherDoc),
}
merged[name] = merged[name] or vm.docHasAttr(otherDoc, 'partial')
end
end
end
if not merged[name] and #cache[name] > 1 then
callback {
start = (doc.alias or doc.enum).start,
finish = (doc.alias or doc.enum).finish,
related = cache,
message = lang.script('DIAG_DUPLICATE_DOC_ALIAS', name)
}
end
end
end
end

View File

@@ -0,0 +1,94 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm.vm'
local await = require 'await'
local guide = require 'parser.guide'
local function isDocFunc(doc)
if not doc.extends then
return false
end
if #doc.extends.types ~= 1 then
return false
end
local docFunc = doc.extends.types[1]
if docFunc.type ~= 'doc.type.function' then
return false
end
return true
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
local cachedKeys = {}
---@param field parser.object
---@return string?
local function viewKey(field)
if isDocFunc(field) then
return nil
end
if not cachedKeys[field] then
local view = vm.viewKey(field, uri)
cachedKeys[field] = view or false
end
return cachedKeys[field] or nil
end
---@async
---@param myField parser.object
local function checkField(myField)
await.delay()
local myView = viewKey(myField)
if not myView then
return
end
local class = myField.class
if not class then
return
end
for _, set in ipairs(vm.getGlobal('type', class.class[1]):getSets(uri)) do
if not set.fields then
goto CONTINUE
end
for _, field in ipairs(set.fields) do
if field == myField then
goto CONTINUE
end
local view = viewKey(field)
if view ~= myView then
goto CONTINUE
end
callback {
start = myField.field.start,
finish = myField.field.finish,
message = lang.script('DIAG_DUPLICATE_DOC_FIELD', myView),
related = {{
start = field.field.start,
finish = field.field.finish,
uri = guide.getUri(field),
}}
}
do return end
::CONTINUE::
end
::CONTINUE::
end
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.field' then
checkField(doc)
end
end
end

View File

@@ -0,0 +1,37 @@
local files = require 'files'
local lang = require 'language'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type ~= 'doc.param' then
goto CONTINUE
end
local name = doc.param[1]
local bindGroup = doc.bindGroup
if not bindGroup then
goto CONTINUE
end
for _, other in ipairs(bindGroup) do
if other ~= doc
and other.type == 'doc.param'
and other.param[1] == name then
callback {
start = doc.param.start,
finish = doc.param.finish,
message = lang.script('DIAG_DUPLICATE_DOC_PARAM', name)
}
goto CONTINUE
end
end
::CONTINUE::
end
end

View File

@@ -0,0 +1,65 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
local await = require 'await'
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
---@async
guide.eachSourceType(ast.ast, 'table', function (source)
await.delay()
local mark = {}
for _, obj in ipairs(source) do
if obj.type == 'tablefield'
or obj.type == 'tableindex'
or obj.type == 'tableexp' then
local name = guide.getKeyName(obj)
if name then
if not mark[name] then
mark[name] = {}
end
mark[name][#mark[name]+1] = obj.field or obj.index or obj.value
end
end
end
for name, defs in pairs(mark) do
if #defs > 1 and name then
local related = {}
for i = 1, #defs do
local def = defs[i]
related[i] = {
start = def.start,
finish = def.finish,
uri = uri,
}
end
for i = 1, #defs - 1 do
local def = defs[i]
callback {
start = def.start,
finish = def.finish,
related = related,
message = lang.script('DIAG_DUPLICATE_INDEX', name),
level = define.DiagnosticSeverity.Hint,
tags = { define.DiagnosticTag.Unnecessary },
}
end
for i = #defs, #defs do
local def = defs[i]
callback {
start = def.start,
finish = def.finish,
related = related,
message = lang.script('DIAG_DUPLICATE_INDEX', name),
}
end
end
end
end)
end

View File

@@ -0,0 +1,88 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
local sourceTypes = {
'setfield',
'setmethod',
'setindex',
}
---@param source parser.object
---@return parser.object?
local function getTopFunctionOfIf(source)
while true do
if source.type == 'ifblock'
or source.type == 'elseifblock'
or source.type == 'elseblock'
or source.type == 'function'
or source.type == 'main' then
return source
end
source = source.parent
end
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if vm.isMetaFile(uri) then
return
end
---@async
guide.eachSourceTypes(state.ast, sourceTypes, function (src)
await.delay()
local name = guide.getKeyName(src)
if not name then
return
end
local value = vm.getObjectValue(src)
if not value or value.type ~= 'function' then
return
end
local myTopBlock = getTopFunctionOfIf(src)
local defs = vm.getDefs(src)
for _, def in ipairs(defs) do
if def == src then
goto CONTINUE
end
if def.type ~= 'setfield'
and def.type ~= 'setmethod'
and def.type ~= 'setindex' then
goto CONTINUE
end
local defTopBlock = getTopFunctionOfIf(def)
if uri == guide.getUri(def) and myTopBlock ~= defTopBlock then
goto CONTINUE
end
local defValue = vm.getObjectValue(def)
if not defValue or defValue.type ~= 'function' then
goto CONTINUE
end
if vm.getDefinedClass(guide.getUri(def), def.node)
and not vm.getDefinedClass(guide.getUri(src), src.node)
then
-- allow type variable to override function defined in class variable
goto CONTINUE
end
callback {
start = src.start,
finish = src.finish,
related = {{
start = def.start,
finish = def.finish,
uri = guide.getUri(def),
}},
message = lang.script('DIAG_DUPLICATE_SET_FIELD', name),
}
::CONTINUE::
end
end)
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
local await = require 'await'
-- 检查空代码块
-- 但是排除忙等待repeat/while)
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
await.delay()
guide.eachSourceType(ast.ast, 'if', function (source)
for _, block in ipairs(source) do
if #block > 0 then
return
end
end
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_EMPTY_BLOCK,
}
end)
await.delay()
guide.eachSourceType(ast.ast, 'loop', function (source)
if #source > 0 then
return
end
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_EMPTY_BLOCK,
}
end)
await.delay()
guide.eachSourceType(ast.ast, 'in', function (source)
if #source > 0 then
return
end
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_EMPTY_BLOCK,
}
end)
end

View File

@@ -0,0 +1,75 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
local util = require 'utility'
local function isDocClass(source)
if not source.bindDocs then
return false
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.class' then
return true
end
end
return false
end
local function isGlobalRegex(name, definedGlobalRegex)
if not definedGlobalRegex then
return false
end
for _, pattern in ipairs(definedGlobalRegex) do
if name:match(pattern) then
return true
end
end
return false
end
-- If global elements are discouraged by coding convention, this diagnostic helps with reminding about that
-- Exceptions may be added to Lua.diagnostics.globals
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
local definedGlobal = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
local definedGlobalRegex = config.get(uri, 'Lua.diagnostics.globalsRegex')
guide.eachSourceType(ast.ast, 'setglobal', function (source)
local name = guide.getKeyName(source)
if not name or definedGlobal[name] then
return
end
-- If the assignment is marked as doc.class, then it is considered allowed
if isDocClass(source) then
return
end
if isGlobalRegex(name, definedGlobalRegex) then
return
end
if definedGlobal[name] == nil then
definedGlobal[name] = false
local global = vm.getGlobal('variable', name)
if global then
for _, set in ipairs(global:getSets(uri)) do
if vm.isMetaFile(guide.getUri(set)) then
definedGlobal[name] = true
return
end
end
end
end
callback {
start = source.start,
finish = source.finish,
message = lang.script.DIAG_GLOBAL_ELEMENT,
}
end)
end

View File

@@ -0,0 +1,39 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local function check(source)
local node = source.node
if node.tag == '_ENV' then
return
end
if guide.isParam(node) then
return
end
if not node.value or node.value.type == 'nil' then
callback {
start = source.start,
finish = source.finish,
uri = uri,
message = lang.script.DIAG_GLOBAL_IN_NIL_ENV,
related = {
{
start = node.start,
finish = node.finish,
uri = uri,
}
}
}
end
end
guide.eachSourceType(state.ast, 'getglobal', check)
guide.eachSourceType(state.ast, 'setglobal', check)
end

View File

@@ -0,0 +1,83 @@
local lang = require 'language'
local m = {}
local function findParam(docs, param)
if not docs then
return false
end
for _, doc in ipairs(docs) do
if doc.type == 'doc.param' then
if doc.param[1] == param then
return true
end
end
end
return false
end
local function findReturn(docs, index)
if not docs then
return false
end
for _, doc in ipairs(docs) do
if doc.type == 'doc.return' then
for _, ret in ipairs(doc.returns) do
if ret.returnIndex == index then
return true
end
end
end
end
return false
end
local function checkFunction(source, callback, commentId, paramId, returnId)
local functionName = source.parent[1]
local argCount = source.args and #source.args or 0
if argCount == 0 and not source.returns and not source.bindDocs then
callback {
start = source.start,
finish = source.finish,
message = lang.script(commentId, functionName),
}
end
if argCount > 0 then
for _, arg in ipairs(source.args) do
local argName = arg[1]
if argName ~= 'self'
and argName ~= '_' then
if not findParam(source.bindDocs, argName) then
callback {
start = arg.start,
finish = arg.finish,
message = lang.script(paramId, argName, functionName),
}
end
end
end
end
if source.returns then
for _, ret in ipairs(source.returns) do
for index, expr in ipairs(ret) do
if not findReturn(source.bindDocs, index) then
callback {
start = expr.start,
finish = expr.finish,
message = lang.script(returnId, index, functionName),
}
end
end
end
end
end
m.CheckFunction = checkFunction
return m

View File

@@ -0,0 +1,109 @@
-- incomplete-signature-doc
local files = require 'files'
local lang = require 'language'
local guide = require "parser.guide"
local await = require 'await'
local function findParam(docs, param)
if not docs then
return false
end
for _, doc in ipairs(docs) do
if doc.type == 'doc.param' then
if doc.param[1] == param then
return true
end
end
end
return false
end
local function findReturn(docs, index)
if not docs then
return false
end
for _, doc in ipairs(docs) do
if doc.type == 'doc.return' then
for _, ret in ipairs(doc.returns) do
if ret.returnIndex == index then
return true
end
end
end
end
return false
end
--- check if there's any signature doc (@param or @return), or just comments, @async, ...
local function findSignatureDoc(docs)
if not docs then
return false
end
for _, doc in ipairs(docs) do
if doc.type == 'doc.return' or doc.type == 'doc.param' then
return true
end
end
return false
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast then
return
end
---@async
guide.eachSourceType(state.ast, 'function', function (source)
await.delay()
if not source.bindDocs then
return
end
--- don't apply rule if there is no @param or @return annotation yet
--- so comments and @async can be applied without the need for a full documentation
if(not findSignatureDoc(source.bindDocs)) then
return
end
if source.args and #source.args > 0 then
for _, arg in ipairs(source.args) do
local argName = arg[1]
if argName ~= 'self'
and argName ~= '_' then
if not findParam(source.bindDocs, argName) then
callback {
start = arg.start,
finish = arg.finish,
message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_PARAM', argName),
}
end
end
end
end
if source.returns then
for _, ret in ipairs(source.returns) do
for index, expr in ipairs(ret) do
if not findReturn(source.bindDocs, index) then
callback {
start = expr.start,
finish = expr.finish,
message = lang.script('DIAG_INCOMPLETE_SIGNATURE_DOC_RETURN', index),
}
end
end
end
end
end)
end

View File

@@ -0,0 +1,191 @@
local files = require 'files'
local define = require 'proto.define'
local config = require 'config'
local await = require 'await'
local vm = require "vm.vm"
local util = require 'utility'
local diagd = require 'proto.diagnostic'
local sleepRest = 0.0
---@async
local function checkSleep(uri, passed)
local speedRate = config.get(uri, 'Lua.diagnostics.workspaceRate')
if speedRate <= 0 or speedRate >= 100 then
return
end
local sleepTime = passed * (100 - speedRate) / speedRate
if sleepTime + sleepRest < 0.001 then
sleepRest = sleepRest + sleepTime
return
end
sleepRest = sleepTime + sleepRest
sleepTime = sleepRest
if sleepTime > 0.1 then
sleepTime = 0.1
end
local clock = os.clock()
await.sleep(sleepTime)
local sleeped = os.clock() - clock
sleepRest = sleepRest - sleeped
end
---@param uri uri
---@param name string
---@return string
local function getSeverity(uri, name)
local severity = config.get(uri, 'Lua.diagnostics.severity')[name]
or define.DiagnosticDefaultSeverity[name]
if severity:sub(-1) == '!' then
return severity:sub(1, -2)
end
local groupSeverity = config.get(uri, 'Lua.diagnostics.groupSeverity')
local groups = diagd.getGroups(name)
local groupLevel = 999
for _, groupName in ipairs(groups) do
local gseverity = groupSeverity[groupName]
if gseverity and gseverity ~= 'Fallback' then
groupLevel = math.min(groupLevel, define.DiagnosticSeverity[gseverity])
end
end
if groupLevel == 999 then
return severity
end
for severityName, level in pairs(define.DiagnosticSeverity) do
if level == groupLevel then
return severityName
end
end
return severity
end
---@param uri uri
---@param name string
---@return string
local function getStatus(uri, name)
local status = config.get(uri, 'Lua.diagnostics.neededFileStatus')[name]
or define.DiagnosticDefaultNeededFileStatus[name]
if status:sub(-1) == '!' then
return status:sub(1, -2)
end
local groupStatus = config.get(uri, 'Lua.diagnostics.groupFileStatus')
local groups = diagd.getGroups(name)
local groupLevel = 0
for _, groupName in ipairs(groups) do
local gstatus = groupStatus[groupName]
if gstatus and gstatus ~= 'Fallback' then
groupLevel = math.max(groupLevel, define.DiagnosticFileStatus[gstatus])
end
end
if groupLevel == 0 then
return status
end
for statusName, level in pairs(define.DiagnosticFileStatus) do
if level == groupLevel then
return statusName
end
end
return status
end
---@async
---@param uri uri
---@param name string
---@param isScopeDiag boolean
---@param response async fun(result: any)
---@param ignoreFileOpenState? boolean
---@return boolean
local function check(uri, name, isScopeDiag, response, ignoreFileOpenState)
local disables = config.get(uri, 'Lua.diagnostics.disable')
if util.arrayHas(disables, name) then
return false
end
local severity = getSeverity(uri, name)
local status = getStatus(uri, name)
if status == 'None' then
return false
end
if not ignoreFileOpenState and status == 'Opened' and not files.isOpen(uri) then
return false
end
local level = define.DiagnosticSeverity[severity]
local clock = os.clock()
local mark = {}
---@async
require('core.diagnostics.' .. name)(uri, function (result)
if vm.isDiagDisabledAt(uri, result.start, name) then
return
end
if result.start < 0 then
return
end
if mark[result.start] then
return
end
mark[result.start] = true
result.level = level or result.level
result.code = name
response(result)
end, name)
local passed = os.clock() - clock
if passed >= 0.5 then
log.warn(('Diagnostics [%s] @ [%s] takes [%.3f] sec!'):format(name, uri, passed))
end
if isScopeDiag then
checkSleep(uri, passed)
end
if DIAGTIMES then
DIAGTIMES[name] = (DIAGTIMES[name] or 0) + passed
end
return true
end
local diagList
local diagCosts = {}
local diagCount = {}
local function buildDiagList()
if not diagList then
diagList = {}
for name in pairs(define.DiagnosticDefaultSeverity) do
diagList[#diagList+1] = name
end
end
table.sort(diagList, function (a, b)
local time1 = (diagCosts[a] or 0) / (diagCount[a] or 1)
local time2 = (diagCosts[b] or 0) / (diagCount[b] or 1)
return time1 < time2
end)
return diagList
end
---@async
---@param uri uri
---@param isScopeDiag boolean
---@param response async fun(result: any)
---@param checked? async fun(name: string)
---@param ignoreFileOpenState? boolean
return function (uri, isScopeDiag, response, checked, ignoreFileOpenState)
local ast = files.getState(uri)
if not ast then
return nil
end
for _, name in ipairs(buildDiagList()) do
await.delay()
local clock = os.clock()
local suc = check(uri, name, isScopeDiag, response, ignoreFileOpenState)
if suc then
local cost = os.clock() - clock
diagCosts[name] = (diagCosts[name] or 0) + cost
diagCount[name] = (diagCount[name] or 0) + 1
end
if checked then
checked(name)
end
end
end

View File

@@ -0,0 +1,150 @@
local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
local hname = require 'core.hover.name'
local skipCheckClass = {
['unknown'] = true,
['any'] = true,
['table'] = true,
}
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
---@async
local function checkInjectField(src)
await.delay()
local node = src.node
if not node then
return
end
local ok
for view in vm.getInfer(node):eachView(uri) do
if skipCheckClass[view] then
return
end
ok = true
end
if not ok then
return
end
local isExact
local class = vm.getDefinedClass(uri, node)
if class then
for _, doc in ipairs(class:getSets(uri)) do
if vm.docHasAttr(doc, 'exact') then
isExact = true
break
end
end
if not isExact then
return
end
if src.type == 'setmethod'
and not guide.getSelfNode(node) then
return
end
end
for _, def in ipairs(vm.getDefs(src)) do
local dnode = def.node
if dnode
and not isExact
and vm.getDefinedClass(uri, dnode) then
return
end
if def.type == 'doc.type.field' then
return
end
if def.type == 'doc.field' then
return
end
if def.type == 'tablefield' and not isExact then
return
end
end
local howToFix = ''
if not isExact then
howToFix = lang.script('DIAG_INJECT_FIELD_FIX_CLASS', {
node = hname(node),
fix = '---@class',
})
for _, ndef in ipairs(vm.getDefs(node)) do
if ndef.type == 'doc.type.table' then
howToFix = lang.script('DIAG_INJECT_FIELD_FIX_TABLE', {
fix = '[any]: any',
})
break
end
end
end
local message = lang.script('DIAG_INJECT_FIELD', {
class = vm.getInfer(node):view(uri),
field = guide.getKeyName(src),
fix = howToFix,
})
if src.type == 'setfield' and src.field then
callback {
start = src.field.start,
finish = src.field.finish,
message = message,
}
elseif src.type == 'setmethod' and src.method then
callback {
start = src.method.start,
finish = src.method.finish,
message = message,
}
end
end
guide.eachSourceType(ast.ast, 'setfield', checkInjectField)
guide.eachSourceType(ast.ast, 'setmethod', checkInjectField)
---@async
local function checkExtraTableField(src)
await.delay()
if not src.bindSource then
return
end
if not vm.docHasAttr(src, 'exact') then
return
end
local value = src.bindSource.value
if not value or value.type ~= 'table' then
return
end
for _, field in ipairs(value) do
local defs = vm.getDefs(field)
for _, def in ipairs(defs) do
if def.type == 'doc.field' then
goto nextField
end
end
local message = lang.script('DIAG_INJECT_FIELD', {
class = vm.getInfer(src):view(uri),
field = guide.getKeyName(src),
fix = '',
})
callback {
start = field.start,
finish = field.finish,
message = message,
}
::nextField::
end
end
guide.eachSourceType(ast.ast, 'doc.class', checkExtraTableField)
end

View File

@@ -0,0 +1,67 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm.vm'
local await = require 'await'
local checkTypes = {'getfield', 'setfield', 'getmethod', 'setmethod', 'getindex', 'setindex'}
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceTypes(state.ast, checkTypes, function (src)
local child = src.field or src.method or src.index
if not child then
return
end
local key = guide.getKeyName(src)
if not key then
return
end
await.delay()
local defs = vm.getDefs(src)
for _, def in ipairs(defs) do
if not vm.isVisible(src.node, def) then
if vm.getVisibleType(def) == 'private' then
callback {
start = child.start,
finish = child.finish,
uri = uri,
message = lang.script('DIAG_INVISIBLE_PRIVATE', {
field = key,
class = vm.getParentClass(def):getName(),
}),
}
elseif vm.getVisibleType(def) == 'protected' then
callback {
start = child.start,
finish = child.finish,
uri = uri,
message = lang.script('DIAG_INVISIBLE_PROTECTED', {
field = key,
class = vm.getParentClass(def):getName(),
}),
}
elseif vm.getVisibleType(def) == 'package' then
callback {
start = child.start,
finish = child.finish,
uri = uri,
message = lang.script('DIAG_INVISIBLE_PACKAGE', {
field = key,
uri = guide.getUri(def),
}),
}
else
error('Unknown visible type: ' .. vm.getVisibleType(def))
end
break
end
end
end)
end

View File

@@ -0,0 +1,81 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local config = require 'config'
local vm = require 'vm'
local util = require 'utility'
local function isDocClass(source)
if not source.bindDocs then
return false
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.class' then
return true
end
end
return false
end
local function isGlobalRegex(name, definedGlobalRegex)
if not definedGlobalRegex then
return false
end
for _, pattern in ipairs(definedGlobalRegex) do
if name:match(pattern) then
return true
end
end
return false
end
-- 不允许定义首字母小写的全局变量(很可能是拼错或者漏删)
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
local definedGlobal = util.arrayToHash(config.get(uri, 'Lua.diagnostics.globals'))
local definedGlobalRegex = config.get(uri, 'Lua.diagnostics.globalsRegex')
guide.eachSourceType(ast.ast, 'setglobal', function (source)
local name = guide.getKeyName(source)
if not name or definedGlobal[name] then
return
end
local first = name:match '%w'
if not first then
return
end
if not first:match '%l' then
return
end
-- 如果赋值被标记为 doc.class ,则认为是允许的
if isDocClass(source) then
return
end
if isGlobalRegex(name, definedGlobalRegex) then
return
end
if definedGlobal[name] == nil then
definedGlobal[name] = false
local global = vm.getGlobal('variable', name)
if global then
for _, set in ipairs(global:getSets(uri)) do
if vm.isMetaFile(guide.getUri(set)) then
definedGlobal[name] = true
return
end
end
end
end
callback {
start = source.start,
finish = source.finish,
message = lang.script.DIAG_LOWERCASE_GLOBAL,
}
end)
end

View File

@@ -0,0 +1,118 @@
local vm = require 'vm'
local files = require 'files'
local guide = require 'parser.guide'
local await = require 'await'
local lang = require 'language'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'table', function (src)
await.delay()
vm.removeNode(src) -- the node is not updated correctly, reason still unknown
local defs = vm.getDefs(src)
local sortedDefs = {}
for _, def in ipairs(defs) do
if def.type == 'doc.class' then
if def.bindSource and guide.isInRange(def.bindSource, src.start) then
return
end
local className = def.class[1]
if not sortedDefs[className] then
sortedDefs[className] = {}
-- check if this class is a `partial` class
-- a partial class will not check missing inherited fields
local class = vm.getGlobal('type', className)
---@cast class -nil
for _, set in ipairs(class:getSets(uri)) do
if set.type == 'doc.class'
and vm.docHasAttr(set, 'partial')
then
sortedDefs[className].isPartial = true
break
end
end
end
local samedefs = sortedDefs[className]
samedefs[#samedefs+1] = def
end
if def.type == 'doc.type.array'
or def.type == 'doc.type.table' then
return
end
end
local myKeys
local warnings = {}
for className, samedefs in pairs(sortedDefs) do
local missedKeys = {}
for _, def in ipairs(samedefs) do
local fields = samedefs.isPartial and def.fields or vm.getFields(def)
if not fields or #fields == 0 then
goto continue
end
if not myKeys then
myKeys = {}
for _, field in ipairs(src) do
local key = vm.getKeyName(field) or field.tindex
if key then
myKeys[key] = true
end
end
end
for _, field in ipairs(fields) do
if not field.optional
and field.type == "doc.field"
and not vm.compileNode(field):isNullable() then
local key = vm.getKeyName(field)
if not key then
local fieldnode = vm.compileNode(field.field)[1]
if fieldnode and fieldnode.type == 'doc.type.integer' then
---@cast fieldnode parser.object
key = vm.getKeyName(fieldnode)
end
end
if key and not myKeys[key] then
if type(key) == "number" then
missedKeys[#missedKeys+1] = ('`[%s]`'):format(key)
else
missedKeys[#missedKeys+1] = ('`%s`'):format(key)
end
end
end
end
::continue::
if not samedefs.isPartial then
-- if not partial class, then all fields in this class have already been checked
-- because in the above uses `vm.getFields` to get all fields
break
end
end
if #missedKeys == 0 then
return
end
warnings[#warnings+1] = lang.script('DIAG_MISSING_FIELDS', className, table.concat(missedKeys, ', '))
end
if #warnings == 0 then
return
end
callback {
start = src.start,
finish = src.finish,
message = table.concat(warnings, '\n')
}
end)
end

View File

@@ -0,0 +1,27 @@
local files = require 'files'
local guide = require "parser.guide"
local await = require 'await'
local helper = require 'core.diagnostics.helper.missing-doc-helper'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast then
return
end
---@async
guide.eachSourceType(state.ast, 'function', function (source)
await.delay()
if source.parent.type ~= 'setglobal' then
return
end
helper.CheckFunction(source, callback, 'DIAG_MISSING_GLOBAL_DOC_COMMENT', 'DIAG_MISSING_GLOBAL_DOC_PARAM', 'DIAG_MISSING_GLOBAL_DOC_RETURN')
end)
end

View File

@@ -0,0 +1,52 @@
local files = require 'files'
local guide = require "parser.guide"
local await = require 'await'
local helper = require 'core.diagnostics.helper.missing-doc-helper'
---@async
local function findSetField(ast, name, callback)
---@async
guide.eachSourceType(ast, 'setfield', function (source)
await.delay()
if source.node[1] == name then
local funcPtr = source.value.node
if not funcPtr then
return
end
local func = funcPtr.value
if not func then
return
end
if funcPtr.type == 'local' and func.type == 'function' then
helper.CheckFunction(func, callback, 'DIAG_MISSING_LOCAL_EXPORT_DOC_COMMENT', 'DIAG_MISSING_LOCAL_EXPORT_DOC_PARAM', 'DIAG_MISSING_LOCAL_EXPORT_DOC_RETURN')
end
end
end)
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast then
return
end
---@async
guide.eachSourceType(state.ast, 'return', function (source)
await.delay()
--table
for _, ret in ipairs(source) do
if ret.type == 'getlocal' then
if ret.node.value and ret.node.value.type == 'table' then
findSetField(state.ast, ret[1], callback)
end
end
end
end)
end

View File

@@ -0,0 +1,32 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
await.delay()
local _, callArgs = vm.countList(source.args)
local funcNode = vm.compileNode(source.node)
local funcArgs = vm.countParamsOfNode(funcNode)
if callArgs >= funcArgs then
return
end
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_MISS_ARGS', funcArgs, callArgs),
}
end)
end

View File

@@ -0,0 +1,51 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'function', function (source)
await.delay()
local returns = source.returns
if not returns then
return
end
local min = vm.countReturnsOfSource(source)
if min == 0 then
return
end
for _, ret in ipairs(returns) do
local rmin, rmax = vm.countList(ret)
if rmax < min then
if rmin == rmax then
callback {
start = ret.start,
finish = ret.start + #'return',
message = lang.script('DIAG_MISSING_RETURN_VALUE', {
min = min,
rmax = rmax,
}),
}
else
callback {
start = ret.start,
finish = ret.start + #'return',
message = lang.script('DIAG_MISSING_RETURN_VALUE_RANGE', {
min = min,
rmin = rmin,
rmax = rmax,
}),
}
end
end
end
end)
end

View File

@@ -0,0 +1,76 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@param block parser.object
---@return boolean
local function hasReturn(block)
if block.hasReturn or block.hasExit then
return true
end
if block.type == 'if' then
local hasElse
for _, subBlock in ipairs(block) do
if not hasReturn(subBlock) then
return false
end
if subBlock.type == 'elseblock' then
hasElse = true
end
end
return hasElse == true
else
if block.type == 'while' then
if vm.testCondition(block.filter) then
return true
end
end
for _, action in ipairs(block) do
if guide.isBlockType(action) then
if hasReturn(action) then
return true
end
end
end
end
return false
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local isMeta = vm.isMetaFile(uri)
---@async
guide.eachSourceType(state.ast, 'function', function (source)
-- check declare only
if isMeta and vm.isEmptyFunction(source) then
return
end
await.delay()
if vm.countReturnsOfSource(source) == 0 then
return
end
if hasReturn(source) then
return
end
local lastAction = source[#source]
local pos
if lastAction then
pos = lastAction.range or lastAction.finish
else
pos = source.keyword[3] or source.finish
end
callback {
start = pos,
finish = pos,
message = lang.script('DIAG_MISSING_RETURN'),
}
end)
end

View File

@@ -0,0 +1,35 @@
local files = require 'files'
local converter = require 'proto.converter'
local log = require 'log'
local nameStyle = require 'provider.name-style'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local text = state.originText
local status, diagnosticInfos = nameStyle.nameStyleCheck(uri, text)
if not status then
if diagnosticInfos ~= nil then
log.error(diagnosticInfos)
end
return
end
if diagnosticInfos then
for _, diagnosticInfo in ipairs(diagnosticInfos) do
callback {
start = converter.unpackPosition(state, diagnosticInfo.range.start),
finish = converter.unpackPosition(state, diagnosticInfo.range["end"]),
message = diagnosticInfo.message,
data = diagnosticInfo.data
}
end
end
end

View File

@@ -0,0 +1,48 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local delayer = await.newThrottledDelayer(500)
---@async
guide.eachSourceType(state.ast, 'getlocal', function (src)
delayer:delay()
local checkNil
local nxt = src.next
if nxt then
if nxt.type == 'getfield'
or nxt.type == 'getmethod'
or nxt.type == 'getindex'
or nxt.type == 'call' then
checkNil = true
end
end
local call = src.parent
if call and call.type == 'call' and call.node == src then
checkNil = true
end
local setIndex = src.parent
if setIndex and setIndex.type == 'setindex' and setIndex.index == src then
checkNil = true
end
if not checkNil then
return
end
local node = vm.compileNode(src)
if node:hasFalsy() and not vm.getInfer(src):hasType(uri, 'any') then
callback {
start = src.start,
finish = src.finish,
message = lang.script('DIAG_NEED_CHECK_NIL'),
}
end
end)
end

View File

@@ -0,0 +1,49 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local await = require 'await'
local sub = require 'core.substring'
---@async
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
---@async
guide.eachSourceType(state.ast, 'table', function (source)
await.delay()
for i = 1, #source do
local field = source[i]
if field.type ~= 'tableexp' then
goto CONTINUE
end
local call = field.value
if not call then
goto CONTINUE
end
if call.type ~= 'call' then
return
end
local func = call.node
local args = call.args
if args then
local funcLine = guide.rowColOf(func.finish)
local argsLine = guide.rowColOf(args.start)
if argsLine > funcLine then
callback {
start = call.start,
finish = call.finish,
message = lang.script('DIAG_PREFIELD_CALL'
, sub(state)(func.start + 1, func.finish)
, sub(state)(args.start + 1, args.finish)
)
}
end
end
::CONTINUE::
end
end)
end

View File

@@ -0,0 +1,54 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local await = require 'await'
local sub = require 'core.substring'
---@async
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
local node = source.node
local args = source.args
if not args then
return
end
-- 必须有其他人在继续使用当前对象
if not source.next then
return
end
await.delay()
local startOffset = guide.positionToOffset(state, args.start) + 1
local finishOffset = guide.positionToOffset(state, args.finish)
if text:sub(startOffset, startOffset) ~= '('
or text:sub(finishOffset, finishOffset) ~= ')' then
return
end
local nodeRow = guide.rowColOf(node.finish)
local argRow = guide.rowColOf(args.start)
if nodeRow == argRow then
return
end
if #args == 1 then
callback {
start = node.start,
finish = args.finish,
message = lang.script('DIAG_PREVIOUS_CALL'
, sub(state)(node.start + 1, node.finish)
, sub(state)(args.start + 1, args.finish)
),
}
end
end)
end

View File

@@ -0,0 +1,36 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm'
local await = require 'await'
local types = {
'local',
'setlocal',
'setglobal',
'getglobal',
'setfield',
'setindex',
'tablefield',
'tableindex',
}
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
---@async
guide.eachSourceTypes(ast.ast, types, function (source)
await.delay()
if vm.getInfer(source):view(uri) == 'unknown' then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_UNKNOWN'),
}
end
end)
end

View File

@@ -0,0 +1,63 @@
local files = require 'files'
local await = require 'await'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local function isYieldAble(defs, i)
local hasFuncDef
for _, def in ipairs(defs) do
if def.type == 'function' then
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
if vm.getInfer(arg):hasType(guide.getUri(def), 'any')
or vm.isAsync(arg, true)
or arg.type == '...' then
return true
end
end
end
if def.type == 'doc.type.function' then
local arg = def.args and def.args[i]
if arg then
hasFuncDef = true
if vm.getInfer(arg.extends):hasType(guide.getUri(def), 'any')
or vm.isAsync(arg.extends, true) then
return true
end
end
end
end
return not hasFuncDef
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
guide.eachSourceType(state.ast, 'call', function (source) ---@async
if not source.args then
return
end
await.delay()
local defs = vm.getDefs(source.node)
if #defs == 0 then
return
end
for i, arg in ipairs(source.args) do
if vm.isAsync(arg, true)
and not vm.isLinkedCall(source.node, i)
and not isYieldAble(defs, i) then
callback {
start = arg.start,
finish = arg.finish,
message = lang.script('DIAG_NOT_YIELDABLE', i),
}
end
end
end)
end

View File

@@ -0,0 +1,122 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
---@param defNode vm.node
local function expandGenerics(defNode)
---@type parser.object[]
local generics = {}
for dn in defNode:eachObject() do
if dn.type == 'doc.generic.name' then
---@cast dn parser.object
generics[#generics+1] = dn
end
end
for _, generic in ipairs(generics) do
defNode:removeObject(generic)
end
for _, generic in ipairs(generics) do
local limits = generic.generic and generic.generic.extends
if limits then
defNode:merge(vm.compileNode(limits))
else
local unknownType = vm.declareGlobal('type', 'unknown')
defNode:merge(unknownType)
end
end
end
---@param funcNode vm.node
---@param i integer
---@return vm.node?
local function getDefNode(funcNode, i)
local defNode = vm.createNode()
for src in funcNode:eachObject() do
if src.type == 'function'
or src.type == 'doc.type.function' then
local param = src.args and src.args[i]
if param then
defNode:merge(vm.compileNode(param))
if param[1] == '...' then
defNode:addOptional()
end
end
end
end
if defNode:isEmpty() then
return nil
end
expandGenerics(defNode)
return defNode
end
---@param funcNode vm.node
---@param i integer
---@return vm.node
local function getRawDefNode(funcNode, i)
local defNode = vm.createNode()
for f in funcNode:eachObject() do
if f.type == 'function'
or f.type == 'doc.type.function' then
local param = f.args and f.args[i]
if param then
defNode:merge(vm.compileNode(param))
end
end
end
return defNode
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
if not source.args then
return
end
await.delay()
local funcNode = vm.compileNode(source.node)
for i, arg in ipairs(source.args) do
local refNode = vm.compileNode(arg)
if not refNode then
goto CONTINUE
end
local defNode = getDefNode(funcNode, i)
if not defNode then
goto CONTINUE
end
if arg.type == 'getfield'
or arg.type == 'getindex'
or arg.type == 'self' then
-- 由于无法对字段进行类型收窄,
-- 因此将假值移除再进行检查
refNode = refNode:copy():setTruthy()
end
local errs = {}
if not vm.canCastType(uri, defNode, refNode, errs) then
local rawDefNode = getRawDefNode(funcNode, i)
assert(errs)
callback {
start = arg.start,
finish = arg.finish,
message = lang.script('DIAG_PARAM_TYPE_MISMATCH', {
def = vm.getInfer(rawDefNode):view(uri),
ref = vm.getInfer(refNode):view(uri),
}) .. '\n' .. vm.viewTypeErrorMessage(uri, errs),
}
end
::CONTINUE::
end
end)
end

View File

@@ -0,0 +1,37 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
---@async
guide.eachSourceType(ast.ast, 'local', function (source)
local name = source[1]
if name == '_'
or name == ast.ENVMode then
return
end
await.delay()
local exist = guide.getLocal(source, name, source.start-1)
if exist then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_REDEFINED_LOCAL', name),
related = {
{
start = exist.start,
finish = exist.finish,
uri = uri,
}
},
}
end
end)
end

View File

@@ -0,0 +1,73 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
await.delay()
local callArgs = vm.countList(source.args)
if callArgs == 0 then
return
end
local funcNode = vm.compileNode(source.node)
local _, funcArgs = vm.countParamsOfNode(funcNode)
if callArgs <= funcArgs then
return
end
if callArgs == 1 and source.node.type == 'getmethod' then
return
end
if funcArgs + 1 > #source.args then
local lastArg = source.args[#source.args]
if lastArg.type == 'call' and funcArgs > 0 then
-- 如果函数接收至少一个参数,那么调用方最后一个参数是函数调用
-- 导致的参数数量太多可以忽略。
-- 如果函数不接收任何参数,那么任何参数都是错误的。
return
end
callback {
start = lastArg.start,
finish = lastArg.finish,
message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
}
else
for i = funcArgs + 1, #source.args do
local arg = source.args[i]
callback {
start = arg.start,
finish = arg.finish,
message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, callArgs)
}
end
end
end)
---@async
guide.eachSourceType(state.ast, 'function', function (source)
await.delay()
if not source.args then
return
end
local _, funcArgs = vm.countParamsOfSource(source)
local myArgs = #source.args
for i = funcArgs + 1, myArgs do
local arg = source.args[i]
callback {
start = arg.start,
finish = arg.finish,
message = lang.script('DIAG_OVER_MAX_ARGS', funcArgs, myArgs),
}
end
end)
end

View File

@@ -0,0 +1,58 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'function', function (source)
local returns = source.returns
if not returns then
return
end
await.delay()
local _, max = vm.countReturnsOfSource(source)
for _, ret in ipairs(returns) do
local rmin, rmax = vm.countList(ret)
if rmin > max then
for i = max + 1, #ret - 1 do
callback {
start = ret[i].start,
finish = ret[i].finish,
message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', {
max = max,
rmax = i,
}),
}
end
if #ret == rmax then
callback {
start = ret[#ret].start,
finish = ret[#ret].finish,
message = lang.script('DIAG_REDUNDANT_RETURN_VALUE', {
max = max,
rmax = rmax,
}),
}
else
callback {
start = ret[#ret].start,
finish = ret[#ret].finish,
message = lang.script('DIAG_REDUNDANT_RETURN_VALUE_RANGE', {
max = max,
rmin = #ret,
rmax = rmax,
}),
}
end
end
end
end)
end

View File

@@ -0,0 +1,27 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local define = require 'proto.define'
-- reports 'return' without any return values at the end of functions
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
guide.eachSourceType(ast.ast, 'return', function (source)
if not source.parent or source.parent.type ~= "function" then
return
end
if #source > 0 then
return
end
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_REDUNDANT_RETURN,
}
end)
end

View File

@@ -0,0 +1,26 @@
local files = require 'files'
local define = require 'proto.define'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
local delayer = await.newThrottledDelayer(50000)
guide.eachSource(state.ast, function (src) ---@async
delayer:delay()
if src.redundant then
callback {
start = src.start,
finish = src.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script('DIAG_OVER_MAX_VALUES', src.redundant.max, src.redundant.passed)
}
end
end)
end

View File

@@ -0,0 +1,94 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local vm = require 'vm'
local await = require 'await'
local util = require 'utility'
---@param func parser.object
---@return vm.node[]?
local function getDocReturns(func)
---@type table<integer, vm.node>
local returns = util.defaultTable(function ()
return vm.createNode()
end)
if func.bindDocs then
for _, doc in ipairs(func.bindDocs) do
if doc.type == 'doc.return' then
for _, ret in ipairs(doc.returns) do
returns[ret.returnIndex]:merge(vm.compileNode(ret))
end
end
if doc.type == 'doc.overload' then
for i, ret in ipairs(doc.overload.returns) do
returns[i]:merge(vm.compileNode(ret))
end
end
end
end
for nd in vm.compileNode(func):eachObject() do
if nd.type == 'doc.type.function' then
---@cast nd parser.object
for i, ret in ipairs(nd.returns) do
returns[i]:merge(vm.compileNode(ret))
end
end
end
setmetatable(returns, nil)
if #returns == 0 then
return nil
end
return returns
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@param docReturns vm.node[]
---@param rets parser.object
local function checkReturn(docReturns, rets)
for i, docRet in ipairs(docReturns) do
local retNode, exp = vm.selectNode(rets, i)
if not exp then
break
end
if retNode:hasName 'nil' then
if exp.type == 'getfield'
or exp.type == 'getindex' then
retNode = retNode:copy():removeOptional()
end
end
local errs = {}
if not vm.canCastType(uri, docRet, retNode, errs) then
callback {
start = exp.start,
finish = exp.finish,
message = lang.script('DIAG_RETURN_TYPE_MISMATCH', {
def = vm.getInfer(docRet):view(uri),
ref = vm.getInfer(retNode):view(uri),
index = i,
}) .. '\n' .. vm.viewTypeErrorMessage(uri, errs),
}
end
end
end
---@async
guide.eachSourceType(state.ast, 'function', function (source)
if not source.returns then
return
end
await.delay()
local docReturns = getDocReturns(source)
if not docReturns then
return
end
for _, ret in ipairs(source.returns) do
checkReturn(docReturns, ret)
await.delay()
end
end)
end

View File

@@ -0,0 +1,35 @@
local files = require 'files'
local converter = require 'proto.converter'
local log = require 'log'
local spell = require 'provider.spell'
---@async
return function(uri, callback)
local state = files.getState(uri)
if not state then
return
end
local text = state.originText
local status, diagnosticInfos = spell.spellCheck(uri, text)
if not status then
if diagnosticInfos ~= nil then
log.error(diagnosticInfos)
end
return
end
if diagnosticInfos then
for _, diagnosticInfo in ipairs(diagnosticInfos) do
callback {
start = converter.unpackPosition(state, diagnosticInfo.range.start),
finish = converter.unpackPosition(state, diagnosticInfo.range["end"]),
message = diagnosticInfo.message,
data = diagnosticInfo.data
}
end
end
end

View File

@@ -0,0 +1,53 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
local text = files.getText(uri)
if not state or not text then
return
end
local delayer = await.newThrottledDelayer(5000)
local lines = state.lines
for i = 0, #lines do
delayer:delay()
local startOffset = lines[i]
local finishOffset = text:find('[\r\n]', startOffset) or (#text + 1)
local lastOffset = finishOffset - 1
local lastChar = text:sub(lastOffset, lastOffset)
if lastChar ~= ' ' and lastChar ~= '\t' then
goto NEXT_LINE
end
local lastPos = guide.offsetToPosition(state, lastOffset)
if guide.isInString(state.ast, lastPos)
or guide.isInComment(state.ast, lastPos) then
goto NEXT_LINE
end
local firstOffset = startOffset
for n = lastOffset - 1, startOffset, -1 do
local char = text:sub(n, n)
if char ~= ' ' and char ~= '\t' then
firstOffset = n + 1
break
end
end
local firstPos = guide.offsetToPosition(state, firstOffset) - 1
if firstOffset == startOffset then
callback {
start = firstPos,
finish = lastPos,
message = lang.script.DIAG_LINE_ONLY_SPACE,
}
else
callback {
start = firstPos,
finish = lastPos,
message = lang.script.DIAG_LINE_POST_SPACE,
}
end
::NEXT_LINE::
end
end

View File

@@ -0,0 +1,49 @@
local files = require 'files'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
local types = {
'local',
'setlocal',
'setglobal',
'setfield',
'setindex' ,
}
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
local last
local function checkSet(source)
if source.value then
last = source
else
if not last then
return
end
if last.start <= source.start
and last.value.start >= source.finish then
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_UNBALANCED_ASSIGNMENTS')
}
else
last = nil
end
end
end
local delayer = await.newThrottledDelayer(1000)
---@async
guide.eachSourceTypes(ast.ast, types, function (source)
delayer:delay()
checkSet(source)
end)
end

View File

@@ -0,0 +1,48 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
local cache = {}
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.class' then
if not doc.extends then
goto CONTINUE
end
for _, ext in ipairs(doc.extends) do
local name = ext.type == 'doc.extends.name' and ext[1]
if name then
local docs = vm.getDocSets(uri, name)
if cache[name] == nil then
cache[name] = false
for _, otherDoc in ipairs(docs) do
if otherDoc.type == 'doc.class' then
cache[name] = true
break
end
end
end
if not cache[name] then
callback {
start = ext.start,
finish = ext.finish,
related = cache,
message = lang.script('DIAG_UNDEFINED_DOC_CLASS', name)
}
end
end
end
end
::CONTINUE::
end
end

View File

@@ -0,0 +1,37 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require 'vm'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
guide.eachSource(state.ast.docs, function (source)
if source.type ~= 'doc.extends.name'
and source.type ~= 'doc.type.name' then
return
end
if source.parent.type == 'doc.class' then
return
end
local name = source[1]
if name == '...' or name == '_' or name == 'self' then
return
end
if #vm.getDocSets(uri, name) > 0 then
return
end
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_UNDEFINED_DOC_NAME', name)
}
end)
end

View File

@@ -0,0 +1,24 @@
local files = require 'files'
local lang = require 'language'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.param'
and not doc.bindSource then
callback {
start = doc.param.start,
finish = doc.param.finish,
message = lang.script('DIAG_UNDEFINED_DOC_PARAM', doc.param[1])
}
end
end
end

View File

@@ -0,0 +1,47 @@
local files = require 'files'
local guide = require 'parser.guide'
local lang = require 'language'
local vm = require "vm.vm"
---@param source parser.object
---@return boolean
local function isBindDoc(source)
if not source.bindDocs then
return false
end
for _, doc in ipairs(source.bindDocs) do
if doc.type == 'doc.type'
or doc.type == 'doc.class' then
return true
end
end
return false
end
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
guide.eachSourceType(state.ast, 'getglobal', function (source)
if source.node.tag == '_ENV' then
return
end
if not isBindDoc(source.node) then
return
end
if #vm.getDefs(source) > 0 then
return
end
local key = source[1]
callback {
start = source.start,
finish = source.finish,
message = lang.script('DIAG_UNDEF_ENV_CHILD', key),
}
end)
end

View File

@@ -0,0 +1,90 @@
local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local guide = require 'parser.guide'
local await = require 'await'
local skipCheckClass = {
['unknown'] = true,
['any'] = true,
['table'] = true,
}
---@async
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
---@async
local function checkUndefinedField(src)
await.delay()
if vm.hasDef(src) then
return
end
local node = src.node
if node then
local ok
for view in vm.getInfer(node):eachView(uri) do
if skipCheckClass[view] then
return
end
ok = true
end
if not ok then
return
end
end
local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src))
if src.type == 'getfield' and src.field then
callback {
start = src.field.start,
finish = src.field.finish,
message = message,
}
elseif src.type == 'getmethod' and src.method then
callback {
start = src.method.start,
finish = src.method.finish,
message = message,
}
end
end
---@async
local function checkUndefinedFieldByIndexEnum(src)
await.delay()
local isEnum = false
for _, node in ipairs(vm.compileNode(src.node)) do
local docs = node.bindDocs
if docs then
for _, doc in ipairs(docs) do
if doc.type == "doc.enum" then
isEnum = true
break
end
end
end
end
if not isEnum then
return
end
if vm.hasDef(src) then
return
end
local keyName = guide.getKeyName(src)
if not keyName then
return
end
local message = lang.script('DIAG_UNDEF_FIELD', guide.getKeyName(src))
callback {
start = src.index.start,
finish = src.index.finish,
message = message,
}
end
guide.eachSourceType(ast.ast, 'getfield', checkUndefinedField)
guide.eachSourceType(ast.ast, 'getmethod', checkUndefinedField)
guide.eachSourceType(ast.ast, 'getindex', checkUndefinedFieldByIndexEnum)
end

View File

@@ -0,0 +1,37 @@
local files = require 'files'
local vm = require 'vm'
local lang = require 'language'
local guide = require 'parser.guide'
local requireLike = {
['include'] = true,
['import'] = true,
['require'] = true,
['load'] = true,
}
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
-- 遍历全局变量,检查所有没有 set 模式的全局变量
guide.eachSourceType(state.ast, 'getglobal', function (src) ---@async
if vm.isUndefinedGlobal(src) then
local key = src[1]
local message = lang.script('DIAG_UNDEF_GLOBAL', key)
if requireLike[key:lower()] then
message = ('%s(%s)'):format(message, lang.script('DIAG_REQUIRE_LIKE', key))
end
callback {
start = src.start,
finish = src.finish,
message = message,
undefinedGlobal = src[1]
}
end
end)
end

View File

@@ -0,0 +1,31 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.cast' and doc.name then
await.delay()
local defs = vm.getDefs(doc.name)
local loc = defs[1]
if not loc then
callback {
start = doc.name.start,
finish = doc.name.finish,
message = lang.script('DIAG_UNKNOWN_CAST_VARIABLE', doc.name[1])
}
end
end
end
end

View File

@@ -0,0 +1,31 @@
local files = require 'files'
local lang = require 'language'
local diag = require 'proto.diagnostic'
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.diagnostic' then
if doc.names then
for _, nameUnit in ipairs(doc.names) do
local code = nameUnit[1]
if not diag.getDiagAndErrNameMap()[code] then
callback {
start = nameUnit.start,
finish = nameUnit.finish,
message = lang.script('DIAG_UNKNOWN_DIAG_CODE', code),
}
end
end
end
end
end
end

View File

@@ -0,0 +1,33 @@
local files = require 'files'
local lang = require 'language'
local vm = require 'vm'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if not state.ast.docs then
return
end
for _, doc in ipairs(state.ast.docs) do
if doc.type == 'doc.operator' then
local op = doc.op
if op then
local opName = op[1]
if not vm.OP_BINARY_MAP[opName]
and not vm.OP_UNARY_MAP[opName]
and not vm.OP_OTHER_MAP[opName] then
callback {
start = doc.op.start,
finish = doc.op.finish,
message = lang.script('DIAG_UNKNOWN_OPERATOR', opName)
}
end
end
end
end
end

View File

@@ -0,0 +1,29 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceType(state.ast, 'call', function (source)
await.delay()
local currentFunc = guide.getParentFunction(source)
if currentFunc and source.node.special == 'assert' and source.args[1] then
local argNode = vm.compileNode(source.args[1])
if argNode:alwaysTruthy() then
callback {
start = source.node.start,
finish = source.node.finish,
message = lang.script('DIAG_UNNECESSARY_ASSERT'),
}
end
end
end)
end

View File

@@ -0,0 +1,84 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local lang = require 'language'
local await = require 'await'
local define = require 'proto.define'
---@param source parser.object
---@return boolean
local function allLiteral(source)
local result = true
guide.eachSource(source, function (src)
if src.type ~= 'unary'
and src.type ~= 'binary'
and not guide.isLiteral(src) then
result = false
return false
end
end)
return result
end
---@param block parser.object
---@return boolean
local function hasReturn(block)
if block.hasReturn or block.hasExit then
return true
end
if block.type == 'if' then
local hasElse
for _, subBlock in ipairs(block) do
if not hasReturn(subBlock) then
return false
end
if subBlock.type == 'elseblock' then
hasElse = true
end
end
return hasElse == true
else
if block.type == 'while' then
if vm.testCondition(block.filter)
and not block.breaks
and allLiteral(block.filter) then
return true
end
end
for _, action in ipairs(block) do
if guide.isBlockType(action) then
if hasReturn(action) then
return true
end
end
end
end
return false
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
---@async
guide.eachSourceTypes(state.ast, {'main', 'function'}, function (source)
await.delay()
for i, action in ipairs(source) do
if guide.isBlockType(action)
and hasReturn(action) then
if i < #source then
callback {
start = source[i+1].start,
finish = source[#source].finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script('DIAG_UNREACHABLE_CODE'),
}
end
return
end
end
end)
end

View File

@@ -0,0 +1,129 @@
local files = require 'files'
local guide = require 'parser.guide'
local vm = require 'vm'
local define = require 'proto.define'
local lang = require 'language'
local await = require 'await'
local client = require 'client'
local util = require 'utility'
local function isToBeClosed(source)
if not source.attrs then
return false
end
for _, attr in ipairs(source.attrs) do
if attr[1] == 'close' then
return true
end
end
return false
end
---@param source parser.object?
---@return boolean
local function isValidFunction(source)
if not source then
return false
end
if source.type == 'main' then
return false
end
local parent = source.parent
if not parent then
return false
end
if parent.type ~= 'local'
and parent.type ~= 'setlocal' then
return false
end
if isToBeClosed(parent) then
return false
end
return true
end
---@async
local function collect(ast, white, roots, links)
---@async
guide.eachSourceType(ast, 'function', function (src)
await.delay()
if not isValidFunction(src) then
return
end
local loc = src.parent
if loc.type == 'setlocal' then
loc = loc.node
end
for _, ref in ipairs(loc.ref or {}) do
if ref.type == 'getlocal' then
local func = guide.getParentFunction(ref)
if not func or not isValidFunction(func) or roots[func] then
roots[src] = true
return
end
if not links[func] then
links[func] = {}
end
links[func][#links[func]+1] = src
end
end
white[src] = true
end)
return white, roots, links
end
local function turnBlack(source, black, white, links)
if black[source] then
return
end
black[source] = true
white[source] = nil
for _, link in ipairs(links[source] or {}) do
turnBlack(link, black, white, links)
end
end
---@async
return function (uri, callback)
local state = files.getState(uri)
if not state then
return
end
if vm.isMetaFile(uri) then
return
end
local black = {}
local white = {}
local roots = {}
local links = {}
collect(state.ast, white, roots, links)
for source in pairs(roots) do
turnBlack(source, black, white, links)
end
local tagSupports = client.getAbility('textDocument.completion.completionItem.tagSupport.valueSet')
local supportUnnecessary = tagSupports and util.arrayHas(tagSupports, define.DiagnosticTag.Unnecessary)
for source in pairs(white) do
if supportUnnecessary then
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_UNUSED_FUNCTION,
}
else
callback {
start = source.keyword[1],
finish = source.keyword[2],
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script.DIAG_UNUSED_FUNCTION,
}
end
end
end

View File

@@ -0,0 +1,22 @@
local files = require 'files'
local guide = require 'parser.guide'
local define = require 'proto.define'
local lang = require 'language'
return function (uri, callback)
local ast = files.getState(uri)
if not ast then
return
end
guide.eachSourceType(ast.ast, 'label', function (source)
if not source.ref then
callback {
start = source.start,
finish = source.finish,
tags = { define.DiagnosticTag.Unnecessary },
message = lang.script('DIAG_UNUSED_LABEL', source[1]),
}
end
end)
end

Some files were not shown because too many files have changed in this diff Show More