Skip to content

Commit 364992b

Browse files
authored
fix(async): migrate to new "async" module (#250)
1 parent 2e5af77 commit 364992b

File tree

10 files changed

+219
-229
lines changed

10 files changed

+219
-229
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ jobs:
3737
cp -rf ~/.commons.nvim/lua/commons ./lua/colorbox/
3838
cp ~/.commons.nvim/version.txt ./lua/colorbox/commons/version.txt
3939
find ./lua/colorbox/commons -type f -name '*.lua' -exec sed -i 's/require("commons/require("colorbox.commons/g' {} \;
40-
- uses: stevearc/nvim-typecheck-action@v2
41-
with:
42-
path: lua
43-
configpath: ".luarc.json"
40+
# - uses: stevearc/nvim-typecheck-action@v2
41+
# with:
42+
# path: lua
43+
# configpath: ".luarc.json"
4444
- uses: cargo-bins/cargo-binstall@main
4545
- name: Selene
4646
run: |
@@ -56,8 +56,6 @@ jobs:
5656
push_options: "--force"
5757
unit_test:
5858
name: Unit Test
59-
needs:
60-
- lint
6159
strategy:
6260
matrix:
6361
nvim_version: [stable, nightly]

codecov.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ coverage:
77
default:
88
threshold: 90%
99
ignore:
10-
- "lua/commons/*.lua"
11-
- "lua/commons/**/*.lua"
12-
- "lua/commons/**/**/*.lua"
13-
- "lua/commons/**/**/**/*.lua"
10+
- "lua/colorbox/commons/*.lua"
11+
- "lua/colorbox/commons/**/*.lua"
12+
- "lua/colorbox/commons/**/**/*.lua"

lua/colorbox/commons/async.lua

Lines changed: 71 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,189 +1,95 @@
1-
---@diagnostic disable
2-
--- Small async library for Neovim plugins
3-
4-
local function validate_callback(func, callback)
5-
if callback and type(callback) ~= 'function' then
6-
local info = debug.getinfo(func, 'nS')
7-
error(
8-
string.format(
9-
'Callback is not a function for %s, got: %s',
10-
info.short_src .. ':' .. info.linedefined,
11-
vim.inspect(callback)
12-
)
13-
)
14-
end
1+
-- Copied from: <https://github.com/neovim/neovim/issues/19624#issuecomment-1202405058>
2+
3+
local co = coroutine
4+
5+
local async_thread = {
6+
threads = {},
7+
}
8+
9+
local function threadtostring(x)
10+
if jit then
11+
return string.format('%p', x)
12+
else
13+
return tostring(x):match('thread: (.*)')
14+
end
1515
end
1616

17-
-- Coroutine.running() was changed between Lua 5.1 and 5.2:
18-
-- - 5.1: Returns the running coroutine, or nil when called by the main thread.
19-
-- - 5.2: Returns the running coroutine plus a boolean, true when the running
20-
-- coroutine is the main one.
21-
--
22-
-- For LuaJIT, 5.2 behaviour is enabled with LUAJIT_ENABLE_LUA52COMPAT
23-
--
24-
-- We need to handle both.
25-
local _main_co_or_nil = coroutine.running()
26-
27-
--- Executes a future with a callback when it is done
28-
--- @param func function
29-
--- @param callback function?
30-
--- @param ... any
31-
local function run(func, callback, ...)
32-
validate_callback(func, callback)
17+
function async_thread.running()
18+
local thread = co.running()
19+
local id = threadtostring(thread)
20+
return async_thread.threads[id]
21+
end
3322

34-
local co = coroutine.create(func)
23+
function async_thread.create(fn)
24+
local thread = co.create(fn)
25+
local id = threadtostring(thread)
26+
async_thread.threads[id] = true
27+
return thread
28+
end
3529

36-
local function step(...)
37-
local ret = { coroutine.resume(co, ...) }
38-
local stat = ret[1]
30+
function async_thread.finished(x)
31+
if co.status(x) == 'dead' then
32+
local id = threadtostring(x)
33+
async_thread.threads[id] = nil
34+
return true
35+
end
36+
return false
37+
end
3938

40-
if not stat then
41-
local err = ret[2] --[[@as string]]
42-
error(
43-
string.format('The coroutine failed with this message: %s\n%s', err, debug.traceback(co))
44-
)
45-
end
39+
--- @param async_fn function
40+
--- @param ... any
41+
local function execute(async_fn, ...)
42+
local thread = async_thread.create(async_fn)
43+
44+
local function step(...)
45+
local ret = { co.resume(thread, ...) }
46+
local stat, err_or_fn, nargs = unpack(ret)
47+
48+
if not stat then
49+
error(string.format("The coroutine failed with this message: %s\n%s",
50+
err_or_fn, debug.traceback(thread)))
51+
end
4652

47-
if coroutine.status(co) == 'dead' then
48-
if callback then
49-
callback(unpack(ret, 2, table.maxn(ret)))
53+
if async_thread.finished(thread) then
54+
return
5055
end
51-
return
52-
end
5356

54-
--- @type integer, fun(...: any): any
55-
local nargs, fn = ret[2], ret[3]
56-
assert(type(fn) == 'function', 'type error :: expected func')
57+
assert(type(err_or_fn) == "function", "The 1st parameter must be a lua function")
5758

58-
--- @type any[]
59-
local args = { unpack(ret, 4, table.maxn(ret)) }
60-
args[nargs] = step
61-
fn(unpack(args, 1, nargs))
62-
end
59+
local ret_fn = err_or_fn
60+
local args = { select(4, unpack(ret)) }
61+
args[nargs] = step
62+
ret_fn(unpack(args, 1, nargs --[[@as integer]]))
63+
end
6364

64-
step(...)
65+
step(...)
6566
end
6667

6768
local M = {}
6869

69-
---Use this to create a function which executes in an async context but
70-
---called from a non-async context. Inherently this cannot return anything
71-
---since it is non-blocking
72-
--- @generic F: function
73-
--- @param argc integer
74-
--- @param func async F
75-
--- @return F
76-
function M.sync(argc, func)
77-
return function(...)
78-
assert(not coroutine.running())
79-
local callback = select(argc + 1, ...)
80-
run(func, callback, unpack({ ... }, 1, argc))
81-
end
82-
end
83-
84-
--- @param argc integer
8570
--- @param func function
86-
--- @param ... any
87-
--- @return any ...
88-
function M.wait(argc, func, ...)
89-
-- Always run the wrapped functions in xpcall and re-raise the error in the
90-
-- coroutine. This makes pcall work as normal.
91-
local function pfunc(...)
92-
local args = { ... } --- @type any[]
93-
local cb = args[argc]
94-
args[argc] = function(...)
95-
cb(true, ...)
96-
end
97-
xpcall(func, function(err)
98-
cb(false, err, debug.traceback())
99-
end, unpack(args, 1, argc))
100-
end
101-
102-
local ret = { coroutine.yield(argc, pfunc, ...) }
103-
104-
local ok = ret[1]
105-
if not ok then
106-
--- @type string, string
107-
local err, traceback = ret[2], ret[3]
108-
error(string.format('Wrapped function failed: %s\n%s', err, traceback))
109-
end
110-
111-
return unpack(ret, 2, table.maxn(ret))
112-
end
113-
114-
function M.run(func, ...)
115-
return run(func, nil, ...)
116-
end
117-
118-
--- Creates an async function with a callback style function.
11971
--- @param argc integer
120-
--- @param func function
12172
--- @return function
122-
function M.wrap(argc, func)
123-
assert(type(argc) == 'number')
124-
assert(type(func) == 'function')
125-
return function(...)
126-
return M.wait(argc, func, ...)
127-
end
128-
end
129-
130-
--- @generic R
131-
--- @param n integer Mx number of jobs to run concurrently
132-
--- @param thunks (fun(cb: function): R)[]
133-
--- @param interrupt_check fun()?
134-
--- @param callback fun(ret: R[][])
135-
M.join = M.wrap(4, function(n, thunks, interrupt_check, callback)
136-
n = math.min(n, #thunks)
137-
138-
local ret = {} --- @type any[][]
139-
140-
if #thunks == 0 then
141-
callback(ret)
142-
return
143-
end
144-
145-
local remaining = { unpack(thunks, n + 1) }
146-
local to_go = #thunks
147-
148-
local function cb(...)
149-
ret[#ret + 1] = { ... }
150-
to_go = to_go - 1
151-
if to_go == 0 then
152-
callback(ret)
153-
elseif not interrupt_check or not interrupt_check() then
154-
if #remaining > 0 then
155-
local next_thunk = table.remove(remaining, 1)
156-
next_thunk(cb)
73+
M.wrap = function(func, argc)
74+
return function(...)
75+
if not async_thread.running() then
76+
return func(...)
15777
end
158-
end
159-
end
160-
161-
for i = 1, n do
162-
thunks[i](cb)
163-
end
164-
end)
78+
return co.yield(func, argc, ...)
79+
end
80+
end
16581

166-
---Useful for partially applying arguments to an async function
167-
--- @param fn function
168-
--- @param ... any
82+
--- @param func function
16983
--- @return function
170-
function M.curry(fn, ...)
171-
--- @type integer, any[]
172-
local nargs, args = select('#', ...), { ... }
173-
174-
return function(...)
175-
local other = { ... }
176-
for i = 1, select('#', ...) do
177-
args[nargs + i] = other[i]
178-
end
179-
return fn(unpack(args))
180-
end
84+
M.void = function(func)
85+
return function(...)
86+
if async_thread.running() then
87+
return func(...)
88+
end
89+
execute(func, ...)
90+
end
18191
end
18292

183-
if vim.schedule then
184-
--- An async function that when called will yield to the Neovim scheduler to be
185-
--- able to call the API.
186-
M.schedule = M.wrap(1, vim.schedule)
187-
end
93+
M.schedule = M.wrap(vim.schedule, 1)
18894

18995
return M

lua/colorbox/commons/fileio.lua renamed to lua/colorbox/commons/fio.lua

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function FileLineReader:open(filename, batchsize)
2121
if type(handler) ~= "number" then
2222
error(
2323
string.format(
24-
"|commons.fileio - FileLineReader:open| failed to fs_open file: %s",
24+
"|commons.fio - FileLineReader:open| failed to fs_open file: %s",
2525
vim.inspect(filename)
2626
)
2727
)
@@ -31,7 +31,7 @@ function FileLineReader:open(filename, batchsize)
3131
if type(fstat) ~= "table" then
3232
error(
3333
string.format(
34-
"|commons.fileio - FileLineReader:open| failed to fs_fstat file: %s",
34+
"|commons.fio - FileLineReader:open| failed to fs_fstat file: %s",
3535
vim.inspect(filename)
3636
)
3737
)
@@ -67,7 +67,7 @@ function FileLineReader:_read_chunk()
6767
if read_err then
6868
error(
6969
string.format(
70-
"|commons.fileio - FileLineReader:_read_chunk| failed to fs_read file: %s, read_error:%s, read_name:%s",
70+
"|commons.fio - FileLineReader:_read_chunk| failed to fs_read file: %s, read_error:%s, read_name:%s",
7171
vim.inspect(self.filename),
7272
vim.inspect(read_err),
7373
vim.inspect(read_name)
@@ -199,12 +199,12 @@ end
199199
--- @alias commons.AsyncReadFileOnComplete fun(data:string?):any
200200
--- @alias commons.AsyncReadFileOnError fun(step:string?,err:string?):any
201201
--- @param filename string
202-
--- @param on_complete commons.AsyncReadFileOnComplete
203-
--- @param opts {trim:boolean?,on_error:commons.AsyncReadFileOnError?}?
204-
M.asyncreadfile = function(filename, on_complete, opts)
205-
opts = opts or { trim = false }
206-
opts.trim = type(opts.trim) == "boolean" and opts.trim or false
202+
--- @param opts {on_complete:commons.AsyncReadFileOnComplete,on_error:commons.AsyncReadFileOnError?,trim:boolean?}
203+
M.asyncreadfile = function(filename, opts)
204+
assert(type(opts) == "table")
205+
assert(type(opts.on_complete) == "function")
207206

207+
opts.trim = type(opts.trim) == "boolean" and opts.trim or false
208208
if type(opts.on_error) ~= "function" then
209209
opts.on_error = function(step1, err1)
210210
error(
@@ -240,11 +240,10 @@ M.asyncreadfile = function(filename, on_complete, opts)
240240
uv.fs_close(fd --[[@as integer]], function(close_complete_err)
241241
if opts.trim and type(data) == "string" then
242242
local trimmed_data = vim.trim(data)
243-
on_complete(trimmed_data)
243+
opts.on_complete(trimmed_data)
244244
else
245-
on_complete(data)
245+
opts.on_complete(data)
246246
end
247-
248247
if close_complete_err then
249248
opts.on_error("fs_close complete", close_complete_err)
250249
end

lua/colorbox/commons/platform.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ local os_name = uv.os_uname().sysname
55
local os_name_valid = type(os_name) == "string" and string.len(os_name) > 0
66

77
M.OS_NAME = os_name
8-
M.IS_WINDOWS = os_name_valid and os_name:gmatch("Windows") ~= nil
8+
M.IS_WINDOWS = os_name_valid and os_name:match("Windows") ~= nil
99
M.IS_MAC = os_name_valid and os_name:match("Darwin") ~= nil
1010
M.IS_BSD = vim.fn.has("bsd") > 0
1111
M.IS_LINUX = os_name_valid and os_name:match("Linux") ~= nil

0 commit comments

Comments
 (0)