From 866cb79f65c844b3fcfa99d2caa4bf19930dbc6d Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Tue, 20 Jan 2015 11:35:16 +0100 Subject: luajit already has yieldable (x)pcall, add tests for code from compat52 --- compat53.lua | 318 ++++++++++++++++++++++++++++++----------------------------- 1 file changed, 162 insertions(+), 156 deletions(-) (limited to 'compat53.lua') diff --git a/compat53.lua b/compat53.lua index 5f88e8b..67a8c8f 100644 --- a/compat53.lua +++ b/compat53.lua @@ -336,12 +336,12 @@ if lua_version < "5.3" then function debug.setuservalue(obj, value) if type(obj) ~= "userdata" then error("bad argument #1 to 'setuservalue' (userdata expected, got ".. - type(obj)..")", 0) + type(obj)..")", 2) end if value == nil then value = _G end if type(value) ~= "table" then error("bad argument #2 to 'setuservalue' (table expected, got ".. - type(value)..")", 0) + type(value)..")", 2) end return debug_setfenv(obj, value) end @@ -366,96 +366,98 @@ if lua_version < "5.3" then end end -- not luajit with compat52 enabled - local debug_getinfo = debug.getinfo - local function calculate_trace_level(co, level) - if level ~= nil then - for out = 1, 1/0 do - local info = (co==nil) and debug_getinfo(out, "") or debug_getinfo(co, out, "") - if info == nil then - local max = out-1 - if level <= max then - return level + if not is_luajit then + local debug_getinfo = debug.getinfo + local function calculate_trace_level(co, level) + if level ~= nil then + for out = 1, 1/0 do + local info = (co==nil) and debug_getinfo(out, "") or debug_getinfo(co, out, "") + if info == nil then + local max = out-1 + if level <= max then + return level + end + return nil, level-max end - return nil, level-max end end + return 1 end - return 1 - end - local stack_pattern = "\nstack traceback:" - local stack_replace = "" - local debug_traceback = debug.traceback - function debug.traceback(co, msg, level) - local lvl - local nilmsg - if type(co) ~= "thread" then - co, msg, level = coroutine_running(), co, msg - end - if msg == nil then - msg = "" - nilmsg = true - elseif type(msg) ~= "string" then - return msg - end - if co == nil then - msg = debug_traceback(msg, level or 1) - else - local xpco = xpcall_running[co] - if xpco ~= nil then - lvl, level = calculate_trace_level(xpco, level) - if lvl then - msg = debug_traceback(xpco, msg, lvl) - else - msg = msg..stack_pattern - end - lvl, level = calculate_trace_level(co, level) - if lvl then - local trace = debug_traceback(co, "", lvl) - msg = msg..trace:gsub(stack_pattern, stack_replace) - end + local stack_pattern = "\nstack traceback:" + local stack_replace = "" + local debug_traceback = debug.traceback + function debug.traceback(co, msg, level) + local lvl + local nilmsg + if type(co) ~= "thread" then + co, msg, level = coroutine_running(), co, msg + end + if msg == nil then + msg = "" + nilmsg = true + elseif type(msg) ~= "string" then + return msg + end + if co == nil then + msg = debug_traceback(msg, level or 1) else - co = pcall_callOf[co] or co - lvl, level = calculate_trace_level(co, level) - if lvl then - msg = debug_traceback(co, msg, lvl) + local xpco = xpcall_running[co] + if xpco ~= nil then + lvl, level = calculate_trace_level(xpco, level) + if lvl then + msg = debug_traceback(xpco, msg, lvl) + else + msg = msg..stack_pattern + end + lvl, level = calculate_trace_level(co, level) + if lvl then + local trace = debug_traceback(co, "", lvl) + msg = msg..trace:gsub(stack_pattern, stack_replace) + end else - msg = msg..stack_pattern - end - end - co = pcall_previous[co] - while co ~= nil do - lvl, level = calculate_trace_level(co, level) - if lvl then - local trace = debug_traceback(co, "", lvl) - msg = msg..trace:gsub(stack_pattern, stack_replace) + co = pcall_callOf[co] or co + lvl, level = calculate_trace_level(co, level) + if lvl then + msg = debug_traceback(co, msg, lvl) + else + msg = msg..stack_pattern + end end co = pcall_previous[co] + while co ~= nil do + lvl, level = calculate_trace_level(co, level) + if lvl then + local trace = debug_traceback(co, "", lvl) + msg = msg..trace:gsub(stack_pattern, stack_replace) + end + co = pcall_previous[co] + end + end + if nilmsg then + msg = msg:gsub("^\n", "") end + msg = msg:gsub("\n\t%(tail call%): %?", "\000") + msg = msg:gsub("\n\t%.%.%.\n", "\001\n") + msg = msg:gsub("\n\t%.%.%.$", "\001") + msg = msg:gsub("(%z+)\001(%z+)", function(some, other) + return "\n\t(..."..#some+#other.."+ tail call(s)...)" + end) + msg = msg:gsub("\001(%z+)", function(zeros) + return "\n\t(..."..#zeros.."+ tail call(s)...)" + end) + msg = msg:gsub("(%z+)\001", function(zeros) + return "\n\t(..."..#zeros.."+ tail call(s)...)" + end) + msg = msg:gsub("%z+", function(zeros) + return "\n\t(..."..#zeros.." tail call(s)...)" + end) + msg = msg:gsub("\001", function(zeros) + return "\n\t..." + end) + return msg end - if nilmsg then - msg = msg:gsub("^\n", "") - end - msg = msg:gsub("\n\t%(tail call%): %?", "\000") - msg = msg:gsub("\n\t%.%.%.\n", "\001\n") - msg = msg:gsub("\n\t%.%.%.$", "\001") - msg = msg:gsub("(%z+)\001(%z+)", function(some, other) - return "\n\t(..."..#some+#other.."+ tail call(s)...)" - end) - msg = msg:gsub("\001(%z+)", function(zeros) - return "\n\t(..."..#zeros.."+ tail call(s)...)" - end) - msg = msg:gsub("(%z+)\001", function(zeros) - return "\n\t(..."..#zeros.."+ tail call(s)...)" - end) - msg = msg:gsub("%z+", function(zeros) - return "\n\t(..."..#zeros.." tail call(s)...)" - end) - msg = msg:gsub("\001", function(zeros) - return "\n\t..." - end) - return msg - end + end -- is not luajit end -- debug table available @@ -501,7 +503,7 @@ if lua_version < "5.3" then local ld_type = type(ld) if ld_type ~= "function" then error("bad argument #1 to 'load' (function expected, got ".. - ld_type..")", 0) + ld_type..")", 2) end if mode ~= "bt" then local checked, merr = false, nil @@ -564,7 +566,7 @@ if lua_version < "5.3" then function rawlen(v) local t = type(v) if t ~= "string" and t ~= "table" then - error("bad argument #1 to 'rawlen' (table or string expected)", 0) + error("bad argument #1 to 'rawlen' (table or string expected)", 2) end return #v end @@ -622,92 +624,94 @@ if lua_version < "5.3" then local coroutine_yield = coroutine.yield function coroutine.yield(...) - local co = coroutine_running() - if co then + local co, flag = coroutine_running() + if co and not flag then return coroutine_yield(...) else error("attempt to yield from outside a coroutine", 0) end end - local coroutine_resume = coroutine.resume - function coroutine.resume(co, ...) - if co == main_coroutine then - return false, "cannot resume non-suspended coroutine" - else - return coroutine_resume(co, ...) + if not is_luajit then + local coroutine_resume = coroutine.resume + function coroutine.resume(co, ...) + if co == main_coroutine then + return false, "cannot resume non-suspended coroutine" + else + return coroutine_resume(co, ...) + end end - end - local coroutine_status = coroutine.status - function coroutine.status(co) - local notmain = coroutine_running() - if co == main_coroutine then - return notmain and "normal" or "running" - else - return coroutine_status(co) + local coroutine_status = coroutine.status + function coroutine.status(co) + local notmain = coroutine_running() + if co == main_coroutine then + return notmain and "normal" or "running" + else + return coroutine_status(co) + end end - end - local function pcall_results(current, call, success, ...) - if coroutine_status(call) == "suspended" then - return pcall_results(current, call, coroutine_resume(call, coroutine_yield(...))) + local function pcall_results(current, call, success, ...) + if coroutine_status(call) == "suspended" then + return pcall_results(current, call, coroutine_resume(call, coroutine_yield(...))) + end + if pcall_previous then + pcall_previous[call] = nil + local main = pcall_mainOf[call] + if main == current then current = nil end + pcall_callOf[main] = current + end + pcall_mainOf[call] = nil + return success, ... end - if pcall_previous then - pcall_previous[call] = nil - local main = pcall_mainOf[call] - if main == current then current = nil end - pcall_callOf[main] = current + local function pcall_exec(current, call, ...) + local main = pcall_mainOf[current] or current + pcall_mainOf[call] = main + if pcall_previous then + pcall_previous[call] = current + pcall_callOf[main] = call + end + return pcall_results(current, call, coroutine_resume(call, ...)) end - pcall_mainOf[call] = nil - return success, ... - end - local function pcall_exec(current, call, ...) - local main = pcall_mainOf[current] or current - pcall_mainOf[call] = main - if pcall_previous then - pcall_previous[call] = current - pcall_callOf[main] = call + local coroutine_create52 = coroutine.create + local function pcall_coroutine(func) + if type(func) ~= "function" then + local callable = func + func = function (...) return callable(...) end + end + return coroutine_create52(func) end - return pcall_results(current, call, coroutine_resume(call, ...)) - end - local coroutine_create52 = coroutine.create - local function pcall_coroutine(func) - if type(func) ~= "function" then - local callable = func - func = function (...) return callable(...) end + function pcall(func, ...) + local current = coroutine_running() + if not current then return _pcall(func, ...) end + return pcall_exec(current, pcall_coroutine(func), ...) end - return coroutine_create52(func) - end - function pcall(func, ...) - local current = coroutine_running() - if not current then return _pcall(func, ...) end - return pcall_exec(current, pcall_coroutine(func), ...) - end - local function xpcall_catch(current, call, msgh, success, ...) - if not success then - xpcall_running[current] = call - local ok, result = _pcall(msgh, ...) - xpcall_running[current] = nil - if not ok then - return false, "error in error handling ("..tostring(result)..")" + local function xpcall_catch(current, call, msgh, success, ...) + if not success then + xpcall_running[current] = call + local ok, result = _pcall(msgh, ...) + xpcall_running[current] = nil + if not ok then + return false, "error in error handling ("..tostring(result)..")" + end + return false, result end - return false, result + return true, ... end - return true, ... - end - local _xpcall = xpcall - local _unpack = unpack - function xpcall(f, msgh, ...) - local current = coroutine_running() - if not current then - local args, n = { ... }, select('#', ...) - return _xpcall(function() return f(_unpack(args, 1, n)) end, msgh) + local _xpcall = xpcall + local _unpack = unpack + function xpcall(f, msgh, ...) + local current = coroutine_running() + if not current then + local args, n = { ... }, select('#', ...) + return _xpcall(function() return f(_unpack(args, 1, n)) end, msgh) + end + local call = pcall_coroutine(f) + return xpcall_catch(current, call, msgh, pcall_exec(current, call, ...)) end - local call = pcall_coroutine(f) - return xpcall_catch(current, call, msgh, pcall_exec(current, call, ...)) - end + end -- not luajit if not is_luajit then @@ -806,6 +810,7 @@ if lua_version < "5.3" then return addqt[c] or string_format("\\%03d", c:byte()) end + local _unpack = unpack function string.format(fmt, ...) local args, n = { ... }, select('#', ...) local i = 0 @@ -846,12 +851,13 @@ if lua_version < "5.3" then if var_1 == nil then if st.doclose then st.f:close() end if (...) ~= nil then - error((...), 0) + error((...), 2) end end return var_1, ... end + local _unpack = unpack function lines_iterator(st) return helper(st, st.f:read(_unpack(st, 1, st.n))) end @@ -864,14 +870,14 @@ if lua_version < "5.3" then local doclose, file, msg if fname ~= nil then doclose, file, msg = true, io_open(fname, "r") - if not file then error(msg, 0) end + if not file then error(msg, 2) end else doclose, file = false, io_input() end local st = { f=file, doclose=doclose, n=select('#', ...), ... } for i = 1, st.n do if type(st[i]) ~= "number" and not valid_format[st[i]] then - error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 0) + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) end end return lines_iterator, st @@ -894,12 +900,12 @@ if lua_version < "5.3" then file_meta.__index.lines = function(self, ...) if io_type(self) == "closed file" then - error("attempt to use a closed file", 0) + error("attempt to use a closed file", 2) end local st = { f=self, doclose=false, n=select('#', ...), ... } for i = 1, st.n do if type(st[i]) ~= "number" and not valid_format[st[i]] then - error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 0) + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) end end return lines_iterator, st -- cgit v1.2.3-55-g6feb