From 4c062846ae003dd747dfcd3eca91493c7eee4f50 Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Sat, 17 Jan 2015 20:36:32 +0100 Subject: table library (except table.sort for now) respects metamethods --- README.md | 5 +- compat53.lua | 158 +++++++++++++++++++++++++++++++++++++++++++++++++++++---- tests/test.lua | 116 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 265 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 5adaa38..b788094 100644 --- a/README.md +++ b/README.md @@ -66,6 +66,8 @@ your project: * `math.maxinteger` and `math.mininteger`, `math.tointeger`, `math.type`, and `math.ult` * `ipairs` respects `__index` metamethod +* `table.move` +* `table` library (except `table.sort`) respects metamethods ### C @@ -112,9 +114,8 @@ For Lua 5.1 additionally: * bit operators * integer division operator -* `table.move` * `coroutine.isyieldable` -* `table` library doesn't respect metamethods yet +* `table.sort` doesn't respect metamethods yet * Lua 5.1: `_ENV`, `goto`, labels, ephemeron tables, etc. See [`lua-compat-5.2`][2] for a detailed list. * the following C API functions/macros: diff --git a/compat53.lua b/compat53.lua index 8bf6bd8..761d554 100644 --- a/compat53.lua +++ b/compat53.lua @@ -1,12 +1,31 @@ local lua_version = _VERSION:sub(-3) -if lua_version ~= "5.3" then - local _type = type +if lua_version < "5.3" then + -- local aliases for commonly used functions + local type, select, error = type, select, error + -- select the most powerful getmetatable function available - local gmt = _type(debug) == "table" and debug.getmetatable or + local gmt = type(debug) == "table" and debug.getmetatable or getmetatable or function() return false end + + -- type checking functions local checkinteger -- forward declararation + local function argcheck(cond, i, f, extra) + if not cond then + error("bad argument #"..i.." to '"..f.."' ("..extra..")", 0) + end + end + + local function checktype(x, t, i, f) + local xt = type(x) + if xt ~= t then + error("bad argument #"..i.." to '"..f.."' ("..t.. + " expected, got "..xt..")", 0) + end + end + + -- load utf8 library local ok, utf8lib = pcall(require, "compat53.utf8") if ok then @@ -17,6 +36,7 @@ if lua_version ~= "5.3" then end end + -- use Roberto's struct module for string packing/unpacking for now -- maybe we'll later extract the functions from the 5.3 string -- library for greater compatiblity, but it uses the 5.3 buffer API @@ -47,14 +67,14 @@ if lua_version ~= "5.3" then math.mininteger = minint function math.tointeger(n) - if _type(n) == "number" and n <= maxint and n >= minint and n % 1 == 0 then + if type(n) == "number" and n <= maxint and n >= minint and n % 1 == 0 then return n end return nil end function math.type(n) - if _type(n) == "number" then + if type(n) == "number" then if n <= maxint and n >= minint and n % 1 == 0 then return "integer" else @@ -65,15 +85,14 @@ if lua_version ~= "5.3" then end end - local _error = error function checkinteger(x, i, f) - local t = _type(x) + local t = type(x) if t ~= "number" then - _error("bad argument #"..i.." to '"..f.. - "' (number expected, got "..t..")", 0) + error("bad argument #"..i.." to '"..f.. + "' (number expected, got "..t..")", 0) elseif x > maxint or x < minint or x % 1 ~= 0 then - _error("bad argument #"..i.." to '"..f.. - "' (number has no integer representation)", 0) + error("bad argument #"..i.." to '"..f.. + "' (number has no integer representation)", 0) else return x end @@ -112,6 +131,123 @@ if lua_version ~= "5.3" then end end + + -- update table library + do + local table_concat = table.concat + function table.concat(list, sep, i, j) + local mt = gmt(list) + if type(mt) == "table" and type(mt.__len) == "function" then + local src = list + list, i, j = {}, i or 1, j or mt.__len(src) + for k = i, j do + list[k] = src[k] + end + end + return table_concat(list, sep, i, j) + end + + local table_insert = table.insert + function table.insert(list, ...) + local mt = gmt(list) + local has_mt = type(mt) == "table" + local has_len = has_mt and type(mt.__len) == "function" + if has_mt and (has_len or mt.__index or mt.__newindex) then + local e = (has_len and mt.__len(list) or #list)+1 + local nargs, pos, value = select('#', ...), ... + if nargs == 1 then + pos, value = e, pos + elseif nargs == 2 then + pos = checkinteger(pos, "2", "table.insert") + argcheck(1 <= pos and pos <= e, "2", "table.insert", + "position out of bounds" ) + else + error("wrong number of arguments to 'insert'", 0) + end + for i = e-1, pos, -1 do + list[i+1] = list[i] + end + list[pos] = value + else + return table_insert(list, ...) + end + end + + function table.move(a1, f, e, t, a2) + a2 = a2 or a1 + f = checkinteger(f, "2", "table.move") + argcheck(f > 0, "2", "table.move", + "initial position must be positive") + e = checkinteger(e, "3", "table.move") + t = checkinteger(t, "4", "table.move") + if e >= f then + local m, n, d = 0, e-f, 1 + if t > f then m, n, d = n, m, -1 end + for i = m, n, d do + a2[t+i] = a1[f+i] + end + end + return a2 + end + + local table_remove = table.remove + function table.remove(list, pos) + local mt = gmt(list) + local has_mt = type(mt) == "table" + local has_len = has_mt and type(mt.__len) == "function" + if has_mt and (has_len or mt.__index or mt.__newindex) then + local e = (has_len and mt.__len(list) or #list) + pos = pos ~= nil and checkinteger(pos, "2", "table.remove") or e + if pos ~= e then + argcheck(1 <= pos and pos <= e+1, "2", "table.remove", + "position out of bounds" ) + end + local result = list[pos] + while pos < e do + list[pos] = list[pos+1] + pos = pos + 1 + end + list[pos] = nil + return result + else + return table_remove(list, pos) + end + end + + -- TODO: table.sort + + local table_unpack = lua_version == "5.1" and unpack or table.unpack + local function unpack_helper(list, i, j, ...) + if j < i then + return ... + else + return unpack_helper(list, i, j-1, list[j], ...) + end + end + function table.unpack(list, i, j) + local mt = gmt(list) + local has_mt = type(mt) == "table" + local has_len = has_mt and type(mt.__len) == "function" + if has_mt and (has_len or mt.__index) then + i, j = i or 1, j or (has_len and mt.__len(list)) or #list + return unpack_helper(list, i, j) + else + return table_unpack(list, i, j) + end + end + end + + + if lua_version == "5.1" then + -- detect LuaJIT (including LUAJIT_ENABLE_LUA52COMPAT compilation flag) + local is_luajit = (string.dump(function() end) or ""):sub(1, 3) == "\027LJ" + local is_luajit52 = is_luajit and + #setmetatable({}, { __len = function() return 1 end }) == 1 + + -- TODO: add functions from lua-compat-5.2 + + end + end -- vi: set expandtab softtabstop=3 shiftwidth=3 : diff --git a/tests/test.lua b/tests/test.lua index ab25820..7e1c9da 100755 --- a/tests/test.lua +++ b/tests/test.lua @@ -1,6 +1,6 @@ #!/usr/bin/env lua -local F, ___ +local F, tproxy, ___ do local type, unpack = type, table.unpack or unpack function F(...) @@ -13,6 +13,13 @@ do end return unpack(args, 1, n) end + function tproxy(t) + return setmetatable({}, { + __index = t, + __newindex = t, + __len = function() return #t end, + }), t + end local sep = ("="):rep(70) function ___() print(sep) @@ -32,6 +39,113 @@ do end +___'' +do + local p, t = tproxy{ "a", "b", "c" } + print("table.concat", table.concat(p)) + print("table.concat", table.concat(p, ",", 2)) + print("table.concat", table.concat(p, ".", 1, 2)) + print("table.concat", table.concat(t)) + print("table.concat", table.concat(t, ",", 2)) + print("table.concat", table.concat(t, ".", 1, 2)) +end + + +___'' +do + local p, t = tproxy{ "a", "b", "c" } + table.insert(p, "d") + print("table.insert", next(p), t[4]) + table.insert(p, 1, "z") + print("table.insert", next(p), t[1], t[2]) + table.insert(p, 2, "y") + print("table.insert", next(p), t[1], t[2], p[3]) + t = { "a", "b", "c" } + table.insert(t, "d") + print("table.insert", t[1], t[2], t[3], t[4]) + table.insert(t, 1, "z") + print("table.insert", t[1], t[2], t[3], t[4], t[5]) + table.insert(t, 2, "y") + print("table.insert", t[1], t[2], t[3], t[4], t[5]) +end + + +___'' +do + local ps, s = tproxy{ "a", "b", "c", "d" } + local pd, d = tproxy{ "A", "B", "C", "D" } + table.move(ps, 1, 4, 1, pd) + print("table.move", next(pd), d[1], d[2], d[3], d[4]) + pd, d = tproxy{ "A", "B", "C", "D" } + table.move(ps, 2, 4, 1, pd) + print("table.move", next(pd), d[1], d[2], d[3], d[4]) + pd, d = tproxy{ "A", "B", "C", "D" } + table.move(ps, 2, 3, 4, pd) + print("table.move", next(pd), d[1], d[2], d[3], d[4], d[5]) + table.move(ps, 2, 4, 1) + print("table.move", next(ps), s[1], s[2], s[3], s[4]) + ps, s = tproxy{ "a", "b", "c", "d" } + table.move(ps, 2, 3, 4) + print("table.move", next(ps), s[1], s[2], s[3], s[4], s[5]) + s = { "a", "b", "c", "d" } + d = { "A", "B", "C", "D" } + table.move(s, 1, 4, 1, d) + print("table.move", d[1], d[2], d[3], d[4]) + d = { "A", "B", "C", "D" } + table.move(s, 2, 4, 1, d) + print("table.move", d[1], d[2], d[3], d[4]) + d = { "A", "B", "C", "D" } + table.move(s, 2, 3, 4, d) + print("table.move", d[1], d[2], d[3], d[4], d[5]) + table.move(s, 2, 4, 1) + print("table.move", s[1], s[2], s[3], s[4]) + s = { "a", "b", "c", "d" } + table.move(s, 2, 3, 4) + print("table.move", s[1], s[2], s[3], s[4], s[5]) +end + + +___'' +do + local p, t = tproxy{ "a", "b", "c", "d", "e" } + print("table.remove", table.remove(p)) + print("table.remove", next(p), t[1], t[2], t[3], t[4], t[5]) + print("table.remove", table.remove(p, 1)) + print("table.remove", next(p), t[1], t[2], t[3], t[4]) + print("table.remove", table.remove(p, 2)) + print("table.remove", next(p), t[1], t[2], t[3]) + print("table.remove", table.remove(p, 3)) + print("table.remove", next(p), t[1], t[2], t[3]) + p, t = tproxy{} + print("table.remove", table.remove(p)) + print("table.remove", next(p), next(t)) + t = { "a", "b", "c", "d", "e" } + print("table.remove", table.remove(t)) + print("table.remove", t[1], t[2], t[3], t[4], t[5]) + print("table.remove", table.remove(t, 1)) + print("table.remove", t[1], t[2], t[3], t[4]) + print("table.remove", table.remove(t, 2)) + print("table.remove", t[1], t[2], t[3]) + print("table.remove", table.remove(t, 3)) + print("table.remove", t[1], t[2], t[3]) + t = {} + print("table.remove", table.remove(t)) + print("table.remove", next(t)) +end + + +___'' +do + local p, t = tproxy{ "a", "b", "c" } + print("table.unpack", table.unpack(p)) + print("table.unpack", table.unpack(p, 2)) + print("table.unpack", table.unpack(p, 1, 2)) + print("table.unpack", table.unpack(t)) + print("table.unpack", table.unpack(t, 2)) + print("table.unpack", table.unpack(t, 1, 2)) +end + + ___'' print("math.maxinteger", math.maxinteger+1 > math.maxinteger) print("math.mininteger", math.mininteger-1 < math.mininteger) -- cgit v1.2.3-55-g6feb