From 65658e90ffcf6a84f13ec7c781c228e4e36f1799 Mon Sep 17 00:00:00 2001 From: Philipp Janda Date: Sun, 23 Aug 2015 00:38:38 +0200 Subject: Make '*' optional for io.read() and io.lines(). --- compat53/module.lua | 98 +++++++++++++++++++++++++++++++++++++++++------------ tests/test.lua | 14 +++++++- 2 files changed, 90 insertions(+), 22 deletions(-) diff --git a/compat53/module.lua b/compat53/module.lua index 4450bd5..c4fc515 100644 --- a/compat53/module.lua +++ b/compat53/module.lua @@ -9,8 +9,11 @@ if lua_version < "5.3" then -- cache globals in upvalues local error, ipairs, pairs, pcall, require, select, setmetatable, type = error, ipairs, pairs, pcall, require, select, setmetatable, type - local debug, math, package, string, table = - debug, math, package, string, table + local debug, io, math, package, string, table = + debug, io, math, package, string, table + local io_lines = io.lines + local io_read = io.read + local unpack = lua_version == "5.1" and unpack or table.unpack -- create module table M = {} @@ -21,6 +24,7 @@ if lua_version < "5.3" then setmetatable(M, M_meta) -- create subtables + M.io = setmetatable({}, { __index = io }) M.math = setmetatable({}, { __index = math }) M.string = setmetatable({}, { __index = string }) M.table = setmetatable({}, { __index = table }) @@ -148,15 +152,13 @@ if lua_version < "5.3" then -- assert should allow non-string error objects - do - function M.assert(cond, ...) - if cond then - return cond, ... - elseif select('#', ...) > 0 then - error((...), 0) - else - error("assertion failed!", 0) - end + function M.assert(cond, ...) + if cond then + return cond, ... + elseif select('#', ...) > 0 then + error((...), 0) + else + error("assertion failed!", 0) end end @@ -180,13 +182,63 @@ if lua_version < "5.3" then end + -- make '*' optional for io.read and io.lines + do + local function addasterisk(fmt) + if type(fmt) == "string" and fmt:sub(1, 1) ~= "*" then + return "*"..fmt + else + return fmt + end + end + + function M.io.read(...) + local n = select('#', ...) + for i = 1, n do + local a = select(i, ...) + local b = addasterisk(a) + -- as an optimization we only allocate a table for the + -- modified format arguments when we have a '*' somewhere. + if a ~= b then + local args = { ... } + args[i] = b + for j = i+1, n do + args[j] = addasterisk(args[j]) + end + return io_read(unpack(args, 1, n)) + end + end + return io_read(...) + end + + -- PUC-Rio Lua 5.1 uses a different implementation for io.lines! + function M.io.lines(...) + local n = select('#', ...) + for i = 2, n do + local a = select(i, ...) + local b = addasterisk(a) + -- as an optimization we only allocate a table for the + -- modified format arguments when we have a '*' somewhere. + if a ~= b then + local args = { ... } + args[i] = b + for j = i+1, n do + args[j] = addasterisk(args[j]) + end + return io_lines(unpack(args, 1, n)) + end + end + return io_lines(...) + end + end + + -- update table library (if C module not available) if not table_ok then local table_concat = table.concat local table_insert = table.insert local table_remove = table.remove local table_sort = table.sort - local table_unpack = lua_version == "5.1" and unpack or table.unpack function M.table.concat(list, sep, i, j) local mt = gmt(list) @@ -344,7 +396,7 @@ if lua_version < "5.3" then i, j = i or 1, j or (has_len and mt.__len(list)) or #list return unpack_helper(list, i, j) else - return table_unpack(list, i, j) + return unpack(list, i, j) end end end -- update table library @@ -358,9 +410,9 @@ if lua_version < "5.3" then #setmetatable({}, { __len = function() return 1 end }) == 1 -- cache globals in upvalues - local load, loadfile, loadstring, setfenv, unpack, xpcall = - load, loadfile, loadstring, setfenv, unpack, xpcall - local coroutine, io, os = coroutine, io, os + local load, loadfile, loadstring, setfenv, xpcall = + load, loadfile, loadstring, setfenv, xpcall + local coroutine, os = coroutine, os local coroutine_create = coroutine.create local coroutine_resume = coroutine.resume local coroutine_running = coroutine.running @@ -382,7 +434,6 @@ if lua_version < "5.3" then -- create subtables M.coroutine = setmetatable({}, { __index = coroutine }) - M.io = setmetatable({}, { __index = io }) M.os = setmetatable({}, { __index = os }) M.package = setmetatable({}, { __index = package }) @@ -734,8 +785,6 @@ if lua_version < "5.3" then return helper(st, st.f:read(unpack(st, 1, st.n))) end - local valid_format = { ["*l"] = true, ["*n"] = true, ["*a"] = true } - function M.io.lines(fname, ...) local doclose, file, msg if fname ~= nil then @@ -746,8 +795,15 @@ if lua_version < "5.3" then end local st = { f=file, doclose=doclose, n=select('#', ...), ... } for i = 1, st.n do - if type(st[i]) ~= "number" and not valid_format[st[i]] then - error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) + local t = type(st[i]) + if t == "string" then + local fmt = st[i]:match("^%*?([aln])") + if not fmt then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) + end + st[i] = "*"..fmt + elseif t ~= "number" then + error("bad argument #"..(i+1).." to 'for iterator' (invalid format)", 2) end end return lines_iterator, st diff --git a/tests/test.lua b/tests/test.lua index 7e9d57f..c634ebe 100755 --- a/tests/test.lua +++ b/tests/test.lua @@ -537,6 +537,18 @@ do end +___'' +do + writefile("data.txt", "123 18.8 hello world\ni'm here\n") + io.input("data.txt") + print(io.read("*n", "*number", "*l", "*a")) + io.input("data.txt") + print(io.read("n", "number", "l", "a")) + io.input(io.stdin) + os.remove("data.txt") +end + + ___'' do writefile("data.txt", "123 18.8 hello world\ni'm here\n") @@ -548,7 +560,7 @@ do print("io.lines()", l) break end - for n1,n2,rest in io.lines("data.txt", "*n", "*n", "*a") do + for n1,n2,rest in io.lines("data.txt", "*n", "n", "*a") do print("io.lines()", n1, n2, rest) end for l in io.lines("data.txt") do -- cgit v1.2.3-55-g6feb