From 7c97e8e40aaa665226fb54449773dc3134e755b2 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Sat, 27 Nov 2004 07:58:04 +0000 Subject: Almost ready for beta3 --- src/buffer.c | 4 +++- src/except.c | 26 +++++++++++++++++++++++--- src/ftp.lua | 53 +++++++++++++++++++++++++++++++++++++++++++---------- src/http.lua | 23 ++++++++++++++--------- src/ltn12.lua | 22 ++++++++++++++-------- src/mime.lua | 11 +++++++---- src/smtp.lua | 25 +++++++++++++++---------- src/tcp.c | 3 ++- src/tp.lua | 29 +++++++++++++++++++---------- src/udp.c | 3 ++- src/url.lua | 15 ++++++++++----- src/wsocket.c | 15 +++++++-------- 12 files changed, 159 insertions(+), 70 deletions(-) (limited to 'src') diff --git a/src/buffer.c b/src/buffer.c index dbd5d2c..1b1b791 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -158,6 +158,7 @@ int buf_isempty(p_buf buf) { /*-------------------------------------------------------------------------*\ * Sends a block of data (unbuffered) \*-------------------------------------------------------------------------*/ +#define STEPSIZE 8192 static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) { p_io io = buf->io; p_tm tm = buf->tm; @@ -165,7 +166,8 @@ static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) { int err = IO_DONE; while (total < count && err == IO_DONE) { size_t done; - err = io->send(io->ctx, data+total, count-total, &done, tm); + size_t step = (count-total <= STEPSIZE)? count-total: STEPSIZE; + err = io->send(io->ctx, data+total, step, &done, tm); total += done; } *sent = total; diff --git a/src/except.c b/src/except.c index 80d7e5d..dabaf19 100644 --- a/src/except.c +++ b/src/except.c @@ -29,11 +29,21 @@ static luaL_reg func[] = { /*-------------------------------------------------------------------------*\ * Try factory \*-------------------------------------------------------------------------*/ +static void wrap(lua_State *L) { + lua_newtable(L); + lua_pushnumber(L, 1); + lua_pushvalue(L, -3); + lua_settable(L, -3); + lua_insert(L, -2); + lua_pop(L, 1); +} + static int finalize(lua_State *L) { if (!lua_toboolean(L, 1)) { lua_pushvalue(L, lua_upvalueindex(1)); lua_pcall(L, 0, 0, 0); lua_settop(L, 2); + wrap(L); lua_error(L); return 0; } else return lua_gettop(L); @@ -54,13 +64,23 @@ static int global_newtry(lua_State *L) { /*-------------------------------------------------------------------------*\ * Protect factory \*-------------------------------------------------------------------------*/ +static int unwrap(lua_State *L) { + if (lua_istable(L, -1)) { + lua_pushnumber(L, 1); + lua_gettable(L, -2); + lua_pushnil(L); + lua_insert(L, -2); + return 1; + } else return 0; +} + static int protected_(lua_State *L) { lua_pushvalue(L, lua_upvalueindex(1)); lua_insert(L, 1); if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { - lua_pushnil(L); - lua_insert(L, 1); - return 2; + if (unwrap(L)) return 2; + else lua_error(L); + return 0; } else return lua_gettop(L); } diff --git a/src/ftp.lua b/src/ftp.lua index 9902c88..4529acd 100644 --- a/src/ftp.lua +++ b/src/ftp.lua @@ -8,13 +8,15 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local table = require("table") +local string = require("string") +local math = require("math") local socket = require("socket") local url = require("socket.url") local tp = require("socket.tp") - local ltn12 = require("ltn12") - -module("socket.ftp") +local ftp = module("socket.ftp") ----------------------------------------------------------------------------- -- Program constants @@ -35,7 +37,7 @@ local metat = { __index = {} } function open(server, port) local tp = socket.try(tp.connect(server, port or PORT, TIMEOUT)) - local f = setmetatable({ tp = tp }, metat) + local f = base.setmetatable({ tp = tp }, metat) -- make sure everything gets closed in an exception f.try = socket.newtry(function() f:close() end) return f @@ -102,7 +104,8 @@ function metat.__index:send(sendt) -- we just get the data connection into self.data if self.pasvt then self:pasvconnect() end -- get the transfer argument and command - local argument = sendt.argument or string.gsub(sendt.path, "^/", "") + local argument = sendt.argument or + url.unescape(string.gsub(sendt.path or "", "^/", "")) if argument == "" then argument = nil end local command = sendt.command or "stor" -- send the transfer command and check the reply @@ -134,7 +137,8 @@ end function metat.__index:receive(recvt) self.try(self.pasvt or self.server, "need port or pasv first") if self.pasvt then self:pasvconnect() end - local argument = recvt.argument or string.gsub(recvt.path, "^/", "") + local argument = recvt.argument or + url.unescape(string.gsub(recvt.path or "", "^/", "")) if argument == "" then argument = nil end local command = recvt.command or "retr" self.try(self.tp:command(command, argument)) @@ -182,7 +186,19 @@ end ----------------------------------------------------------------------------- -- High level FTP API ----------------------------------------------------------------------------- +function override(t) + if t.url then + u = url.parse(t.url) + for i,v in base.pairs(t) do + u[i] = v + end + return u + else return t end +end + local function tput(putt) + putt = override(putt) + socket.try(putt.host, "missing hostname") local f = open(putt.host, putt.port) f:greet() f:login(putt.user, putt.password) @@ -201,8 +217,8 @@ local default = { local function parse(u) local t = socket.try(url.parse(u, default)) - socket.try(t.scheme == "ftp", "invalid scheme '" .. t.scheme .. "'") - socket.try(t.host, "invalid host") + socket.try(t.scheme == "ftp", "wrong scheme '" .. t.scheme .. "'") + socket.try(t.host, "missing hostname") local pat = "^type=(.)$" if t.params then t.type = socket.skip(2, string.find(t.params, pat)) @@ -219,11 +235,13 @@ local function sput(u, body) end put = socket.protect(function(putt, body) - if type(putt) == "string" then return sput(putt, body) + if base.type(putt) == "string" then return sput(putt, body) else return tput(putt) end end) local function tget(gett) + gett = override(gett) + socket.try(gett.host, "missing hostname") local f = open(gett.host, gett.port) f:greet() f:login(gett.user, gett.password) @@ -242,7 +260,22 @@ local function sget(u) return table.concat(t) end +command = socket.protect(function(cmdt) + cmdt = override(cmdt) + socket.try(cmdt.host, "missing hostname") + socket.try(cmdt.command, "missing command") + local f = open(cmdt.host, cmdt.port) + f:greet() + f:login(cmdt.user, cmdt.password) + f.try(f.tp:command(cmdt.command, cmdt.argument)) + if cmdt.check then f.try(f.tp:check(cmdt.check)) end + f:quit() + return f:close() +end) + get = socket.protect(function(gett) - if type(gett) == "string" then return sget(gett) + if base.type(gett) == "string" then return sget(gett) else return tget(gett) end end) + +base.setmetatable(ftp, nil) diff --git a/src/http.lua b/src/http.lua index b265650..a15ea69 100644 --- a/src/http.lua +++ b/src/http.lua @@ -12,8 +12,10 @@ local socket = require("socket") local url = require("socket.url") local ltn12 = require("ltn12") local mime = require("mime") - -module("socket.http") +local string = require("string") +local base = require("base") +local table = require("table") +local http = module("socket.http") ----------------------------------------------------------------------------- -- Program constants @@ -32,7 +34,7 @@ local metat = { __index = {} } function open(host, port) local c = socket.try(socket.tcp()) - local h = setmetatable({ c = c }, metat) + local h = base.setmetatable({ c = c }, metat) -- make sure the connection gets closed on exception h.try = socket.newtry(function() h:close() end) h.try(c:settimeout(TIMEOUT)) @@ -46,7 +48,7 @@ function metat.__index:sendrequestline(method, uri) end function metat.__index:sendheaders(headers) - for i, v in pairs(headers) do + for i, v in base.pairs(headers) do self.try(self.c:send(i .. ": " .. v .. "\r\n")) end -- mark end of request headers @@ -66,7 +68,7 @@ end function metat.__index:receivestatusline() local status = self.try(self.c:receive()) local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) - return self.try(tonumber(code), status) + return self.try(base.tonumber(code), status) end function metat.__index:receiveheaders() @@ -97,11 +99,11 @@ end function metat.__index:receivebody(headers, sink, step) sink = sink or ltn12.sink.null() step = step or ltn12.pump.step - local length = tonumber(headers["content-length"]) + local length = base.tonumber(headers["content-length"]) local TE = headers["transfer-encoding"] local mode = "default" -- connection close if TE and TE ~= "identity" then mode = "http-chunked" - elseif tonumber(headers["content-length"]) then mode = "by-length" end + elseif base.tonumber(headers["content-length"]) then mode = "by-length" end return self.try(ltn12.pump.all(socket.source(mode, self.c, length), sink, step)) end @@ -159,9 +161,10 @@ local default = { local function adjustrequest(reqt) -- parse url if provided local nreqt = reqt.url and url.parse(reqt.url, default) or {} + local t = url.parse(reqt.url, default) -- explicit components override url for i,v in reqt do nreqt[i] = reqt[i] end - socket.try(nreqt.host, "invalid host '" .. tostring(nreqt.host) .. "'") + socket.try(nreqt.host, "invalid host '" .. base.tostring(nreqt.host) .. "'") -- compute uri if user hasn't overriden nreqt.uri = reqt.uri or adjusturi(nreqt) -- ajust host and port if there is a proxy @@ -253,6 +256,8 @@ local function srequest(u, body) end request = socket.protect(function(reqt, body) - if type(reqt) == "string" then return srequest(reqt, body) + if base.type(reqt) == "string" then return srequest(reqt, body) else return trequest(reqt) end end) + +base.setmetatable(http, nil) diff --git a/src/ltn12.lua b/src/ltn12.lua index ed39ec8..43c2755 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -8,7 +8,11 @@ ----------------------------------------------------------------------------- -- Declare module ----------------------------------------------------------------------------- -module("ltn12") +local string = require("string") +local table = require("table") +local base = require("base") +local coroutine = require("coroutine") +local ltn12 = module("ltn12") filter = {} source = {} @@ -23,7 +27,7 @@ BLOCKSIZE = 2048 ----------------------------------------------------------------------------- -- returns a high level filter that cycles a low-level filter function filter.cycle(low, ctx, extra) - assert(low) + base.assert(low) return function(chunk) local ret ret, ctx = low(ctx, chunk, extra) @@ -121,7 +125,7 @@ end -- turns a fancy source into a simple source function source.simplify(src) - assert(src) + base.assert(src) return function() local chunk, err_or_new = src() src = err_or_new or src @@ -145,7 +149,7 @@ end -- creates rewindable source function source.rewind(src) - assert(src) + base.assert(src) local t = {} return function(chunk) if not chunk then @@ -160,7 +164,7 @@ end -- chains a source with a filter function source.chain(src, f) - assert(src and f) + base.assert(src and f) local co = coroutine.create(function() while true do local chunk, err = src() @@ -215,7 +219,7 @@ end -- turns a fancy sink into a simple sink function sink.simplify(snk) - assert(snk) + base.assert(snk) return function(chunk, err) local ret, err_or_new = snk(chunk, err) if not ret then return nil, err_or_new end @@ -254,7 +258,7 @@ end -- chains a sink with a filter function sink.chain(f, snk) - assert(f and snk) + base.assert(f and snk) return function(chunk, err) local filtered = f(chunk) local done = chunk and "" @@ -279,10 +283,12 @@ end -- pumps all data from a source to a sink, using a step function function pump.all(src, snk, step) - assert(src and snk) + base.assert(src and snk) step = step or pump.step while true do local ret, err = step(src, snk) if not ret then return not err, err end end end + +base.setmetatable(ltn12, nil) diff --git a/src/mime.lua b/src/mime.lua index 3dbcf79..712600c 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -8,9 +8,10 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- -module("mime") -local mime = require("lmime") +local base = require("base") local ltn12 = require("ltn12") +local mime = require("lmime") +module("mime") -- encode, decode and wrap algorithm tables mime.encodet = {} @@ -20,11 +21,11 @@ mime.wrapt = {} -- creates a function that chooses a filter by name from a given table local function choose(table) return function(name, opt1, opt2) - if type(name) ~= "string" then + if base.type(name) ~= "string" then name, opt1, opt2 = "default", name, opt1 end local f = table[name or "nil"] - if not f then error("unknown key (" .. tostring(name) .. ")", 3) + if not f then error("unknown key (" .. base.tostring(name) .. ")", 3) else return f(opt1, opt2) end end end @@ -74,3 +75,5 @@ end function mime.stuff() return ltn12.filter.cycle(dot, 2) end + +base.setmetatable(mime, nil) diff --git a/src/smtp.lua b/src/smtp.lua index 974d222..9d49178 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -8,13 +8,16 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local coroutine = require("coroutine") +local string = require("string") +local math = require("math") +local os = require("os") local socket = require("socket") local tp = require("socket.tp") - local ltn12 = require("ltn12") local mime = require("mime") - -module("socket.smtp") +local smtp = module("socket.smtp") ----------------------------------------------------------------------------- -- Program constants @@ -98,8 +101,8 @@ end -- send message or throw an exception function metat.__index:send(mailt) self:mail(mailt.from) - if type(mailt.rcpt) == "table" then - for i,v in ipairs(mailt.rcpt) do + if base.type(mailt.rcpt) == "table" then + for i,v in base.ipairs(mailt.rcpt) do self:rcpt(v) end else @@ -110,7 +113,7 @@ end function open(server, port) local tp = socket.try(tp.connect(server or SERVER, port or PORT, TIMEOUT)) - local s = setmetatable({tp = tp}, metat) + local s = base.setmetatable({tp = tp}, metat) -- make sure tp is closed if we get an exception s.try = socket.newtry(function() if s.tp:command("QUIT") then s.tp:check("2..") end @@ -145,7 +148,7 @@ local function send_multipart(mesgt) coroutine.yield("\r\n") end -- send each part separated by a boundary - for i, m in ipairs(mesgt.body) do + for i, m in base.ipairs(mesgt.body) do coroutine.yield("\r\n--" .. bd .. "\r\n") send_message(m) end @@ -191,7 +194,7 @@ end -- yield the headers one by one local function send_headers(mesgt) if mesgt.headers then - for i,v in pairs(mesgt.headers) do + for i,v in base.pairs(mesgt.headers) do coroutine.yield(i .. ':' .. v .. "\r\n") end end @@ -200,8 +203,8 @@ end -- message source function send_message(mesgt) send_headers(mesgt) - if type(mesgt.body) == "table" then send_multipart(mesgt) - elseif type(mesgt.body) == "function" then send_source(mesgt) + if base.type(mesgt.body) == "table" then send_multipart(mesgt) + elseif base.type(mesgt.body) == "function" then send_source(mesgt) else send_string(mesgt) end end @@ -241,3 +244,5 @@ send = socket.protect(function(mailt) s:quit() return s:close() end) + +base.setmetatable(smtp, nil) diff --git a/src/tcp.c b/src/tcp.c index 746c4b6..618f4ce 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -233,7 +233,8 @@ static int meth_close(lua_State *L) { p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); sock_destroy(&tcp->sock); - return 0; + lua_pushnumber(L, 1); + return 1; } /*-------------------------------------------------------------------------*\ diff --git a/src/tp.lua b/src/tp.lua index ada00d2..0a671fb 100644 --- a/src/tp.lua +++ b/src/tp.lua @@ -8,10 +8,12 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +local base = require("base") +local string = require("string") local socket = require("socket") local ltn12 = require("ltn12") -module("socket.tp") +local tp = module("socket.tp") ----------------------------------------------------------------------------- -- Program constants @@ -47,22 +49,27 @@ local metat = { __index = {} } function metat.__index:check(ok) local code, reply = get_reply(self.c) if not code then return nil, reply end - if type(ok) ~= "function" then - if type(ok) == "table" then - for i, v in ipairs(ok) do - if string.find(code, v) then return tonumber(code), reply end + if base.type(ok) ~= "function" then + if base.type(ok) == "table" then + for i, v in base.ipairs(ok) do + if string.find(code, v) then + return base.tonumber(code), reply + end end return nil, reply else - if string.find(code, ok) then return tonumber(code), reply + if string.find(code, ok) then return base.tonumber(code), reply else return nil, reply end end - else return ok(tonumber(code), reply) end + else return ok(base.tonumber(code), reply) end end function metat.__index:command(cmd, arg) - if arg then return self.c:send(cmd .. " " .. arg.. "\r\n") - else return self.c:send(cmd .. "\r\n") end + if arg then + return self.c:send(cmd .. " " .. arg.. "\r\n") + else + return self.c:send(cmd .. "\r\n") + end end function metat.__index:sink(snk, pat) @@ -111,5 +118,7 @@ function connect(host, port, timeout) c:close() return nil, e end - return setmetatable({c = c}, metat) + return base.setmetatable({c = c}, metat) end + +base.setmetatable(tp, nil) diff --git a/src/udp.c b/src/udp.c index 97a6169..7a60080 100644 --- a/src/udp.c +++ b/src/udp.c @@ -288,7 +288,8 @@ static int meth_setpeername(lua_State *L) { static int meth_close(lua_State *L) { p_udp udp = (p_udp) aux_checkgroup(L, "udp{any}", 1); sock_destroy(&udp->sock); - return 0; + lua_pushnumber(L, 1); + return 1; } /*-------------------------------------------------------------------------*\ diff --git a/src/url.lua b/src/url.lua index efe7254..08081f0 100644 --- a/src/url.lua +++ b/src/url.lua @@ -8,7 +8,10 @@ ----------------------------------------------------------------------------- -- Declare module ----------------------------------------------------------------------------- -module("socket.url") +local string = require("string") +local base = require("base") +local table = require("table") +local url = module("socket.url") ----------------------------------------------------------------------------- -- Encodes a string into its escaped hexadecimal representation @@ -18,7 +21,7 @@ module("socket.url") -- escaped representation of string binary ----------------------------------------------------------------------------- function escape(s) - return string.gsub(s, "(.)", function(c) + return string.gsub(s, "([^A-Za-z0-9_])", function(c) return string.format("%%%02x", string.byte(c)) end) end @@ -33,7 +36,7 @@ end ----------------------------------------------------------------------------- local function make_set(t) local s = {} - for i = 1, table.getn(t) do + for i,v in base.ipairs(t) do s[t[i]] = 1 end return s @@ -62,7 +65,7 @@ end ----------------------------------------------------------------------------- function unescape(s) return string.gsub(s, "%%(%x%x)", function(hex) - return string.char(tonumber(hex, 16)) + return string.char(base.tonumber(hex, 16)) end) end @@ -191,7 +194,7 @@ end -- corresponding absolute url ----------------------------------------------------------------------------- function absolute(base_url, relative_url) - local base = type(base_url) == "table" and base_url or parse(base_url) + local base = base.type(base_url) == "table" and base_url or parse(base_url) local relative = parse(relative_url) if not base then return relative_url elseif not relative then return base_url @@ -269,3 +272,5 @@ function build_path(parsed, unsafe) if parsed.is_absolute then path = "/" .. path end return path end + +base.setmetatable(url, nil) diff --git a/src/wsocket.c b/src/wsocket.c index 1b169ed..0294dce 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -180,9 +180,10 @@ int sock_accept(p_sock ps, p_sock pa, SA *addr, socklen_t *len, p_tm tm) { /*-------------------------------------------------------------------------*\ * Send with timeout +* On windows, if you try to send 10MB, the OS will buffer EVERYTHING +* this can take an awful lot of time and we will end up blocked. +* Therefore, whoever calls this function should not pass a huge buffer. \*-------------------------------------------------------------------------*/ -/* has to be larger than UDP_DATAGRAMSIZE !!!*/ -#define MAXCHUNK (64*1024) int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, p_tm tm) { int err; @@ -192,9 +193,7 @@ int sock_send(p_sock ps, const char *data, size_t count, size_t *sent, p_tm tm) *sent = 0; for ( ;; ) { /* try to send something */ - /* on windows, if you try to send 10MB, the OS will buffer EVERYTHING - * this can take an awful lot of time and we will end up blocked. */ - int put = send(*ps, data, (count < MAXCHUNK)? (int)count: MAXCHUNK, 0); + int put = send(*ps, data, count, 0); /* if we sent something, we are done */ if (put > 0) { *sent = put; @@ -221,7 +220,7 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, if (*ps == SOCK_INVALID) return IO_CLOSED; *sent = 0; for ( ;; ) { - int put = send(*ps, data, (int) count, 0); + int put = sendto(*ps, data, (int) count, 0, addr, len); if (put > 0) { *sent = put; return IO_DONE; @@ -298,13 +297,13 @@ void sock_setnonblocking(p_sock ps) { int sock_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { *hp = gethostbyaddr(addr, len, AF_INET); if (*hp) return IO_DONE; - else return h_errno; + else return WSAGetLastError(); } int sock_gethostbyname(const char *addr, struct hostent **hp) { *hp = gethostbyname(addr); if (*hp) return IO_DONE; - else return h_errno; + else return WSAGetLastError(); } /*-------------------------------------------------------------------------*\ -- cgit v1.2.3-55-g6feb