From 1806cdc571215fa82d6ffde3aa75204861aa6cb5 Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Sun, 18 Jan 2015 18:36:21 +0100 Subject: add table.sort, add code from lua-compat-5.2 --- compat53.lua | 662 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 658 insertions(+), 4 deletions(-) (limited to 'compat53.lua') diff --git a/compat53.lua b/compat53.lua index 761d554..03158f0 100644 --- a/compat53.lua +++ b/compat53.lua @@ -214,7 +214,70 @@ if lua_version < "5.3" then end end - -- TODO: table.sort + do + local function pivot(list, cmp, a, b) + local m = b - a + if m > 2 then + local c = a + (m-m%2)/2 + local x, y, z = list[a], list[b], list[c] + if not cmp(x, y) then + x, y, a, b = y, x, b, a + end + if not cmp(y, z) then + y, z, b, c = z, y, c, b + end + if not cmp(x, y) then + x, y, a, b = y, x, b, a + end + return b, y + else + return b, list[b] + end + end + + local function lt_cmp(a, b) + return a < b + end + + local function qsort(list, cmp, b, e) + if b < e then + local i, j, k, val = b, e, pivot(list, cmp, b, e) + while i < j do + while i < j and cmp(list[i], val) do + i = i + 1 + end + while i < j and not cmp(list[j], val) do + j = j - 1 + end + if i < j then + list[i], list[j] = list[j], list[i] + if i == k then k = j end -- update pivot position + i, j = i+1, j-1 + end + end + if i ~= k and not cmp(list[i], val) then + list[i], list[k] = val, list[i] + k = i -- update pivot position + end + qsort(list, cmp, b, i == k and i-1 or i) + return qsort(list, cmp, i+1, e) + end + end + + local table_sort = table.sort + function table.sort(list, cmp) + local mt = gmt(list) + local has_mt = type(mt) == "table" + local has_len = has_mt and type(mt.__len) == "function" + if has_len then + cmp = cmp or lt_cmp + local len = mt.__len(list) + return qsort(list, cmp, 1, len) + else + return table_sort(list, cmp) + end + end + end local table_unpack = lua_version == "5.1" and unpack or table.unpack local function unpack_helper(list, i, j, ...) @@ -238,16 +301,607 @@ if lua_version < "5.3" then end + + -- bring Lua 5.1 (and LuaJIT) up to speed with Lua 5.2 if lua_version == "5.1" then -- detect LuaJIT (including LUAJIT_ENABLE_LUA52COMPAT compilation flag) local is_luajit = (string.dump(function() end) or ""):sub(1, 3) == "\027LJ" local is_luajit52 = is_luajit and #setmetatable({}, { __len = function() return 1 end }) == 1 - -- TODO: add functions from lua-compat-5.2 - end + -- table that maps each running coroutine to the coroutine that resumed it + -- this is used to build complete tracebacks when "coroutine-friendly" pcall + -- is used. + local pcall_previous = {} + local pcall_callOf = {} + local xpcall_running = {} + local coroutine_running = coroutine.running + + -- handle debug functions + if type(debug) == "table" then + + if not is_luajit52 then + local _G, package = _G, package + local debug_setfenv = debug.setfenv + function debug.setuservalue(obj, value) + if type(obj) ~= "userdata" then + error("bad argument #1 to 'setuservalue' (userdata expected, got ".. + type(obj)..")", 0) + end + if value == nil then value = _G end + if type(value) ~= "table" then + error("bad argument #2 to 'setuservalue' (table expected, got ".. + type(value)..")", 0) + end + return debug_setfenv(obj, value) + end + + local debug_getfenv = debug.getfenv + function debug.getuservalue(obj) + if type(obj) ~= "userdata" then + return nil + else + local v = debug_getfenv(obj) + if v == _G or v == package then + return nil + end + return v + end + end + + local debug_setmetatable = debug.setmetatable + function debug.setmetatable(value, tab) + debug_setmetatable(value, tab) + return value + end + end -- not luajit with compat52 enabled + + local debug_getinfo = debug.getinfo + local function calculate_trace_level(co, level) + if level ~= nil then + for out = 1, 1/0 do + local info = (co==nil) and debug_getinfo(out, "") or debug_getinfo(co, out, "") + if info == nil then + local max = out-1 + if level <= max then + return level + end + return nil, level-max + end + end + end + return 1 + end + + local stack_pattern = "\nstack traceback:" + local stack_replace = "" + local debug_traceback = debug.traceback + function debug.traceback(co, msg, level) + local lvl + local nilmsg + if type(co) ~= "thread" then + co, msg, level = coroutine_running(), co, msg + end + if msg == nil then + msg = "" + nilmsg = true + elseif type(msg) ~= "string" then + return msg + end + if co == nil then + msg = debug_traceback(msg, level or 1) + else + local xpco = xpcall_running[co] + if xpco ~= nil then + lvl, level = calculate_trace_level(xpco, level) + if lvl then + msg = debug_traceback(xpco, msg, lvl) + else + msg = msg..stack_pattern + end + lvl, level = calculate_trace_level(co, level) + if lvl then + local trace = debug_traceback(co, "", lvl) + msg = msg..trace:gsub(stack_pattern, stack_replace) + end + else + co = pcall_callOf[co] or co + lvl, level = calculate_trace_level(co, level) + if lvl then + msg = debug_traceback(co, msg, lvl) + else + msg = msg..stack_pattern + end + end + co = pcall_previous[co] + while co ~= nil do + lvl, level = calculate_trace_level(co, level) + if lvl then + local trace = debug_traceback(co, "", lvl) + msg = msg..trace:gsub(stack_pattern, stack_replace) + end + co = pcall_previous[co] + end + end + if nilmsg then + msg = msg:gsub("^\n", "") + end + msg = msg:gsub("\n\t%(tail call%): %?", "\000") + msg = msg:gsub("\n\t%.%.%.\n", "\001\n") + msg = msg:gsub("\n\t%.%.%.$", "\001") + msg = msg:gsub("(%z+)\001(%z+)", function(some, other) + return "\n\t(..."..#some+#other.."+ tail call(s)...)" + end) + msg = msg:gsub("\001(%z+)", function(zeros) + return "\n\t(..."..#zeros.."+ tail call(s)...)" + end) + msg = msg:gsub("(%z+)\001", function(zeros) + return "\n\t(..."..#zeros.."+ tail call(s)...)" + end) + msg = msg:gsub("%z+", function(zeros) + return "\n\t(..."..#zeros.." tail call(s)...)" + end) + msg = msg:gsub("\001", function(zeros) + return "\n\t..." + end) + return msg + end + end -- debug table available + + + if not is_luajit52 then + local _pairs = pairs + function pairs(t) + local mt = gmt(t) + if type(mt) == "table" and type(mt.__pairs) == "function" then + return mt.__pairs(t) + else + return _pairs(t) + end + end + end + + + if not is_luajit then + local function check_mode(mode, prefix) + local has = { text = false, binary = false } + for i = 1,#mode do + local c = mode:sub(i, i) + if c == "t" then has.text = true end + if c == "b" then has.binary = true end + end + local t = prefix:sub(1, 1) == "\27" and "binary" or "text" + if not has[t] then + return "attempt to load a "..t.." chunk (mode is '"..mode.."')" + end + end + + local setfenv = setfenv + local _load, _loadstring = load, loadstring + function load(ld, source, mode, env) + mode = mode or "bt" + local chunk, msg + if type( ld ) == "string" then + if mode ~= "bt" then + local merr = check_mode(mode, ld) + if merr then return nil, merr end + end + chunk, msg = _loadstring(ld, source) + else + local ld_type = type(ld) + if ld_type ~= "function" then + error("bad argument #1 to 'load' (function expected, got ".. + ld_type..")", 0) + end + if mode ~= "bt" then + local checked, merr = false, nil + local function checked_ld() + if checked then + return ld() + else + checked = true + local v = ld() + merr = check_mode(mode, v or "") + if merr then return nil end + return v + end + end + chunk, msg = _load(checked_ld, source) + if merr then return nil, merr end + else + chunk, msg = _load(ld, source) + end + end + if not chunk then + return chunk, msg + end + if env ~= nil then + setfenv(chunk, env) + end + return chunk + end + + loadstring = load + + local _loadfile = loadfile + local io_open = io.open + function loadfile(file, mode, env) + mode = mode or "bt" + if mode ~= "bt" then + local f = io_open(file, "rb") + if f then + local prefix = f:read(1) + f:close() + if prefix then + local merr = check_mode(mode, prefix) + if merr then return nil, merr end + end + end + end + local chunk, msg = _loadfile(file) + if not chunk then + return chunk, msg + end + if env ~= nil then + setfenv(chunk, env) + end + return chunk + end + end -- not luajit + + + if not is_luajit52 then + function rawlen(v) + local t = type(v) + if t ~= "string" and t ~= "table" then + error("bad argument #1 to 'rawlen' (table or string expected)", 0) + end + return #v + end + end + + + if not is_luajit52 then + local os_execute = os.execute + function os.execute(cmd) + local code = os_execute(cmd) + -- Lua 5.1 does not report exit by signal. + if code == 0 then + return true, "exit", code + else + return nil, "exit", code/256 -- only correct on POSIX! + end + end + end + + + if not is_luajit52 then + table.pack = function(...) + return { n = select('#', ...), ... } + end + end + + + local main_coroutine = coroutine.create(function() end) + + local _pcall = pcall + local coroutine_create = coroutine.create + function coroutine.create(func) + local success, result = _pcall(coroutine_create, func) + if not success then + if type(func) ~= "function" then + error("bad argument #1 (function expected)", 0) + end + result = coroutine_create(function(...) return func(...) end) + end + return result + end + + local pcall_mainOf = {} + + if not is_luajit52 then + function coroutine.running() + local co = coroutine_running() + if co then + return pcall_mainOf[co] or co, false + else + return main_coroutine, true + end + end + end + + local coroutine_yield = coroutine.yield + function coroutine.yield(...) + local co = coroutine_running() + if co then + return coroutine_yield(...) + else + error("attempt to yield from outside a coroutine", 0) + end + end + + local coroutine_resume = coroutine.resume + function coroutine.resume(co, ...) + if co == main_coroutine then + return false, "cannot resume non-suspended coroutine" + else + return coroutine_resume(co, ...) + end + end + + local coroutine_status = coroutine.status + function coroutine.status(co) + local notmain = coroutine_running() + if co == main_coroutine then + return notmain and "normal" or "running" + else + return coroutine_status(co) + end + end + + local function pcall_results(current, call, success, ...) + if coroutine_status(call) == "suspended" then + return pcall_results(current, call, coroutine_resume(call, coroutine_yield(...))) + end + if pcall_previous then + pcall_previous[call] = nil + local main = pcall_mainOf[call] + if main == current then current = nil end + pcall_callOf[main] = current + end + pcall_mainOf[call] = nil + return success, ... + end + local function pcall_exec(current, call, ...) + local main = pcall_mainOf[current] or current + pcall_mainOf[call] = main + if pcall_previous then + pcall_previous[call] = current + pcall_callOf[main] = call + end + return pcall_results(current, call, coroutine_resume(call, ...)) + end + local coroutine_create52 = coroutine.create + local function pcall_coroutine(func) + if type(func) ~= "function" then + local callable = func + func = function (...) return callable(...) end + end + return coroutine_create52(func) + end + function pcall(func, ...) + local current = coroutine_running() + if not current then return _pcall(func, ...) end + return pcall_exec(current, pcall_coroutine(func), ...) + end + + local function xpcall_catch(current, call, msgh, success, ...) + if not success then + xpcall_running[current] = call + local ok, result = _pcall(msgh, ...) + xpcall_running[current] = nil + if not ok then + return false, "error in error handling ("..tostring(result)..")" + end + return false, result + end + return true, ... + end + local _xpcall = xpcall + local _unpack = unpack + function xpcall(f, msgh, ...) + local current = coroutine_running() + if not current then + local args, n = { ... }, select('#', ...) + return _xpcall(function() return f(_unpack(args, 1, n)) end, msgh) + end + local call = pcall_coroutine(f) + return xpcall_catch(current, call, msgh, pcall_exec(current, call, ...)) + end + + + if not is_luajit then + local math_log = math.log + math.log = function(x, base) + if base ~= nil then + return math_log(x)/math_log(base) + else + return math_log(x) + end + end + end + + + if not is_luajit then + local io_open = io.open + local table_concat = table.concat + function package.searchpath(name, path, sep, rep) + sep = (sep or "."):gsub("(%p)", "%%%1") + rep = (rep or package.config:sub(1, 1)):gsub("(%%)", "%%%1") + local pname = name:gsub(sep, rep):gsub("(%%)", "%%%1") + local msg = {} + for subpath in path:gmatch("[^;]+") do + local fpath = subpath:gsub("%?", pname) + local f = io_open(fpath, "r") + if f then + f:close() + return fpath + end + msg[#msg+1] = "\n\tno file '" .. fpath .. "'" + end + return nil, table_concat(msg) + end + end + + local p_index = { searchers = package.loaders } + local rawset = rawset + setmetatable(package, { + __index = p_index, + __newindex = function(p, k, v) + if k == "searchers" then + rawset(p, "loaders", v) + p_index.searchers = v + else + rawset(p, k, v) + end + end + }) + + + local string_gsub = string.gsub + local function fix_pattern(pattern) + return (string_gsub(pattern, "%z", "%%z")) + end + + local string_find = string.find + function string.find(s, pattern, ...) + return string_find(s, fix_pattern(pattern), ...) + end + + local string_gmatch = string.gmatch + function string.gmatch(s, pattern) + return string_gmatch(s, fix_pattern(pattern)) + end + + function string.gsub(s, pattern, ...) + return string_gsub(s, fix_pattern(pattern), ...) + end + + local string_match = string.match + function string.match(s, pattern, ...) + return string_match(s, fix_pattern(pattern), ...) + end + + if not is_luajit then + local string_rep = string.rep + function string.rep(s, n, sep) + if sep ~= nil and sep ~= "" and n >= 2 then + return s .. string_rep(sep..s, n-1) + else + return string_rep(s, n) + end + end + end + + if not is_luajit then + local string_format = string.format + do + local addqt = { + ["\n"] = "\\\n", + ["\\"] = "\\\\", + ["\""] = "\\\"" + } + + local function addquoted(c) + return addqt[c] or string_format("\\%03d", c:byte()) + end + + function string.format(fmt, ...) + local args, n = { ... }, select('#', ...) + local i = 0 + local function adjust_fmt(lead, mods, kind) + if #lead % 2 == 0 then + i = i + 1 + if kind == "s" then + args[i] = tostring(args[i]) + elseif kind == "q" then + args[i] = '"'..string_gsub(args[i], "[%z%c\\\"\n]", addquoted)..'"' + return lead.."%"..mods.."s" + end + end + end + fmt = string_gsub(fmt, "(%%*)%%([%d%.%-%+%# ]*)(%a)", adjust_fmt) + return string_format(fmt, _unpack(args, 1, n)) + end + end + end + + + local io_open = io.open + local io_write = io.write + local io_output = io.output + function io.write(...) + local res, msg, errno = io_write(...) + if res then + return io_output() + else + return nil, msg, errno + end + end + + if not is_luajit then + local lines_iterator + do + local function helper( st, var_1, ... ) + if var_1 == nil then + if st.doclose then st.f:close() end + if (...) ~= nil then + error((...), 0) + end + end + return var_1, ... + end + + function lines_iterator(st) + return helper(st, st.f:read(_unpack(st, 1, st.n))) + end + end + + local valid_format = { ["*l"] = true, ["*n"] = true, ["*a"] = true } + + local io_input = io.input + function io.lines(fname, ...) + local doclose, file, msg + if fname ~= nil then + doclose, file, msg = true, io_open(fname, "r") + if not file then error(msg, 0) end + else + doclose, file = false, io_input() + end + local st = { f=file, doclose=doclose, n=select('#', ...), ... } + for i = 1, st.n do + if type(st[i]) ~= "number" and not valid_format[st[i]] then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 0) + end + end + return lines_iterator, st + end + + do + local io_stdout = io.stdout + local io_type = io.type + local file_meta = gmt(io_stdout) + if type(file_meta) == "table" and type(file_meta.__index) == "table" then + local file_write = file_meta.__index.write + file_meta.__index.write = function(self, ...) + local res, msg, errno = file_write(self, ...) + if res then + return self + else + return nil, msg, errno + end + end + + file_meta.__index.lines = function(self, ...) + if io_type(self) == "closed file" then + error("attempt to use a closed file", 0) + end + local st = { f=self, doclose=false, n=select('#', ...), ... } + for i = 1, st.n do + if type(st[i]) ~= "number" and not valid_format[st[i]] then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 0) + end + end + return lines_iterator, st + end + end + end + end -- not luajit + + + end -- lua 5.1 -end +end -- lua < 5.3 -- vi: set expandtab softtabstop=3 shiftwidth=3 : -- cgit v1.2.3-55-g6feb