From 58096449c6044b7aade5cd41cfd71c6bec1d273d Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Tue, 15 Jun 2004 06:24:00 +0000 Subject: Manual is almost done. HTTP is missing. Implemented new distribution scheme. Select is now purely C. HTTP reimplemented seems faster dunno why. LTN12 functions that coroutines fail gracefully. --- src/auxiliar.c | 50 +++++--- src/auxiliar.h | 25 ++-- src/buffer.c | 9 +- src/except.c | 52 +++++++++ src/except.h | 35 ++++++ src/ftp.lua | 73 ++++++------ src/http.lua | 351 ++++++++++++++++++++++++-------------------------------- src/inet.c | 1 - src/ltn12.lua | 21 ++-- src/luasocket.c | 71 +++++++++++- src/luasocket.h | 1 + src/mime.c | 12 +- src/mime.h | 1 + src/mime.lua | 14 +-- src/select.c | 221 +++++++++++++++++++++++------------ src/smtp.lua | 50 ++++---- src/socket.lua | 13 ++- src/tcp.c | 5 +- src/timeout.c | 8 ++ src/tp.lua | 26 +++-- src/udp.c | 5 +- src/url.lua | 4 +- 22 files changed, 621 insertions(+), 427 deletions(-) create mode 100644 src/except.c create mode 100644 src/except.h (limited to 'src') diff --git a/src/auxiliar.c b/src/auxiliar.c index b1f9203..9a37e10 100644 --- a/src/auxiliar.c +++ b/src/auxiliar.c @@ -7,7 +7,6 @@ #include #include -#include "luasocket.h" #include "auxiliar.h" /*=========================================================================*\ @@ -16,16 +15,15 @@ /*-------------------------------------------------------------------------*\ * Initializes the module \*-------------------------------------------------------------------------*/ -int aux_open(lua_State *L) -{ +int aux_open(lua_State *L) { return 0; } /*-------------------------------------------------------------------------*\ * Creates a new class with given methods +* Methods whose names start with __ are passed directly to the metatable. \*-------------------------------------------------------------------------*/ -void aux_newclass(lua_State *L, const char *classname, luaL_reg *func) -{ +void aux_newclass(lua_State *L, const char *classname, luaL_reg *func) { luaL_newmetatable(L, classname); /* mt */ /* create __index table to place methods */ lua_pushstring(L, "__index"); /* mt,"__index" */ @@ -45,11 +43,31 @@ void aux_newclass(lua_State *L, const char *classname, luaL_reg *func) lua_pop(L, 1); } +/*-------------------------------------------------------------------------*\ +* Prints the value of a class in a nice way +\*-------------------------------------------------------------------------*/ +int aux_tostring(lua_State *L) { + char buf[32]; + if (!lua_getmetatable(L, 1)) goto error; + lua_pushstring(L, "__index"); + lua_gettable(L, -2); + if (!lua_istable(L, -1)) goto error; + lua_pushstring(L, "class"); + lua_gettable(L, -2); + if (!lua_isstring(L, -1)) goto error; + sprintf(buf, "%p", lua_touserdata(L, 1)); + lua_pushfstring(L, "%s: %s", lua_tostring(L, -1), buf); + return 1; +error: + lua_pushstring(L, "invalid object passed to 'auxiliar.c:__tostring'"); + lua_error(L); + return 1; +} + /*-------------------------------------------------------------------------*\ * Insert class into group \*-------------------------------------------------------------------------*/ -void aux_add2group(lua_State *L, const char *classname, const char *groupname) -{ +void aux_add2group(lua_State *L, const char *classname, const char *groupname) { luaL_getmetatable(L, classname); lua_pushstring(L, groupname); lua_pushboolean(L, 1); @@ -60,8 +78,7 @@ void aux_add2group(lua_State *L, const char *classname, const char *groupname) /*-------------------------------------------------------------------------*\ * Make sure argument is a boolean \*-------------------------------------------------------------------------*/ -int aux_checkboolean(lua_State *L, int objidx) -{ +int aux_checkboolean(lua_State *L, int objidx) { if (!lua_isboolean(L, objidx)) luaL_typerror(L, objidx, lua_typename(L, LUA_TBOOLEAN)); return lua_toboolean(L, objidx); @@ -71,8 +88,7 @@ int aux_checkboolean(lua_State *L, int objidx) * Return userdata pointer if object belongs to a given class, abort with * error otherwise \*-------------------------------------------------------------------------*/ -void *aux_checkclass(lua_State *L, const char *classname, int objidx) -{ +void *aux_checkclass(lua_State *L, const char *classname, int objidx) { void *data = aux_getclassudata(L, classname, objidx); if (!data) { char msg[45]; @@ -86,8 +102,7 @@ void *aux_checkclass(lua_State *L, const char *classname, int objidx) * Return userdata pointer if object belongs to a given group, abort with * error otherwise \*-------------------------------------------------------------------------*/ -void *aux_checkgroup(lua_State *L, const char *groupname, int objidx) -{ +void *aux_checkgroup(lua_State *L, const char *groupname, int objidx) { void *data = aux_getgroupudata(L, groupname, objidx); if (!data) { char msg[45]; @@ -100,8 +115,7 @@ void *aux_checkgroup(lua_State *L, const char *groupname, int objidx) /*-------------------------------------------------------------------------*\ * Set object class \*-------------------------------------------------------------------------*/ -void aux_setclass(lua_State *L, const char *classname, int objidx) -{ +void aux_setclass(lua_State *L, const char *classname, int objidx) { luaL_getmetatable(L, classname); if (objidx < 0) objidx--; lua_setmetatable(L, objidx); @@ -111,8 +125,7 @@ void aux_setclass(lua_State *L, const char *classname, int objidx) * Get a userdata pointer if object belongs to a given group. Return NULL * otherwise \*-------------------------------------------------------------------------*/ -void *aux_getgroupudata(lua_State *L, const char *groupname, int objidx) -{ +void *aux_getgroupudata(lua_State *L, const char *groupname, int objidx) { if (!lua_getmetatable(L, objidx)) return NULL; lua_pushstring(L, groupname); @@ -130,7 +143,6 @@ void *aux_getgroupudata(lua_State *L, const char *groupname, int objidx) * Get a userdata pointer if object belongs to a given class. Return NULL * otherwise \*-------------------------------------------------------------------------*/ -void *aux_getclassudata(lua_State *L, const char *classname, int objidx) -{ +void *aux_getclassudata(lua_State *L, const char *classname, int objidx) { return luaL_checkudata(L, objidx, classname); } diff --git a/src/auxiliar.h b/src/auxiliar.h index bc45182..70f4704 100644 --- a/src/auxiliar.h +++ b/src/auxiliar.h @@ -2,26 +2,28 @@ #define AUX_H /*=========================================================================*\ * Auxiliar routines for class hierarchy manipulation -* LuaSocket toolkit +* LuaSocket toolkit (but completely independent of other LuaSocket modules) * * A LuaSocket class is a name associated with Lua metatables. A LuaSocket -* group is a name associated to a class. A class can belong to any number +* group is a name associated with a class. A class can belong to any number * of groups. This module provides the functionality to: * * - create new classes * - add classes to groups -* - set the class of object +* - set the class of objects * - check if an object belongs to a given class or group +* - get the userdata associated to objects +* - print objects in a pretty way * * LuaSocket class names follow the convention {}. Modules * can define any number of classes and groups. The module tcp.c, for * example, defines the classes tcp{master}, tcp{client} and tcp{server} and -* the groups tcp{client, server} and tcp{any}. Module functions can then -* perform type-checking on it's arguments by either class or group. +* the groups tcp{client,server} and tcp{any}. Module functions can then +* perform type-checking on their arguments by either class or group. * * LuaSocket metatables define the __index metamethod as being a table. This -* table has one field for each method supported by the class. In DEBUG -* mode, it also has one field with the class name. +* table has one field for each method supported by the class, and a field +* "class" with the class name. * * The mapping from class name to the corresponding metatable and the * reverse mapping are done using lauxlib. @@ -32,14 +34,6 @@ #include #include -/* min and max macros */ -#ifndef MIN -#define MIN(x, y) ((x) < (y) ? x : y) -#endif -#ifndef MAX -#define MAX(x, y) ((x) > (y) ? x : y) -#endif - int aux_open(lua_State *L); void aux_newclass(lua_State *L, const char *classname, luaL_reg *func); void aux_add2group(lua_State *L, const char *classname, const char *group); @@ -49,5 +43,6 @@ void *aux_checkgroup(lua_State *L, const char *groupname, int objidx); void *aux_getclassudata(lua_State *L, const char *groupname, int objidx); void *aux_getgroupudata(lua_State *L, const char *groupname, int objidx); int aux_checkboolean(lua_State *L, int objidx); +int aux_tostring(lua_State *L); #endif /* AUX_H */ diff --git a/src/buffer.c b/src/buffer.c index b771047..fd885a2 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -7,7 +7,6 @@ #include #include -#include "auxiliar.h" #include "buffer.h" /*=========================================================================*\ @@ -20,6 +19,14 @@ static int buf_get(p_buf buf, const char **data, size_t *count); static void buf_skip(p_buf buf, size_t count); static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent); +/* min and max macros */ +#ifndef MIN +#define MIN(x, y) ((x) < (y) ? x : y) +#endif +#ifndef MAX +#define MAX(x, y) ((x) > (y) ? x : y) +#endif + /*=========================================================================*\ * Exported functions \*=========================================================================*/ diff --git a/src/except.c b/src/except.c new file mode 100644 index 0000000..c9eb20e --- /dev/null +++ b/src/except.c @@ -0,0 +1,52 @@ +#include +#include + +#include "except.h" + +static int global_try(lua_State *L); +static int global_protect(lua_State *L); +static int protected(lua_State *L); + +static luaL_reg func[] = { + {"try", global_try}, + {"protect", global_protect}, + {NULL, NULL} +}; + +/*-------------------------------------------------------------------------*\ +* Exception handling: try method +\*-------------------------------------------------------------------------*/ +static int global_try(lua_State *L) { + if (lua_isnil(L, 1) || (lua_isboolean(L, 1) && !lua_toboolean(L, 1))) { + lua_settop(L, 2); + lua_error(L); + return 0; + } else return lua_gettop(L); +} + +/*-------------------------------------------------------------------------*\ +* Exception handling: protect factory +\*-------------------------------------------------------------------------*/ +static int protected(lua_State *L) { + lua_pushvalue(L, lua_upvalueindex(1)); + lua_insert(L, 1); + if (lua_pcall(L, lua_gettop(L) - 1, LUA_MULTRET, 0) != 0) { + lua_pushnil(L); + lua_insert(L, 1); + return 2; + } else return lua_gettop(L); +} + +static int global_protect(lua_State *L) { + lua_insert(L, 1); + lua_pushcclosure(L, protected, 1); + return 1; +} + +/*-------------------------------------------------------------------------*\ +* Init module +\*-------------------------------------------------------------------------*/ +int except_open(lua_State *L) { + luaL_openlib(L, NULL, func, 0); + return 0; +} diff --git a/src/except.h b/src/except.h new file mode 100644 index 0000000..2c57b27 --- /dev/null +++ b/src/except.h @@ -0,0 +1,35 @@ +#ifndef EXCEPT_H +#define EXCEPT_H +/*=========================================================================*\ +* Exception control +* LuaSocket toolkit (but completely independent from other modules) +* +* This provides support for simple exceptions in Lua. During the +* development of the HTTP/FTP/SMTP support, it became aparent that +* error checking was taking a substantial amount of the coding. These +* function greatly simplify the task of checking errors. +* +* The main idea is that functions should return nil as its first return +* value when it finds an error, and return an error message (or value) +* following nil. In case of success, as long as the first value is not nil, +* the other values don't matter. +* +* The idea is to nest function calls with the "try" function. This function +* checks the first value, and calls "error" on the second if the first is +* nil. Otherwise, it returns all values it received. +* +* The protect function returns a new function that behaves exactly like the +* function it receives, but the new function doesn't throw exceptions: it +* returns nil followed by the error message instead. +* +* With these two function, it's easy to write functions that throw +* exceptions on error, but that don't interrupt the user script. +* +* RCS ID: $Id$ +\*=========================================================================*/ + +#include + +int except_open(lua_State *L); + +#endif diff --git a/src/ftp.lua b/src/ftp.lua index 79772f8..c130d1a 100644 --- a/src/ftp.lua +++ b/src/ftp.lua @@ -7,7 +7,7 @@ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- --- Load other required modules +-- Load required modules ----------------------------------------------------------------------------- local socket = require("socket") local ltn12 = require("ltn12") @@ -17,10 +17,7 @@ local tp = require("tp") ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- -local ftp = {} --- make all module globals fall into namespace -setmetatable(ftp, { __index = _G }) -setfenv(1, ftp) +_LOADED["ftp"] = getfenv(1) ----------------------------------------------------------------------------- -- Program constants @@ -32,9 +29,7 @@ PORT = 21 -- this is the default anonymous password. used when no password is -- provided in url. should be changed to your e-mail. USER = "ftp" -EMAIL = "anonymous@anonymous.org" --- block size used in transfers -BLOCKSIZE = 2048 +PASSWORD = "anonymous@anonymous.org" ----------------------------------------------------------------------------- -- Low level FTP API @@ -42,7 +37,7 @@ BLOCKSIZE = 2048 local metat = { __index = {} } function open(server, port) - local tp = socket.try(socket.tp.connect(server, port or PORT)) + local tp = socket.try(tp.connect(server, port or PORT, TIMEOUT)) return setmetatable({tp = tp}, metat) end @@ -51,14 +46,17 @@ local function port(portt) end local function pasv(pasvt) - return socket.connect(pasvt.ip, pasvt.port) + local data = socket.try(socket.tcp()) + socket.try(data:settimeout(TIMEOUT)) + socket.try(data:connect(pasvt.ip, pasvt.port)) + return data end function metat.__index:login(user, password) socket.try(self.tp:command("user", user or USER)) local code, reply = socket.try(self.tp:check{"2..", 331}) if code == 331 then - socket.try(self.tp:command("pass", password or EMAIL)) + socket.try(self.tp:command("pass", password or PASSWORD)) socket.try(self.tp:check("2..")) end return 1 @@ -104,6 +102,7 @@ function metat.__index:send(sendt) socket.try(self.pasvt or self.portt, "need port or pasv first") if self.pasvt then data = socket.try(pasv(self.pasvt)) end local argument = sendt.argument or string.gsub(sendt.path, "^/", "") + if argument == "" then argument = nil end local command = sendt.command or "stor" socket.try(self.tp:command(command, argument)) local code, reply = socket.try(self.tp:check{"2..", "1.."}) @@ -133,6 +132,7 @@ function metat.__index:receive(recvt) socket.try(self.pasvt or self.portt, "need port or pasv first") if self.pasvt then data = socket.try(pasv(self.pasvt)) end local argument = recvt.argument or string.gsub(recvt.path, "^/", "") + if argument == "" then argument = nil end local command = recvt.command or "retr" socket.try(self.tp:command(command, argument)) local code = socket.try(self.tp:check{"1..", "2.."}) @@ -182,14 +182,14 @@ end -- High level FTP API ----------------------------------------------------------------------------- local function tput(putt) - local ftp = socket.ftp.open(putt.host, putt.port) - ftp:greet() - ftp:login(putt.user, putt.password) - if putt.type then ftp:type(putt.type) end - ftp:pasv() - ftp:send(putt) - ftp:quit() - return ftp:close() + local con = ftp.open(putt.host, putt.port) + con:greet() + con:login(putt.user, putt.password) + if putt.type then con:type(putt.type) end + con:pasv() + con:send(putt) + con:quit() + return con:close() end local default = { @@ -198,15 +198,16 @@ local default = { } local function parse(u) - local putt = socket.try(url.parse(u, default)) - socket.try(putt.scheme == "ftp", "invalid scheme '" .. putt.scheme .. "'") - socket.try(putt.host, "invalid host") + local t = socket.try(url.parse(u, default)) + socket.try(t.scheme == "ftp", "invalid scheme '" .. t.scheme .. "'") + socket.try(t.host, "invalid host") local pat = "^type=(.)$" - if putt.params then - putt.type = socket.skip(2, string.find(putt.params, pat)) - socket.try(putt.type == "a" or putt.type == "i") + if t.params then + t.type = socket.skip(2, string.find(t.params, pat)) + socket.try(t.type == "a" or t.type == "i", + "invalid type '" .. t.type .. "'") end - return putt + return t end local function sput(u, body) @@ -221,17 +222,17 @@ put = socket.protect(function(putt, body) end) local function tget(gett) - local ftp = socket.ftp.open(gett.host, gett.port) - ftp:greet() - ftp:login(gett.user, gett.password) - if gett.type then ftp:type(gett.type) end - ftp:pasv() - ftp:receive(gett) - ftp:quit() - return ftp:close() + local con = ftp.open(gett.host, gett.port) + con:greet() + con:login(gett.user, gett.password) + if gett.type then con:type(gett.type) end + con:pasv() + con:receive(gett) + con:quit() + return con:close() end -local function sget(u, body) +local function sget(u) local gett = parse(u) local t = {} gett.sink = ltn12.sink.table(t) @@ -240,7 +241,7 @@ local function sget(u, body) end get = socket.protect(function(gett) - if type(gett) == "string" then return sget(gett, body) + if type(gett) == "string" then return sget(gett) else return tget(gett) end end) diff --git a/src/http.lua b/src/http.lua index ebe6b54..129b562 100644 --- a/src/http.lua +++ b/src/http.lua @@ -7,7 +7,7 @@ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- --- Load other required modules +-- Load required modules ------------------------------------------------------------------------------- local socket = require("socket") local ltn12 = require("ltn12") @@ -17,42 +17,68 @@ local url = require("url") ----------------------------------------------------------------------------- -- Setup namespace ------------------------------------------------------------------------------- -http = {} --- make all module globals fall into namespace -setmetatable(http, { __index = _G }) -setfenv(1, http) +_LOADED["http"] = getfenv(1) ----------------------------------------------------------------------------- -- Program constants ----------------------------------------------------------------------------- -- connection timeout in seconds -TIMEOUT = 60 +TIMEOUT = 4 -- default port for document retrieval PORT = 80 -- user agent field sent in request -USERAGENT = socket.version +USERAGENT = socket.VERSION -- block size used in transfers BLOCKSIZE = 2048 ----------------------------------------------------------------------------- --- Function return value selectors +-- Low level HTTP API ----------------------------------------------------------------------------- -local function second(a, b) - return b +local metat = { __index = {} } + +function open(host, port) + local con = socket.try(socket.tcp()) + socket.try(con:settimeout(TIMEOUT)) + port = port or PORT + socket.try(con:connect(host, port)) + return setmetatable({ con = con }, metat) +end + +function metat.__index:sendrequestline(method, uri) + local reqline = string.format("%s %s HTTP/1.1\r\n", method or "GET", uri) + return socket.try(self.con:send(reqline)) end -local function third(a, b, c) - return c +function metat.__index:sendheaders(headers) + for i, v in pairs(headers) do + socket.try(self.con:send(i .. ": " .. v .. "\r\n")) + end + -- mark end of request headers + socket.try(self.con:send("\r\n")) + return 1 end -local function receive_headers(reqt, respt, tmp) - local sock = tmp.sock +function metat.__index:sendbody(headers, source, step) + source = source or ltn12.source.empty() + step = step or ltn12.pump.step + -- if we don't know the size in advance, send chunked and hope for the best + local mode + if headers["content-length"] then mode = "keep-open" + else mode = "http-chunked" end + return socket.try(ltn12.pump.all(source, socket.sink(mode, self.con), step)) +end + +function metat.__index:receivestatusline() + local status = socket.try(self.con:receive()) + local code = socket.skip(2, string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + return socket.try(tonumber(code), status) +end + +function metat.__index:receiveheaders() local line, name, value local headers = {} - -- store results - respt.headers = headers -- get first line - line = socket.try(sock:receive()) + line = socket.try(self.con:receive()) -- headers go until a blank line is found while line ~= "" do -- get field-name and value @@ -60,189 +86,137 @@ local function receive_headers(reqt, respt, tmp) socket.try(name and value, "malformed reponse headers") name = string.lower(name) -- get next line (value might be folded) - line = socket.try(sock:receive()) + line = socket.try(self.con:receive()) -- unfold any folded values while string.find(line, "^%s") do value = value .. line - line = socket.try(sock:receive()) + line = socket.try(self.con:receive()) end -- save pair in table if headers[name] then headers[name] = headers[name] .. ", " .. value else headers[name] = value end end + return headers end -local function receive_body(reqt, respt, tmp) - local sink = reqt.sink or ltn12.sink.null() - local step = reqt.step or ltn12.pump.step - local source - local te = respt.headers["transfer-encoding"] - if te and te ~= "identity" then - -- get by chunked transfer-coding of message body - source = socket.source("http-chunked", tmp.sock) - elseif tonumber(respt.headers["content-length"]) then - -- get by content-length - local length = tonumber(respt.headers["content-length"]) - source = socket.source("by-length", tmp.sock, length) - else - -- get it all until connection closes - source = socket.source(tmp.sock) - end - socket.try(ltn12.pump.all(source, sink, step)) +function metat.__index:receivebody(headers, sink, step) + sink = sink or ltn12.sink.null() + step = step or ltn12.pump.step + local length = tonumber(headers["content-length"]) + local TE = headers["transfer-encoding"] + local mode + if TE and TE ~= "identity" then mode = "http-chunked" + elseif tonumber(headers["content-length"]) then mode = "by-length" + else mode = "default" end + return socket.try(ltn12.pump.all(socket.source(mode, self.con, length), + sink, step)) end -local function send_headers(sock, headers) - -- send request headers - for i, v in pairs(headers) do - socket.try(sock:send(i .. ": " .. v .. "\r\n")) - end - -- mark end of request headers - socket.try(sock:send("\r\n")) +function metat.__index:close() + return self.con:close() end -local function should_receive_body(reqt, respt, tmp) - if reqt.method == "HEAD" then return nil end - if respt.code == 204 or respt.code == 304 then return nil end - if respt.code >= 100 and respt.code < 200 then return nil end - return 1 -end - -local function receive_status(reqt, respt, tmp) - local status = socket.try(tmp.sock:receive()) - local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) - -- store results - respt.code, respt.status = tonumber(code), status -end - -local function request_uri(reqt, respt, tmp) - local u = tmp.parsed - if not reqt.proxy then - local parsed = tmp.parsed - u = { - path = parsed.path, - params = parsed.params, - query = parsed.query, - fragment = parsed.fragment +----------------------------------------------------------------------------- +-- High level HTTP API +----------------------------------------------------------------------------- +local function uri(reqt) + local u = reqt + if not reqt.proxy and not PROXY then + u = { + path = reqt.path, + params = reqt.params, + query = reqt.query, + fragment = reqt.fragment } end return url.build(u) end -local function send_request(reqt, respt, tmp) - local uri = request_uri(reqt, respt, tmp) - local headers = tmp.headers - local step = reqt.step or ltn12.pump.step - -- send request line - socket.try(tmp.sock:send((reqt.method or "GET") - .. " " .. uri .. " HTTP/1.1\r\n")) - if reqt.source and not headers["content-length"] then - headers["transfer-encoding"] = "chunked" - end - send_headers(tmp.sock, headers) - -- send request message body, if any - if not reqt.source then return end - if headers["content-length"] then - socket.try(ltn12.pump.all(reqt.source, - socket.sink(tmp.sock), step)) - else - socket.try(ltn12.pump.all(reqt.source, - socket.sink("http-chunked", tmp.sock), step)) - end -end - -local function open(reqt, respt, tmp) - local proxy = reqt.proxy or PROXY - local host, port - if proxy then - local pproxy = url.parse(proxy) - socket.try(pproxy.port and pproxy.host, "invalid proxy") - host, port = pproxy.host, pproxy.port - else - host, port = tmp.parsed.host, tmp.parsed.port - end - -- store results - tmp.sock = socket.try(socket.tcp()) - socket.try(tmp.sock:settimeout(reqt.timeout or TIMEOUT)) - socket.try(tmp.sock:connect(host, port)) -end - -local function adjust_headers(reqt, respt, tmp) +local function adjustheaders(headers, host) local lower = {} -- override with user values - for i,v in (reqt.headers or lower) do + for i,v in (headers or lower) do lower[string.lower(i)] = v end lower["user-agent"] = lower["user-agent"] or USERAGENT -- these cannot be overriden - lower["host"] = tmp.parsed.host - lower["connection"] = "close" - -- store results - tmp.headers = lower + lower["host"] = host + return lower end -local function parse_url(reqt, respt, tmp) +local function adjustrequest(reqt) -- parse url with default fields local parsed = url.parse(reqt.url, { host = "", - port = PORT, + port = PORT, path ="/", - scheme = "http" + scheme = "http" }) - -- scheme has to be http - socket.try(parsed.scheme == "http", - string.format("unknown scheme '%s'", parsed.scheme)) - -- explicit authentication info overrides that given by the URL - parsed.user = reqt.user or parsed.user - parsed.password = reqt.password or parsed.password - -- store results - tmp.parsed = parsed + -- explicit info in reqt overrides that given by the URL + for i,v in reqt do parsed[i] = v end + -- compute uri if user hasn't overriden + parsed.uri = parsed.uri or uri(parsed) + -- adjust headers in request + parsed.headers = adjustheaders(parsed.headers, parsed.host) + return parsed end --- forward declaration -local request_p +local function shouldredirect(reqt, respt) + return (reqt.redirect ~= false) and + (respt.code == 301 or respt.code == 302) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not reqt.nredirects or reqt.nredirects < 5) +end -local function should_authorize(reqt, respt, tmp) +local function shouldauthorize(reqt, respt) -- if there has been an authorization attempt, it must have failed if reqt.headers and reqt.headers["authorization"] then return nil end -- if last attempt didn't fail due to lack of authentication, -- or we don't have authorization information, we can't retry - return respt.code == 401 and tmp.parsed.user and tmp.parsed.password + return respt.code == 401 and reqt.user and reqt.password end -local function clone(headers) - if not headers then return nil end - local copy = {} - for i,v in pairs(headers) do - copy[i] = v - end - return copy +local function shouldreceivebody(reqt, respt) + if reqt.method == "HEAD" then return nil end + local code = respt.code + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return 1 end -local function authorize(reqt, respt, tmp) - local headers = clone(reqt.headers) or {} - headers["authorization"] = "Basic " .. - (mime.b64(tmp.parsed.user .. ":" .. tmp.parsed.password)) - local autht = { - method = reqt.method, - url = reqt.url, - source = reqt.source, - sink = reqt.sink, - headers = headers, - timeout = reqt.timeout, - proxy = reqt.proxy, - } - request_p(autht, respt, tmp) +local requestp, authorizep, redirectp + +function requestp(reqt) + local reqt = adjustrequest(reqt) + local respt = {} + local con = open(reqt.host, reqt.port) + con:sendrequestline(reqt.method, reqt.uri) + con:sendheaders(reqt.headers) + con:sendbody(reqt.headers, reqt.source, reqt.step) + respt.code, respt.status = con:receivestatusline() + respt.headers = con:receiveheaders() + if shouldredirect(reqt, respt) then + con:close() + return redirectp(reqt, respt) + elseif shouldauthorize(reqt, respt) then + con:close() + return authorizep(reqt, respt) + elseif shouldreceivebody(reqt, respt) then + con:receivebody(respt.headers, reqt.sink, reqt.step) + end + con:close() + return respt end -local function should_redirect(reqt, respt, tmp) - return (reqt.redirect ~= false) and - (respt.code == 301 or respt.code == 302) and - (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") - and (not tmp.nredirects or tmp.nredirects < 5) +function authorizep(reqt, respt) + local auth = "Basic " .. (mime.b64(reqt.user .. ":" .. reqt.password)) + reqt.headers["authorization"] = auth + return requestp(reqt) end -local function redirect(reqt, respt, tmp) - tmp.nredirects = (tmp.nredirects or 0) + 1 +function redirectp(reqt, respt) + -- we create a new table to get rid of anything we don't + -- absolutely need, including authentication info local redirt = { method = reqt.method, -- the RFC says the redirect URL has to be absolute, but some @@ -251,69 +225,38 @@ local function redirect(reqt, respt, tmp) source = reqt.source, sink = reqt.sink, headers = reqt.headers, - timeout = reqt.timeout, - proxy = reqt.proxy + proxy = reqt.proxy, + nredirects = (reqt.nredirects or 0) + 1 } - request_p(redirt, respt, tmp) + respt = requestp(redirt) -- we pass the location header as a clue we redirected if respt.headers then respt.headers.location = redirt.url end -end - -local function skip_continue(reqt, respt, tmp) - if respt.code == 100 then - receive_status(reqt, respt, tmp) - end -end - --- execute a request of through an exception -function request_p(reqt, respt, tmp) - parse_url(reqt, respt, tmp) - adjust_headers(reqt, respt, tmp) - open(reqt, respt, tmp) - send_request(reqt, respt, tmp) - receive_status(reqt, respt, tmp) - skip_continue(reqt, respt, tmp) - receive_headers(reqt, respt, tmp) - if should_redirect(reqt, respt, tmp) then - tmp.sock:close() - redirect(reqt, respt, tmp) - elseif should_authorize(reqt, respt, tmp) then - tmp.sock:close() - authorize(reqt, respt, tmp) - elseif should_receive_body(reqt, respt, tmp) then - receive_body(reqt, respt, tmp) - end -end - -function request(reqt) - local respt, tmp = {}, {} - local s, e = pcall(request_p, reqt, respt, tmp) - if not s then respt.error = e end - if tmp.sock then tmp.sock:close() end return respt end -function get(u) +request = socket.protect(requestp) + +get = socket.protect(function(u) local t = {} - respt = request { - url = u, - sink = ltn12.sink.table(t) + local respt = requestp { + url = u, + sink = ltn12.sink.table(t) } - return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, - respt.code, respt.error -end + return (table.getn(t) > 0 or nil) and table.concat(t), respt.headers, + respt.code +end) -function post(u, body) +post = socket.protect(function(u, body) local t = {} - respt = request { - url = u, - method = "POST", + local respt = requestp { + url = u, + method = "POST", source = ltn12.source.string(body), sink = ltn12.sink.table(t), - headers = { ["content-length"] = string.len(body) } + headers = { ["content-length"] = string.len(body) } } - return (table.getn(t) > 0 or nil) and table.concat(t), - respt.headers, respt.code, respt.error -end + return (table.getn(t) > 0 or nil) and table.concat(t), + respt.headers, respt.code +end) return http diff --git a/src/inet.c b/src/inet.c index 3a57441..62c67f1 100644 --- a/src/inet.c +++ b/src/inet.c @@ -10,7 +10,6 @@ #include #include -#include "luasocket.h" #include "inet.h" /*=========================================================================*\ diff --git a/src/ltn12.lua b/src/ltn12.lua index 41855f0..6228247 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -8,9 +8,8 @@ ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- -local ltn12 = {} -setmetatable(ltn12, { __index = _G }) -setfenv(1, ltn12) +_LOADED["ltn12"] = getfenv(1) + filter = {} source = {} sink = {} @@ -19,10 +18,6 @@ pump = {} -- 2048 seems to be better in windows... BLOCKSIZE = 2048 -local function shift(a, b, c) - return b, c -end - ----------------------------------------------------------------------------- -- Filter stuff ----------------------------------------------------------------------------- @@ -53,7 +48,9 @@ local function chain2(f1, f2) end end) return function(chunk) - return shift(coroutine.resume(co, chunk)) + local ret, a, b = coroutine.resume(co, chunk) + if ret then return a, b + else return nil, a end end end @@ -149,7 +146,9 @@ function source.chain(src, f) end end) return function() - return shift(coroutine.resume(co)) + local ret, a, b = coroutine.resume(co) + if ret then return a, b + else return nil, a end end end @@ -166,7 +165,9 @@ function source.cat(...) end end) return function() - return shift(coroutine.resume(co)) + local ret, a, b = coroutine.resume(co) + if ret then return a, b + else return nil, a end end end diff --git a/src/luasocket.c b/src/luasocket.c index ca3a52c..2b0a1fa 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -26,7 +26,7 @@ #include "luasocket.h" #include "auxiliar.h" -#include "base.h" +#include "except.h" #include "timeout.h" #include "buffer.h" #include "inet.h" @@ -35,11 +35,18 @@ #include "select.h" /*-------------------------------------------------------------------------*\ -* Modules +* Internal function prototypes +\*-------------------------------------------------------------------------*/ +static int global_skip(lua_State *L); +static int global_unload(lua_State *L); +static int base_open(lua_State *L); + +/*-------------------------------------------------------------------------*\ +* Modules and functions \*-------------------------------------------------------------------------*/ static const luaL_reg mod[] = { {"auxiliar", aux_open}, - {"base", base_open}, + {"except", except_open}, {"timeout", tm_open}, {"buffer", buf_open}, {"inet", inet_open}, @@ -49,11 +56,69 @@ static const luaL_reg mod[] = { {NULL, NULL} }; +static luaL_reg func[] = { + {"skip", global_skip}, + {"__unload", global_unload}, + {NULL, NULL} +}; + +/*-------------------------------------------------------------------------*\ +* Skip a few arguments +\*-------------------------------------------------------------------------*/ +static int global_skip(lua_State *L) { + int amount = luaL_checkint(L, 1); + int ret = lua_gettop(L) - amount - 1; + return ret >= 0 ? ret : 0; +} + +/*-------------------------------------------------------------------------*\ +* Unloads the library +\*-------------------------------------------------------------------------*/ +static int global_unload(lua_State *L) { + sock_close(); + return 0; +} + +/*-------------------------------------------------------------------------*\ +* Setup basic stuff. +\*-------------------------------------------------------------------------*/ +static int base_open(lua_State *L) { + if (sock_open()) { + /* whoever is loading the library replaced the global environment + * with the namespace table */ + lua_pushvalue(L, LUA_GLOBALSINDEX); + /* make sure library is still "requirable" if initialized staticaly */ + lua_pushstring(L, "_LOADEDLIB"); + lua_gettable(L, -2); + lua_pushstring(L, LUASOCKET_LIBNAME); + lua_pushcfunction(L, (lua_CFunction) luaopen_socket); + lua_settable(L, -3); + lua_pop(L, 1); +#ifdef LUASOCKET_DEBUG + lua_pushstring(L, "DEBUG"); + lua_pushboolean(L, 1); + lua_rawset(L, -3); +#endif + /* make version string available to scripts */ + lua_pushstring(L, "VERSION"); + lua_pushstring(L, LUASOCKET_VERSION); + lua_rawset(L, -3); + /* export other functions */ + luaL_openlib(L, NULL, func, 0); + return 1; + } else { + lua_pushstring(L, "unable to initialize library"); + lua_error(L); + return 0; + } +} + /*-------------------------------------------------------------------------*\ * Initializes all library modules. \*-------------------------------------------------------------------------*/ LUASOCKET_API int luaopen_socket(lua_State *L) { int i; + base_open(L); for (i = 0; mod[i].name; i++) mod[i].func(L); return 1; } diff --git a/src/luasocket.h b/src/luasocket.h index 716b7ff..6d30605 100644 --- a/src/luasocket.h +++ b/src/luasocket.h @@ -25,6 +25,7 @@ /*-------------------------------------------------------------------------*\ * Initializes the library. \*-------------------------------------------------------------------------*/ +#define LUASOCKET_LIBNAME "socket" LUASOCKET_API int luaopen_socket(lua_State *L); #endif /* LUASOCKET_H */ diff --git a/src/mime.c b/src/mime.c index f42528c..5750714 100644 --- a/src/mime.c +++ b/src/mime.c @@ -76,7 +76,17 @@ static UC b64unbase[256]; \*-------------------------------------------------------------------------*/ MIME_API int luaopen_mime(lua_State *L) { - lua_newtable(L); + /* whoever is loading the library replaced the global environment + * with the namespace table */ + lua_pushvalue(L, LUA_GLOBALSINDEX); + /* make sure library is still "requirable" if initialized staticaly */ + lua_pushstring(L, "_LOADEDLIB"); + lua_gettable(L, -2); + lua_pushstring(L, MIME_LIBNAME); + lua_pushcfunction(L, (lua_CFunction) luaopen_mime); + lua_settable(L, -3); + lua_pop(L, 1); + /* export functions */ luaL_openlib(L, NULL, func, 0); /* initialize lookup tables */ qpsetup(qpclass, qpunbase); diff --git a/src/mime.h b/src/mime.h index 6febedf..b82d61a 100644 --- a/src/mime.h +++ b/src/mime.h @@ -19,6 +19,7 @@ #define MIME_API extern #endif +#define MIME_LIBNAME "mime" MIME_API int luaopen_mime(lua_State *L); #endif /* MIME_H */ diff --git a/src/mime.lua b/src/mime.lua index ecf310d..000404f 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -5,24 +5,16 @@ -- RCS ID: $Id$ ----------------------------------------------------------------------------- ------------------------------------------------------------------------------ --- Load MIME from dynamic library --- Comment these lines if you are loading static ------------------------------------------------------------------------------ -local open = assert(loadlib("mime", "luaopen_mime")) -local mime = assert(open()) - ----------------------------------------------------------------------------- -- Load other required modules ----------------------------------------------------------------------------- +local mime = requirelib("mime", "luaopen_mime", getfenv(1)) local ltn12 = require("ltn12") ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- --- make all module globals fall into mime namespace -setmetatable(mime, { __index = _G }) -setfenv(1, mime) +_LOADED["mime"] = mime -- encode, decode and wrap algorithm tables encodet = {} @@ -48,7 +40,7 @@ end encodet['quoted-printable'] = function(mode) return ltn12.filter.cycle(qp, "", - (mode == "binary") and "=0D=0A" or "\13\10") + (mode == "binary") and "=0D=0A" or "\r\n") end -- define the decoding filters diff --git a/src/select.c b/src/select.c index 1ebd82c..13f9d6e 100644 --- a/src/select.c +++ b/src/select.c @@ -9,26 +9,21 @@ #include #include -#include "luasocket.h" #include "socket.h" -#include "auxiliar.h" #include "select.h" /*=========================================================================*\ * Internal function prototypes. \*=========================================================================*/ -static int meth_set(lua_State *L); -static int meth_isset(lua_State *L); -static int c_select(lua_State *L); +static int getfd(lua_State *L); +static int dirty(lua_State *L); +static int collect_fd(lua_State *L, int tab, int max_fd, int itab, fd_set *set); +static int check_dirty(lua_State *L, int tab, int dtab, fd_set *set); +static void return_fd(lua_State *L, fd_set *set, int max_fd, + int itab, int tab, int start); +static void make_assoc(lua_State *L, int tab); static int global_select(lua_State *L); -/* fd_set object methods */ -static luaL_reg set[] = { - {"set", meth_set}, - {"isset", meth_isset}, - {NULL, NULL} -}; - /* functions in library namespace */ static luaL_reg func[] = { {"select", global_select}, @@ -36,22 +31,13 @@ static luaL_reg func[] = { }; /*=========================================================================*\ -* Internal function prototypes. +* Exported functions \*=========================================================================*/ /*-------------------------------------------------------------------------*\ * Initializes module \*-------------------------------------------------------------------------*/ -int select_open(lua_State *L) -{ - /* get select auxiliar lua function from lua code and register - * pass it as an upvalue to global_select */ -#ifdef LUASOCKET_COMPILED -#include "select.lch" -#else - lua_dofile(L, "select.lua"); -#endif - luaL_openlib(L, NULL, func, 1); - aux_newclass(L, "select{fd_set}", set); +int select_open(lua_State *L) { + luaL_openlib(L, NULL, func, 0); return 0; } @@ -61,64 +47,149 @@ int select_open(lua_State *L) /*-------------------------------------------------------------------------*\ * Waits for a set of sockets until a condition is met or timeout. \*-------------------------------------------------------------------------*/ -static int global_select(lua_State *L) -{ - fd_set *read_fd_set, *write_fd_set; - /* make sure we have enough arguments (nil is the default) */ +static int global_select(lua_State *L) { + int timeout, rtab, wtab, itab, max_fd, ret, ndirty; + fd_set rset, wset; + FD_ZERO(&rset); FD_ZERO(&wset); lua_settop(L, 3); - /* check timeout */ - if (!lua_isnil(L, 3) && !lua_isnumber(L, 3)) - luaL_argerror(L, 3, "number or nil expected"); - /* select auxiliar lua function to be called comes first */ - lua_pushvalue(L, lua_upvalueindex(1)); - lua_insert(L, 1); - /* pass fd_set objects */ - read_fd_set = (fd_set *) lua_newuserdata(L, sizeof(fd_set)); - FD_ZERO(read_fd_set); - aux_setclass(L, "select{fd_set}", -1); - write_fd_set = (fd_set *) lua_newuserdata(L, sizeof(fd_set)); - FD_ZERO(write_fd_set); - aux_setclass(L, "select{fd_set}", -1); - /* pass select auxiliar C function */ - lua_pushcfunction(L, c_select); - /* call select auxiliar lua function */ - lua_call(L, 6, 3); - return 3; + timeout = lua_isnil(L, 3) ? -1 : (int)(luaL_checknumber(L, 3) * 1000); + lua_newtable(L); itab = lua_gettop(L); + lua_newtable(L); rtab = lua_gettop(L); + lua_newtable(L); wtab = lua_gettop(L); + max_fd = collect_fd(L, 1, -1, itab, &rset); + ndirty = check_dirty(L, 1, rtab, &rset); + timeout = ndirty > 0? 0: timeout; + max_fd = collect_fd(L, 2, max_fd, itab, &wset); + ret = sock_select(max_fd+1, &rset, &wset, NULL, timeout); + if (ret > 0 || (ret == 0 && ndirty > 0)) { + return_fd(L, &rset, max_fd+1, itab, rtab, ndirty); + return_fd(L, &wset, max_fd+1, itab, wtab, 0); + make_assoc(L, rtab); + make_assoc(L, wtab); + return 2; + } else if (ret == 0) { + lua_pushstring(L, "timeout"); + return 3; + } else { + lua_pushnil(L); + lua_pushnil(L); + lua_pushstring(L, "error"); + return 3; + } } /*=========================================================================*\ -* Lua methods +* Internal functions \*=========================================================================*/ -static int meth_set(lua_State *L) -{ - fd_set *set = (fd_set *) aux_checkclass(L, "select{fd_set}", 1); - t_sock fd = (t_sock) lua_tonumber(L, 2); - if (fd >= 0) FD_SET(fd, set); - return 0; +static int getfd(lua_State *L) { + int fd = -1; + lua_pushstring(L, "getfd"); + lua_gettable(L, -2); + if (!lua_isnil(L, -1)) { + lua_pushvalue(L, -2); + lua_call(L, 1, 1); + if (lua_isnumber(L, -1)) + fd = lua_tonumber(L, -1); + } + lua_pop(L, 1); + return fd; } -static int meth_isset(lua_State *L) -{ - fd_set *set = (fd_set *) aux_checkclass(L, "select{fd_set}", 1); - t_sock fd = (t_sock) lua_tonumber(L, 2); - if (fd >= 0 && FD_ISSET(fd, set)) lua_pushnumber(L, 1); - else lua_pushnil(L); - return 1; +static int dirty(lua_State *L) { + int is = 0; + lua_pushstring(L, "dirty"); + lua_gettable(L, -2); + if (!lua_isnil(L, -1)) { + lua_pushvalue(L, -2); + lua_call(L, 1, 1); + is = lua_toboolean(L, -1); + } + lua_pop(L, 1); + return is; } -/*=========================================================================*\ -* Internal functions -\*=========================================================================*/ -static int c_select(lua_State *L) -{ - int max_fd = (int) lua_tonumber(L, 1); - fd_set *read_fd_set = (fd_set *) aux_checkclass(L, "select{fd_set}", 2); - fd_set *write_fd_set = (fd_set *) aux_checkclass(L, "select{fd_set}", 3); - int timeout = lua_isnil(L, 4) ? -1 : (int)(lua_tonumber(L, 4) * 1000); - struct timeval tv; - tv.tv_sec = timeout / 1000; - tv.tv_usec = (timeout % 1000) * 1000; - lua_pushnumber(L, select(max_fd, read_fd_set, write_fd_set, NULL, - timeout < 0 ? NULL : &tv)); - return 1; +static int collect_fd(lua_State *L, int tab, int max_fd, + int itab, fd_set *set) { + int i = 1; + if (lua_isnil(L, tab)) + return max_fd; + while (1) { + int fd; + lua_pushnumber(L, i); + lua_gettable(L, tab); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + break; + } + fd = getfd(L); + if (fd > 0) { + FD_SET(fd, set); + if (max_fd < fd) max_fd = fd; + lua_pushnumber(L, fd); + lua_pushvalue(L, -2); + lua_settable(L, itab); + } + lua_pop(L, 1); + i = i + 1; + } + return max_fd; +} + +static int check_dirty(lua_State *L, int tab, int dtab, fd_set *set) { + int ndirty = 0, i = 1; + if (lua_isnil(L, tab)) + return 0; + while (1) { + int fd; + lua_pushnumber(L, i); + lua_gettable(L, tab); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + break; + } + fd = getfd(L); + if (fd > 0 && dirty(L)) { + lua_pushnumber(L, ++ndirty); + lua_pushvalue(L, -2); + lua_settable(L, dtab); + FD_CLR(fd, set); + } + lua_pop(L, 1); + i = i + 1; + } + return ndirty; +} + +static void return_fd(lua_State *L, fd_set *set, int max_fd, + int itab, int tab, int start) { + int fd; + for (fd = 0; fd < max_fd; fd++) { + if (FD_ISSET(fd, set)) { + lua_pushnumber(L, ++start); + lua_pushnumber(L, fd); + lua_gettable(L, itab); + lua_settable(L, tab); + } + } +} + +static void make_assoc(lua_State *L, int tab) { + int i = 1, atab; + lua_newtable(L); atab = lua_gettop(L); + while (1) { + lua_pushnumber(L, i); + lua_gettable(L, tab); + if (!lua_isnil(L, -1)) { + lua_pushnumber(L, i); + lua_pushvalue(L, -2); + lua_settable(L, atab); + lua_pushnumber(L, i); + lua_settable(L, atab); + } else { + lua_pop(L, 1); + break; + } + i = i+1; + } } + diff --git a/src/smtp.lua b/src/smtp.lua index 7ae99a5..dc80c35 100644 --- a/src/smtp.lua +++ b/src/smtp.lua @@ -7,15 +7,9 @@ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- --- Load SMTP from dynamic library --- Comment these lines if you are loading static ------------------------------------------------------------------------------ -local open = assert(loadlib("smtp", "luaopen_smtp")) -local smtp = assert(open()) - ------------------------------------------------------------------------------ --- Load other required modules +-- Load required modules ----------------------------------------------------------------------------- +local smtp = requirelib("smtp") local socket = require("socket") local ltn12 = require("ltn12") local tp = require("tp") @@ -23,10 +17,10 @@ local tp = require("tp") ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- --- make all module globals fall into smtp namespace -setmetatable(smtp, { __index = _G }) -setfenv(1, smtp) +_LOADED["smtp"] = smtp +-- timeout for connection +TIMEOUT = 60 -- default server used to send e-mails SERVER = "localhost" -- default port @@ -94,9 +88,7 @@ function metat.__index:send(mailt) end function open(server, port) - print(server or SERVER, port or PORT) - local tp, error = tp.connect(server or SERVER, port or PORT) - if not tp then return nil, error end + local tp = socket.try(tp.connect(server or SERVER, port or PORT, TIMEOUT)) return setmetatable({tp = tp}, metat) end @@ -121,7 +113,10 @@ local function send_multipart(mesgt) coroutine.yield('content-type: multipart/mixed; boundary="' .. bd .. '"\r\n\r\n') -- send preamble - if mesgt.body.preamble then coroutine.yield(mesgt.body.preamble) end + if mesgt.body.preamble then + coroutine.yield(mesgt.body.preamble) + coroutine.yield("\r\n") + end -- send each part separated by a boundary for i, m in ipairs(mesgt.body) do coroutine.yield("\r\n--" .. bd .. "\r\n") @@ -130,7 +125,10 @@ local function send_multipart(mesgt) -- send last boundary coroutine.yield("\r\n--" .. bd .. "--\r\n\r\n") -- send epilogue - if mesgt.body.epilogue then coroutine.yield(mesgt.body.epilogue) end + if mesgt.body.epilogue then + coroutine.yield(mesgt.body.epilogue) + coroutine.yield("\r\n") + end end -- yield message body from a source @@ -183,12 +181,12 @@ end -- set defaul headers local function adjust_headers(mesgt) local lower = {} - for i,v in (mesgt or lower) do + for i,v in (mesgt.headers or lower) do lower[string.lower(i)] = v end lower["date"] = lower["date"] or os.date("!%a, %d %b %Y %H:%M:%S ") .. (mesgt.zone or ZONE) - lower["x-mailer"] = lower["x-mailer"] or socket.version + lower["x-mailer"] = lower["x-mailer"] or socket.VERSION -- this can't be overriden lower["mime-version"] = "1.0" mesgt.headers = lower @@ -198,18 +196,22 @@ function message(mesgt) adjust_headers(mesgt) -- create and return message source local co = coroutine.create(function() send_message(mesgt) end) - return function() return socket.skip(1, coroutine.resume(co)) end + return function() + local ret, a, b = coroutine.resume(co) + if ret then return a, b + else return nil, a end + end end --------------------------------------------------------------------------- -- High level SMTP API ----------------------------------------------------------------------------- send = socket.protect(function(mailt) - local smtp = socket.try(open(mailt.server, mailt.port)) - smtp:greet(mailt.domain) - smtp:send(mailt) - smtp:quit() - return smtp:close() + local con = open(mailt.server, mailt.port) + con:greet(mailt.domain) + con:send(mailt) + con:quit() + return con:close() end) return smtp diff --git a/src/socket.lua b/src/socket.lua index 418cd1b..9aa6437 100644 --- a/src/socket.lua +++ b/src/socket.lua @@ -7,8 +7,8 @@ ----------------------------------------------------------------------------- -- Load LuaSocket from dynamic library ----------------------------------------------------------------------------- -local open = assert(loadlib("luasocket", "luaopen_socket")) -local socket = assert(open()) +local socket = requirelib("luasocket", "luaopen_socket", getfenv(1)) +_LOADED["socket"] = socket ----------------------------------------------------------------------------- -- Auxiliar functions @@ -116,18 +116,21 @@ socket.sourcet["by-length"] = function(sock, length) end socket.sourcet["until-closed"] = function(sock) + local done return setmetatable({ getfd = function() return sock:getfd() end, dirty = function() return sock:dirty() end }, { - __call = ltn12.source.simplify(function() + __call = function() + if done then return nil end local chunk, err, partial = sock:receive(socket.BLOCKSIZE) if not err then return chunk elseif err == "closed" then sock:close() - return partial, ltn12.source.empty() + done = 1 + return partial else return nil, err end - end) + end }) end diff --git a/src/tcp.c b/src/tcp.c index 90cfcde..845e0a3 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -9,13 +9,10 @@ #include #include -#include "luasocket.h" - #include "auxiliar.h" #include "socket.h" #include "inet.h" #include "options.h" -#include "base.h" #include "tcp.h" /*=========================================================================*\ @@ -41,7 +38,7 @@ static int meth_dirty(lua_State *L); /* tcp object methods */ static luaL_reg tcp[] = { {"__gc", meth_close}, - {"__tostring", base_meth_tostring}, + {"__tostring", aux_tostring}, {"accept", meth_accept}, {"bind", meth_bind}, {"close", meth_close}, diff --git a/src/timeout.c b/src/timeout.c index bd6c3b4..4f9a315 100644 --- a/src/timeout.c +++ b/src/timeout.c @@ -26,6 +26,14 @@ #endif #endif +/* min and max macros */ +#ifndef MIN +#define MIN(x, y) ((x) < (y) ? x : y) +#endif +#ifndef MAX +#define MAX(x, y) ((x) > (y) ? x : y) +#endif + /*=========================================================================*\ * Internal function prototypes \*=========================================================================*/ diff --git a/src/tp.lua b/src/tp.lua index 3e9dba6..56dd8bc 100644 --- a/src/tp.lua +++ b/src/tp.lua @@ -2,24 +2,28 @@ -- Unified SMTP/FTP subsystem -- LuaSocket toolkit. -- Author: Diego Nehab --- Conforming to: RFC 2616, LTN7 -- RCS ID: $Id$ ----------------------------------------------------------------------------- ----------------------------------------------------------------------------- --- Load other required modules +-- Load required modules ----------------------------------------------------------------------------- local socket = require("socket") +local ltn12 = require("ltn12") ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- -tp = {} -setmetatable(tp, { __index = _G }) -setfenv(1, tp) +_LOADED["tp"] = getfenv(1) +----------------------------------------------------------------------------- +-- Program constants +----------------------------------------------------------------------------- TIMEOUT = 60 +----------------------------------------------------------------------------- +-- Implementation +----------------------------------------------------------------------------- -- gets server reply (works for SMTP and FTP) local function get_reply(control) local code, current, sep @@ -37,7 +41,6 @@ local function get_reply(control) -- reply ends with same code until code == current and sep == " " end -print(reply) return code, reply end @@ -46,6 +49,7 @@ local metat = { __index = {} } function metat.__index:check(ok) local code, reply = get_reply(self.control) +print(reply) if not code then return nil, reply end if type(ok) ~= "function" then if type(ok) == "table" then @@ -103,11 +107,11 @@ function metat.__index:close() end -- connect with server and return control object -function connect(host, port) - local control, err = socket.connect(host, port) - if not control then return nil, err end - control:settimeout(TIMEOUT) +connect = socket.protect(function(host, port, timeout) + local control = socket.try(socket.tcp()) + socket.try(control:settimeout(timeout or TIMEOUT)) + socket.try(control:connect(host, port)) return setmetatable({control = control}, metat) -end +end) return tp diff --git a/src/udp.c b/src/udp.c index 4770a2e..51d6402 100644 --- a/src/udp.c +++ b/src/udp.c @@ -9,13 +9,10 @@ #include #include -#include "luasocket.h" - #include "auxiliar.h" #include "socket.h" #include "inet.h" #include "options.h" -#include "base.h" #include "udp.h" /*=========================================================================*\ @@ -51,7 +48,7 @@ static luaL_reg udp[] = { {"close", meth_close}, {"setoption", meth_setoption}, {"__gc", meth_close}, - {"__tostring", base_meth_tostring}, + {"__tostring", aux_tostring}, {"getfd", meth_getfd}, {"setfd", meth_setfd}, {"dirty", meth_dirty}, diff --git a/src/url.lua b/src/url.lua index 2441268..aac2a47 100644 --- a/src/url.lua +++ b/src/url.lua @@ -9,9 +9,7 @@ ----------------------------------------------------------------------------- -- Setup namespace ----------------------------------------------------------------------------- -local url = {} -setmetatable(url, { __index = _G }) -setfenv(1, url) +_LOADED["url"] = getfenv(1) ----------------------------------------------------------------------------- -- Encodes a string into its escaped hexadecimal representation -- cgit v1.2.3-55-g6feb