From 65658e90ffcf6a84f13ec7c781c228e4e36f1799 Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Sun, 23 Aug 2015 00:38:38 +0200 Subject: Make '*' optional for io.read() and io.lines(). --- compat53/module.lua | 98 +++++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 77 insertions(+), 21 deletions(-) (limited to 'compat53') diff --git a/compat53/module.lua b/compat53/module.lua index 4450bd5..c4fc515 100644 --- a/compat53/module.lua +++ b/compat53/module.lua @@ -9,8 +9,11 @@ if lua_version < "5.3" then -- cache globals in upvalues local error, ipairs, pairs, pcall, require, select, setmetatable, type = error, ipairs, pairs, pcall, require, select, setmetatable, type - local debug, math, package, string, table = - debug, math, package, string, table + local debug, io, math, package, string, table = + debug, io, math, package, string, table + local io_lines = io.lines + local io_read = io.read + local unpack = lua_version == "5.1" and unpack or table.unpack -- create module table M = {} @@ -21,6 +24,7 @@ if lua_version < "5.3" then setmetatable(M, M_meta) -- create subtables + M.io = setmetatable({}, { __index = io }) M.math = setmetatable({}, { __index = math }) M.string = setmetatable({}, { __index = string }) M.table = setmetatable({}, { __index = table }) @@ -148,15 +152,13 @@ if lua_version < "5.3" then -- assert should allow non-string error objects - do - function M.assert(cond, ...) - if cond then - return cond, ... - elseif select('#', ...) > 0 then - error((...), 0) - else - error("assertion failed!", 0) - end + function M.assert(cond, ...) + if cond then + return cond, ... + elseif select('#', ...) > 0 then + error((...), 0) + else + error("assertion failed!", 0) end end @@ -180,13 +182,63 @@ if lua_version < "5.3" then end + -- make '*' optional for io.read and io.lines + do + local function addasterisk(fmt) + if type(fmt) == "string" and fmt:sub(1, 1) ~= "*" then + return "*"..fmt + else + return fmt + end + end + + function M.io.read(...) + local n = select('#', ...) + for i = 1, n do + local a = select(i, ...) + local b = addasterisk(a) + -- as an optimization we only allocate a table for the + -- modified format arguments when we have a '*' somewhere. + if a ~= b then + local args = { ... } + args[i] = b + for j = i+1, n do + args[j] = addasterisk(args[j]) + end + return io_read(unpack(args, 1, n)) + end + end + return io_read(...) + end + + -- PUC-Rio Lua 5.1 uses a different implementation for io.lines! + function M.io.lines(...) + local n = select('#', ...) + for i = 2, n do + local a = select(i, ...) + local b = addasterisk(a) + -- as an optimization we only allocate a table for the + -- modified format arguments when we have a '*' somewhere. + if a ~= b then + local args = { ... } + args[i] = b + for j = i+1, n do + args[j] = addasterisk(args[j]) + end + return io_lines(unpack(args, 1, n)) + end + end + return io_lines(...) + end + end + + -- update table library (if C module not available) if not table_ok then local table_concat = table.concat local table_insert = table.insert local table_remove = table.remove local table_sort = table.sort - local table_unpack = lua_version == "5.1" and unpack or table.unpack function M.table.concat(list, sep, i, j) local mt = gmt(list) @@ -344,7 +396,7 @@ if lua_version < "5.3" then i, j = i or 1, j or (has_len and mt.__len(list)) or #list return unpack_helper(list, i, j) else - return table_unpack(list, i, j) + return unpack(list, i, j) end end end -- update table library @@ -358,9 +410,9 @@ if lua_version < "5.3" then #setmetatable({}, { __len = function() return 1 end }) == 1 -- cache globals in upvalues - local load, loadfile, loadstring, setfenv, unpack, xpcall = - load, loadfile, loadstring, setfenv, unpack, xpcall - local coroutine, io, os = coroutine, io, os + local load, loadfile, loadstring, setfenv, xpcall = + load, loadfile, loadstring, setfenv, xpcall + local coroutine, os = coroutine, os local coroutine_create = coroutine.create local coroutine_resume = coroutine.resume local coroutine_running = coroutine.running @@ -382,7 +434,6 @@ if lua_version < "5.3" then -- create subtables M.coroutine = setmetatable({}, { __index = coroutine }) - M.io = setmetatable({}, { __index = io }) M.os = setmetatable({}, { __index = os }) M.package = setmetatable({}, { __index = package }) @@ -734,8 +785,6 @@ if lua_version < "5.3" then return helper(st, st.f:read(unpack(st, 1, st.n))) end - local valid_format = { ["*l"] = true, ["*n"] = true, ["*a"] = true } - function M.io.lines(fname, ...) local doclose, file, msg if fname ~= nil then @@ -746,8 +795,15 @@ if lua_version < "5.3" then end local st = { f=file, doclose=doclose, n=select('#', ...), ... } for i = 1, st.n do - if type(st[i]) ~= "number" and not valid_format[st[i]] then - error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) + local t = type(st[i]) + if t == "string" then + local fmt = st[i]:match("^%*?([aln])") + if not fmt then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) + end + st[i] = "*"..fmt + elseif t ~= "number" then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) end end return lines_iterator, st -- cgit v1.2.3-55-g6feb